tsurumure 1 mēnesi atpakaļ
vecāks
revīzija
f0f9c6e75a

+ 51 - 51
src/main/java/com/backendsys/modules/queue/controller/QueueController.java

@@ -1,51 +1,51 @@
-package com.backendsys.modules.queue.controller;
-
-import com.backendsys.modules.common.config.security.annotations.Anonymous;
-import com.backendsys.modules.queue.entity.QueuePosition;
-import com.backendsys.modules.queue.entity.QueueRequest;
-import com.backendsys.modules.queue.service.QueueService;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.web.bind.annotation.*;
-
-@RestController
-public class QueueController {
-
-    @Autowired
-    private QueueService queueService;
-
-    /**
-     * 提交请求并加入队列
-     */
-    @Anonymous
-    @PostMapping("/api/queue/submit")
-    public String submitRequest() {
-        QueueRequest request = new QueueRequest();
-        int position = queueService.enqueue("taskQueue", request);
-        return "Your request has been submitted. You are at position " + position + " in the queue. request id: " + request.getId();
-    }
-
-    /**
-     * 启动处理队列中的请求
-     */
-    @Anonymous
-    @GetMapping("/api/queue/startProcessing")
-    public String startProcessing() {
-        queueService.startProcessing("taskQueue");
-        return "Processing has started.";
-    }
-
-    /**
-     * 查询队列
-     */
-    @Anonymous
-    @GetMapping("/api/queue/position")
-    public String getPosition(@RequestParam String requestId) {
-        QueuePosition positionInfo = queueService.getPosition("taskQueue", requestId);
-        if (positionInfo.getPosition() == -1) {
-            return "Request not found.";
-        } else {
-            return "Your request is at position " + positionInfo.getPosition() + " out of " + positionInfo.getTotal() + " in the queue.";
-        }
-    }
-
-}
+//package com.backendsys.modules.queue.controller;
+//
+//import com.backendsys.modules.common.config.security.annotations.Anonymous;
+//import com.backendsys.modules.queue.entity.QueuePosition;
+//import com.backendsys.modules.queue.entity.QueueRequest;
+//import com.backendsys.modules.queue.service.QueueService;
+//import org.springframework.beans.factory.annotation.Autowired;
+//import org.springframework.web.bind.annotation.*;
+//
+//@RestController
+//public class QueueController {
+//
+//    @Autowired
+//    private QueueService queueService;
+//
+//    /**
+//     * 提交请求并加入队列
+//     */
+//    @Anonymous
+//    @PostMapping("/api/queue/submit")
+//    public String submitRequest() {
+//        QueueRequest request = new QueueRequest();
+//        int position = queueService.enqueue("taskQueue", request);
+//        return "Your request has been submitted. You are at position " + position + " in the queue. request id: " + request.getId();
+//    }
+//
+//    /**
+//     * 启动处理队列中的请求
+//     */
+//    @Anonymous
+//    @GetMapping("/api/queue/startProcessing")
+//    public String startProcessing() {
+//        queueService.startProcessing("taskQueue");
+//        return "Processing has started.";
+//    }
+//
+//    /**
+//     * 查询队列
+//     */
+//    @Anonymous
+//    @GetMapping("/api/queue/position")
+//    public String getPosition(@RequestParam String requestId) {
+//        QueuePosition positionInfo = queueService.getPosition("taskQueue", requestId);
+//        if (positionInfo.getPosition() == -1) {
+//            return "Request not found.";
+//        } else {
+//            return "Your request is at position " + positionInfo.getPosition() + " out of " + positionInfo.getTotal() + " in the queue.";
+//        }
+//    }
+//
+//}

+ 37 - 0
src/main/java/com/backendsys/modules/queue/controller/TaskStatusController.java

@@ -0,0 +1,37 @@
+package com.backendsys.modules.queue.controller;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.http.HttpStatus;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.PathVariable;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.server.ResponseStatusException;
+
+import java.util.Map;
+import java.util.stream.Collectors;
+
+@RestController
+public class TaskStatusController {
+
+    @Autowired
+    private RedisTemplate redisTemplate;
+
+    @GetMapping("/api/tasks/{taskId}/status")
+    public Map<String, String> getTaskStatus(@PathVariable String taskId) {
+        String statusKey = "task:status:" + taskId;
+        Map<Object, Object> entries = redisTemplate.opsForHash().entries(statusKey);
+
+        if (entries.isEmpty()) {
+            System.out.println("Task not found");
+        }
+
+        // 转换为String类型返回
+        return entries.entrySet().stream()
+                .collect(Collectors.toMap(
+                        e -> e.getKey().toString(),
+                        e -> e.getValue().toString()
+                ));
+    }
+}

+ 15 - 0
src/main/java/com/backendsys/modules/queue/entity/GenerateRequest.java

@@ -0,0 +1,15 @@
+package com.backendsys.modules.queue.entity;
+
+
+import lombok.Data;
+
+import java.util.HashMap;
+import java.util.Map;
+
+@Data
+public class GenerateRequest {
+    private String prompt;
+    private String userId;
+    private String model = "stable-diffusion-v2.1";              // 默认模型
+    private Map<String, Object> parameters = new HashMap<>();
+}

+ 14 - 14
src/main/java/com/backendsys/modules/queue/entity/QueuePosition.java

@@ -1,14 +1,14 @@
-package com.backendsys.modules.queue.entity;
-
-import lombok.Data;
-
-@Data
-public class QueuePosition {
-    private int position; // 当前排队位置
-    private int total; // 队列总数
-
-    public QueuePosition(int position, int total) {
-        this.position = position;
-        this.total = total;
-    }
-}
+//package com.backendsys.modules.queue.entity;
+//
+//import lombok.Data;
+//
+//@Data
+//public class QueuePosition {
+//    private int position; // 当前排队位置
+//    private int total; // 队列总数
+//
+//    public QueuePosition(int position, int total) {
+//        this.position = position;
+//        this.total = total;
+//    }
+//}

+ 15 - 15
src/main/java/com/backendsys/modules/queue/entity/QueueRequest.java

@@ -1,15 +1,15 @@
-package com.backendsys.modules.queue.entity;
-
-import lombok.Data;
-
-import java.util.UUID;
-
-@Data
-public class QueueRequest {
-    private String id;
-    private int position;
-
-    public QueueRequest() {
-        this.id = UUID.randomUUID().toString(); // 自动生成唯一标识符
-    }
-}
+//package com.backendsys.modules.queue.entity;
+//
+//import lombok.Data;
+//
+//import java.util.UUID;
+//
+//@Data
+//public class QueueRequest {
+//    private String id;
+//    private int position;
+//
+//    public QueueRequest() {
+//        this.id = UUID.randomUUID().toString(); // 自动生成唯一标识符
+//    }
+//}

+ 104 - 104
src/main/java/com/backendsys/modules/queue/service/QueueService.java

@@ -1,104 +1,104 @@
-package com.backendsys.modules.queue.service;
-
-import cn.hutool.core.convert.Convert;
-import com.backendsys.modules.queue.entity.QueuePosition;
-import com.backendsys.modules.queue.entity.QueueRequest;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.data.redis.core.StringRedisTemplate;
-import org.springframework.stereotype.Service;
-
-import java.util.concurrent.atomic.AtomicInteger;
-
-@Service
-public class QueueService {
-    private final StringRedisTemplate redisTemplate;
-    private final AtomicInteger counter = new AtomicInteger(0);
-
-    @Autowired
-    public QueueService(StringRedisTemplate redisTemplate) {
-        this.redisTemplate = redisTemplate;
-    }
-
-    /**
-     * 分配排队编号
-     */
-    public int enqueue(String queueKey, QueueRequest queueRequest) {
-        // 为每个请求分配一个排队编号
-        Long size = redisTemplate.opsForZSet().size(queueKey);
-        int position = (size == null) ? 1 : size.intValue() + 1;
-        //int position = counter.incrementAndGet();
-        System.out.println("排号: " + position + ", request_id: " + queueRequest.getId());
-
-        queueRequest.setPosition(position);
-        redisTemplate.opsForZSet().add(queueKey, queueRequest.getId(), position);
-
-        return position;
-    }
-
-    /**
-     * 开始排队
-     */
-    public void startProcessing(String queueKey) {
-        new Thread(() -> {
-            while (true) {
-                // 从有序集合中取出第一个请求
-                String requestId = Convert.toStr(redisTemplate.opsForZSet().popMin(queueKey));
-                if (requestId != null) {
-                    QueueRequest queueRequest = getRequestById(requestId);
-                    int position = queueRequest.getPosition();
-                    // 处理请求
-                    processRequest(queueRequest);
-                    // 可以通知用户处理完成
-                    notifyUser(queueRequest, position);
-                }
-            }
-        }).start();
-    }
-
-    private void processRequest(QueueRequest queueRequest) {
-        // 模拟耗时操作
-        try {
-            Thread.sleep(10 * 1000);
-        } catch (InterruptedException e) {
-            Thread.currentThread().interrupt();
-        }
-    }
-
-    /**
-     * 通知
-     */
-    private void notifyUser(QueueRequest queueRequest, int position) {
-        // 通知用户处理完成
-        System.out.println("Request " + position + " processed.");
-    }
-
-    private QueueRequest getRequestById(String requestId) {
-        // 根据请求ID获取请求对象
-        return new QueueRequest();
-    }
-
-    /**
-     * 获取请求的排队位置和队列总数
-     * @param requestId 请求ID
-     * @return 一个包含排队位置和队列总数的对象
-     */
-    public QueuePosition getPosition(String queueKey, String requestId) {
-        // 获取请求的排队位置
-        Long rank = redisTemplate.opsForZSet().rank(queueKey, requestId);
-        if (rank == null) {
-            return new QueuePosition(-1, 0); // 请求不存在
-        }
-
-        // 获取队列的总大小(未处理的请求数量)
-        Long size = redisTemplate.opsForZSet().size(queueKey);
-        int total = (size == null) ? 0 : size.intValue();
-
-        // 当前排队位置从0开始,加1表示实际排队位置
-        int position = rank.intValue() + 1;
-
-        return new QueuePosition(position, total);
-    }
-
-
-}
-
+//package com.backendsys.modules.queue.service;
+//
+//import cn.hutool.core.convert.Convert;
+//import com.backendsys.modules.queue.entity.QueuePosition;
+//import com.backendsys.modules.queue.entity.QueueRequest;
+//import org.springframework.beans.factory.annotation.Autowired;
+//import org.springframework.data.redis.core.StringRedisTemplate;
+//import org.springframework.stereotype.Service;
+//
+//import java.util.concurrent.atomic.AtomicInteger;
+//
+//@Service
+//public class QueueService {
+//    private final StringRedisTemplate redisTemplate;
+//    private final AtomicInteger counter = new AtomicInteger(0);
+//
+//    @Autowired
+//    public QueueService(StringRedisTemplate redisTemplate) {
+//        this.redisTemplate = redisTemplate;
+//    }
+//
+//    /**
+//     * 分配排队编号
+//     */
+//    public int enqueue(String queueKey, QueueRequest queueRequest) {
+//        // 为每个请求分配一个排队编号
+//        Long size = redisTemplate.opsForZSet().size(queueKey);
+//        int position = (size == null) ? 1 : size.intValue() + 1;
+//        //int position = counter.incrementAndGet();
+//        System.out.println("排号: " + position + ", request_id: " + queueRequest.getId());
+//
+//        queueRequest.setPosition(position);
+//        redisTemplate.opsForZSet().add(queueKey, queueRequest.getId(), position);
+//
+//        return position;
+//    }
+//
+//    /**
+//     * 开始排队
+//     */
+//    public void startProcessing(String queueKey) {
+//        new Thread(() -> {
+//            while (true) {
+//                // 从有序集合中取出第一个请求
+//                String requestId = Convert.toStr(redisTemplate.opsForZSet().popMin(queueKey));
+//                if (requestId != null) {
+//                    QueueRequest queueRequest = getRequestById(requestId);
+//                    int position = queueRequest.getPosition();
+//                    // 处理请求
+//                    processRequest(queueRequest);
+//                    // 可以通知用户处理完成
+//                    notifyUser(queueRequest, position);
+//                }
+//            }
+//        }).start();
+//    }
+//
+//    private void processRequest(QueueRequest queueRequest) {
+//        // 模拟耗时操作
+//        try {
+//            Thread.sleep(10 * 1000);
+//        } catch (InterruptedException e) {
+//            Thread.currentThread().interrupt();
+//        }
+//    }
+//
+//    /**
+//     * 通知
+//     */
+//    private void notifyUser(QueueRequest queueRequest, int position) {
+//        // 通知用户处理完成
+//        System.out.println("Request " + position + " processed.");
+//    }
+//
+//    private QueueRequest getRequestById(String requestId) {
+//        // 根据请求ID获取请求对象
+//        return new QueueRequest();
+//    }
+//
+//    /**
+//     * 获取请求的排队位置和队列总数
+//     * @param requestId 请求ID
+//     * @return 一个包含排队位置和队列总数的对象
+//     */
+//    public QueuePosition getPosition(String queueKey, String requestId) {
+//        // 获取请求的排队位置
+//        Long rank = redisTemplate.opsForZSet().rank(queueKey, requestId);
+//        if (rank == null) {
+//            return new QueuePosition(-1, 0); // 请求不存在
+//        }
+//
+//        // 获取队列的总大小(未处理的请求数量)
+//        Long size = redisTemplate.opsForZSet().size(queueKey);
+//        int total = (size == null) ? 0 : size.intValue();
+//
+//        // 当前排队位置从0开始,加1表示实际排队位置
+//        int position = rank.intValue() + 1;
+//
+//        return new QueuePosition(position, total);
+//    }
+//
+//
+//}
+//

+ 4 - 0
src/main/java/com/backendsys/modules/queue/service/TaskQueueService.java

@@ -0,0 +1,4 @@
+package com.backendsys.modules.queue.service;
+
+public interface TaskQueueService {
+}

+ 61 - 0
src/main/java/com/backendsys/modules/queue/service/impl/TaskQueueServiceImpl.java

@@ -0,0 +1,61 @@
+package com.backendsys.modules.queue.service.impl;
+
+import com.backendsys.modules.queue.entity.GenerateRequest;
+import com.backendsys.modules.queue.service.TaskQueueService;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.data.redis.connection.stream.ObjectRecord;
+import org.springframework.data.redis.connection.stream.StreamRecords;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.stereotype.Service;
+
+import java.time.Instant;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+@Service
+public class TaskQueueServiceImpl implements TaskQueueService {
+
+    @Value("${spring.data.redis.stream.key}")
+    private String streamKey;
+
+    @Autowired
+    private RedisTemplate redisTemplate;
+
+    // 添加任务到Stream
+    public String addTask(Map<String, String> taskData) {
+        ObjectRecord<String, Map<String, String>> record =
+                StreamRecords.newRecord()
+                        .ofObject(taskData)
+                        .withStreamKey(streamKey);
+
+        return redisTemplate.opsForStream()
+                .add(record)
+                .getValue();
+    }
+
+    // 初始化任务状态(Redis Hash)
+    public void initTaskStatus(String taskId, GenerateRequest request) {
+        String statusKey = "task:status:" + taskId;
+
+        Map<String, String> status = new HashMap<>();
+        status.put("status", "QUEUED");
+        status.put("progress", "0");
+        status.put("prompt", request.getPrompt());
+        status.put("createdAt", Instant.now().toString());
+
+        redisTemplate.opsForHash().putAll(statusKey, status);
+
+        // 设置24小时过期
+        redisTemplate.expire(statusKey, 24, TimeUnit.HOURS);
+    }
+
+    // 获取队列位置
+    public String getQueuePosition(String taskId) {
+        // 使用XPENDING获取任务在消费者组中的位置
+        // 实际实现需根据业务逻辑计算
+        return "10"; // 示例值
+    }
+
+}

+ 4 - 0
src/main/resources/application-local.yml

@@ -47,6 +47,10 @@ spring:
       host: 127.0.0.1
       port: 6388
       password: 123456
+      stream:
+        key: comfyui-stream-task      # 任务队列名称
+        group: comfyui-stream-group   # 消费者组名
+
 #    cache:
 #      type: redis
 #      redis: