浏览代码

Dev ComfyUI

tsurumure 2 月之前
父节点
当前提交
a4a292315d

+ 33 - 0
src/main/java/com/backendsys/modules/crt/controller/CrtGenerateController.java

@@ -0,0 +1,33 @@
+package com.backendsys.modules.crt.controller;
+
+import com.backendsys.modules.common.config.security.annotations.Anonymous;
+import com.backendsys.modules.common.utils.Result;
+import com.backendsys.modules.crt.entity.CrtDramaProjectStoryboard;
+import com.backendsys.modules.crt.service.CrtGenerateService;
+import io.swagger.v3.oas.annotations.Operation;
+import io.swagger.v3.oas.annotations.tags.Tag;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.validation.annotation.Validated;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RestController;
+
+@Validated
+@RestController
+@Tag(name = "短剧创作-生成")
+public class CrtGenerateController {
+
+    @Autowired
+    private CrtGenerateService crtGenerateService;
+
+    // 生成图片
+    @Operation(summary = "生成图片")
+    @Anonymous
+    @PostMapping("/api/crt/generate/image")
+    public Result generateImage(@Validated(CrtDramaProjectStoryboard.GenerateImage.class) @RequestBody CrtDramaProjectStoryboard crtDramaProjectStoryboard) {
+        return Result.success().put("data", crtGenerateService.generateImage(crtDramaProjectStoryboard));
+    }
+
+    // 生成视频
+
+}

+ 2 - 1
src/main/java/com/backendsys/modules/crt/entity/CrtDramaProjectStoryboard.java

@@ -26,12 +26,13 @@ public class CrtDramaProjectStoryboard {
     public static interface Update{}
     public static interface Clear{}
     public static interface Delete{}
+    public static interface GenerateImage{}
 
     @TableId(type = IdType.AUTO)
     private Long id;
 
     @TableField(exist = false)
-    @NotNull(message = "分镜ID不能为空", groups = { Update.class})
+    @NotNull(message = "分镜ID不能为空", groups = { Update.class, GenerateImage.class })
     private Long drama_project_storyboard_id;
 
     private Long user_id;

+ 11 - 0
src/main/java/com/backendsys/modules/crt/service/CrtGenerateService.java

@@ -0,0 +1,11 @@
+package com.backendsys.modules.crt.service;
+
+import com.backendsys.modules.crt.entity.CrtDramaProjectStoryboard;
+
+import java.util.Map;
+
+public interface CrtGenerateService {
+
+    Map<String, Object> generateImage(CrtDramaProjectStoryboard crtDramaProjectStoryboard);
+
+}

+ 42 - 0
src/main/java/com/backendsys/modules/crt/service/impl/CrtGenerateServiceImpl.java

@@ -0,0 +1,42 @@
+package com.backendsys.modules.crt.service.impl;
+
+import com.backendsys.exception.CustException;
+import com.backendsys.modules.common.utils.Result;
+import com.backendsys.modules.crt.dao.CrtDramaProjectStoryboardDao;
+import com.backendsys.modules.crt.entity.CrtDramaProjectStoryboard;
+import com.backendsys.modules.crt.service.CrtGenerateService;
+import com.backendsys.modules.sdk.comfyui.entity.CFPromptResponse;
+import com.backendsys.modules.sdk.comfyui.service.ComfyUIService;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Service;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+
+@Service
+public class CrtGenerateServiceImpl implements CrtGenerateService {
+
+    @Autowired
+    private ComfyUIService comfyUIService;
+
+    @Autowired
+    private CrtDramaProjectStoryboardDao crtDramaProjectStoryboardDao;
+
+    @Override
+    public Map<String, Object> generateImage(CrtDramaProjectStoryboard crtDramaProjectStoryboard) {
+
+        Long drama_project_storyboard_id = crtDramaProjectStoryboard.getDrama_project_storyboard_id();
+
+        CrtDramaProjectStoryboard detail = crtDramaProjectStoryboardDao.selectById(drama_project_storyboard_id);
+        if (detail == null) throw new CustException("分镜不存在");
+
+        String prompt = "{}";
+
+        Mono<CFPromptResponse> cfPromptResponse = comfyUIService.prompt(prompt);
+        CFPromptResponse response = cfPromptResponse.block();
+        System.out.println("结果: " + response);
+
+        return Map.of("drama_project_storyboard_id", drama_project_storyboard_id);
+    }
+
+}

+ 6 - 2
src/main/java/com/backendsys/modules/sdk/comfyui/controller/ComfyUIDemoController.java

@@ -3,6 +3,7 @@ package com.backendsys.modules.sdk.comfyui.controller;
 import com.backendsys.exception.CustException;
 import com.backendsys.modules.common.config.security.annotations.Anonymous;
 import com.backendsys.modules.sdk.comfyui.service.ComfyUIService;
+import com.backendsys.modules.sdk.comfyui.service.ComfyUISocketService;
 import com.tencentcloudapi.tione.v20211111.models.ChatCompletionResponse;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.http.MediaType;
@@ -23,13 +24,16 @@ public class ComfyUIDemoController {
     @Autowired
     private ComfyUIService comfyUIService;
 
+    @Autowired
+    private ComfyUISocketService comfyUISocketService;
+
     /**
      * [ComfyUI] 创建 WebSocket 监听连接
      */
     @Anonymous
     @PostMapping("/api/comfyui/ws/connect")
     public String connect(String clientId) {
-        comfyUIService.connect(clientId, "ws://43.128.1.201:8007/ws").subscribe();
+        comfyUISocketService.connect(clientId, "ws://43.128.1.201:8007/ws").subscribe();
         return "Connection initiated for: " + clientId;
     }
 
@@ -39,7 +43,7 @@ public class ComfyUIDemoController {
     @Anonymous
     @PostMapping("/api/comfyui/ws/disconnect")
     public String disconnect(String clientId) {
-        comfyUIService.disconnect(clientId).subscribe();
+        comfyUISocketService.disconnect(clientId).subscribe();
         return "Disconnected: " + clientId;
     }
 

+ 12 - 0
src/main/java/com/backendsys/modules/sdk/comfyui/entity/CFPromptResponse.java

@@ -0,0 +1,12 @@
+package com.backendsys.modules.sdk.comfyui.entity;
+
+import lombok.Data;
+
+@Data
+public class CFPromptResponse {
+
+    private String prompt_id;
+    private Integer number;
+    private Object node_errors;
+
+}

+ 3 - 4
src/main/java/com/backendsys/modules/sdk/comfyui/service/ComfyUIService.java

@@ -1,12 +1,11 @@
 package com.backendsys.modules.sdk.comfyui.service;
 
+import com.backendsys.modules.sdk.comfyui.entity.CFPromptResponse;
 import reactor.core.publisher.Mono;
 
 public interface ComfyUIService {
 
-    // [ComfyUI] 创建 WebSocket 监听连接
-    Mono<Void> connect(String clientId, String wsUrl);
-    // [ComfyUI] 断开 WebSocket 监听连接
-    Mono<Void> disconnect(String clientId);
+    // [ComfyUI] 执行任务
+    Mono<CFPromptResponse> prompt(String prompt);
 
 }

+ 13 - 0
src/main/java/com/backendsys/modules/sdk/comfyui/service/ComfyUISocketService.java

@@ -0,0 +1,13 @@
+package com.backendsys.modules.sdk.comfyui.service;
+
+import reactor.core.publisher.Mono;
+
+public interface ComfyUISocketService {
+
+    // [ComfyUI] 创建 WebSocket 监听连接
+    Mono<Void> connect(String clientId, String wsUrl);
+
+    // [ComfyUI] 断开 WebSocket 监听连接
+    Mono<Void> disconnect(String clientId);
+
+}

+ 44 - 100
src/main/java/com/backendsys/modules/sdk/comfyui/service/impl/ComfyUIServiceImpl.java

@@ -1,121 +1,65 @@
 package com.backendsys.modules.sdk.comfyui.service.impl;
 
+import cn.hutool.core.convert.Convert;
+import com.backendsys.modules.common.Filter.WebClientFilter;
+import com.backendsys.modules.sdk.comfyui.entity.CFPromptResponse;
 import com.backendsys.modules.sdk.comfyui.service.ComfyUIService;
-import com.tencentcloudapi.tione.v20211111.models.ChatCompletionResponse;
-import org.apache.http.client.methods.HttpPost;
-import org.apache.http.impl.client.CloseableHttpClient;
-import org.apache.http.impl.client.HttpClients;
 import org.springframework.beans.factory.annotation.Value;
+import org.springframework.http.MediaType;
 import org.springframework.stereotype.Service;
-import org.springframework.web.reactive.socket.WebSocketMessage;
-import org.springframework.web.reactive.socket.WebSocketSession;
-import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient;
-import org.springframework.web.reactive.socket.client.WebSocketClient;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.FluxSink;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.util.UriComponentsBuilder;
 import reactor.core.publisher.Mono;
-import reactor.netty.http.client.HttpClient;
-import reactor.util.retry.Retry;
 
-import java.net.URI;
-import java.nio.ByteBuffer;
-import java.nio.charset.StandardCharsets;
-import java.time.Duration;
-import java.util.Base64;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.UUID;
 
 @Service
 public class ComfyUIServiceImpl implements ComfyUIService {
 
-    @Value("${comfyui.token}")
-    private String COMFYUI_TOKEN = "";
-
-//    // 单例 WebSocketClient(线程安全)
-//    private WebSocketClient webSocketClient;
-
-    // 管理多个连接
-    private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
-
-    /**
-     * 创建带有认证头的 WebSocketClient
-     * @param token 认证令牌
-     * @return 配置好的 WebSocketClient
-     */
-    private WebSocketClient createWebSocketClientWithToken(String token) {
-        HttpClient httpClient = HttpClient.create()
-            .headers(headers -> headers.add("Authorization", "Bearer " + token))
-            .responseTimeout(Duration.ofSeconds(30));  // 30秒超时
-        return new ReactorNettyWebSocketClient(httpClient);
-    }
-
-    /**
-     * [ComfyUI] 创建 WebSocket 监听连接
-     * @param clientId 客户端ID(用于标识连接)
-     * @param wsUrl WebSocket 地址
-     * @return Mono<Void> 表示连接操作
-     */
-    @Override
-    public Mono<Void> connect(String clientId, String wsUrl) {
-
-        return Mono.defer(() -> {
-            if (sessions.containsKey(clientId)) {
-                return Mono.error(new IllegalStateException("Connection already exists for client: " + clientId));
-            }
-
-            // 动态创建带有认证头的客户端
-            WebSocketClient clientWithAuth = createWebSocketClientWithToken(COMFYUI_TOKEN);
-            return clientWithAuth.execute(URI.create(wsUrl + "?clientId=" + clientId), session -> {
-                // 保存会话
-                sessions.put(clientId, session);
-
-                // 接收消息
-                Flux<String> incomingMessages = session.receive()
-                    .map(WebSocketMessage::getPayloadAsText)
-                    .doOnNext(message -> {
-
-                        System.out.println("(doOnNext) Received from " + clientId + ": " + message);
-//                            // 转发到消息总线
-//                            messageSink.tryEmitNext(message);
-                    })
-                    .doOnError(e -> {
-                        System.err.println("(doOnError) Error for " + clientId + ": " + e.getMessage());
-                    })
-                    .doFinally(signal -> {
-                        System.out.println("(doFinally) Connection closed for " + clientId + ": " + signal);
-                        sessions.remove(clientId);
-                    });
-
-                // 需要返回一个Mono<Void>来表示处理完成
-                return incomingMessages.then();
-            });
-        });
 
+    @Value("${comfyui.token}")
+    private String COMFYUI_TOKEN;
+    private final String BASE_URL = "http://43.128.1.201:8007";
+
+    private WebClient webClient;
+    public WebClient getWebClient() {
+        if (webClient == null) {
+            webClient = WebClient.builder().baseUrl(BASE_URL).filter(WebClientFilter.logFilter).build();
+        }
+        return webClient;
     }
 
     /**
-     * [ComfyUI] 断开 WebSocket 监听连接
-     * @param clientId 客户端ID
-     * @return Mono<Void> 表示断开操作
+     * [ComfyUI] 执行任务
      */
     @Override
-    public Mono<Void> disconnect(String clientId) {
-        return Mono.fromRunnable(() -> {
-            WebSocketSession session = sessions.get(clientId);
-            if (session != null) {
-                System.out.println("disconnect success! clientId: " + clientId);
-                session.close().subscribe();
-                sessions.remove(clientId);
-            }
-        });
-    }
+    public Mono<CFPromptResponse> prompt(String prompt) {
+
+        String client_id = Convert.toStr(UUID.randomUUID());
+
+        MultiValueMap<String, String> params = new LinkedMultiValueMap<>();
+        params.add("client_id", client_id);          // Unix时间戳、单位ms
+        params.add("prompt", prompt);
+
+        System.out.println("params = " + params);
+
+        String url = "/prompt?token=" + COMFYUI_TOKEN;
+        String uri = UriComponentsBuilder.fromUriString(url).toUriString();
+        WebClient webClient = getWebClient();
+        return webClient.get()
+                .uri(uri)
+                .accept(MediaType.APPLICATION_JSON)
+                .exchangeToMono(response -> {
+                    return response.bodyToMono(CFPromptResponse.class);
+//                    if (response.statusCode().is2xxSuccessful()) {
+//                        return response.bodyToMono(CFPromptResponse.class); // 成功响应
+//                    } else {
+//                        return response.bodyToMono(CFPromptResponse.class).map(e -> KLingUtil.mapErrorResponse(e));
+//                    }
+                });
 
-    /**
-     * 获取所有活动的连接ID
-     * @return 连接ID集合
-     */
-    public Flux<String> getActiveConnections() {
-        return Flux.fromIterable(sessions.keySet());
     }
 
 }

+ 111 - 0
src/main/java/com/backendsys/modules/sdk/comfyui/service/impl/ComfyUISocketServiceImpl.java

@@ -0,0 +1,111 @@
+package com.backendsys.modules.sdk.comfyui.service.impl;
+
+import com.backendsys.modules.sdk.comfyui.service.ComfyUISocketService;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Service;
+import org.springframework.web.reactive.socket.WebSocketMessage;
+import org.springframework.web.reactive.socket.WebSocketSession;
+import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient;
+import org.springframework.web.reactive.socket.client.WebSocketClient;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.netty.http.client.HttpClient;
+
+import java.net.URI;
+import java.time.Duration;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+@Service
+public class ComfyUISocketServiceImpl implements ComfyUISocketService {
+
+    @Value("${comfyui.token}")
+    private String COMFYUI_TOKEN;
+
+    /**
+     * 管理多个连接
+     */
+    private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
+
+    /**
+     * 创建带有认证头的 WebSocketClient
+     * @param token 认证令牌
+     * @return 配置好的 WebSocketClient
+     */
+    private WebSocketClient createWebSocketClientWithToken(String token) {
+        HttpClient httpClient = HttpClient.create()
+            .headers(headers -> headers.add("Authorization", "Bearer " + token))
+            .responseTimeout(Duration.ofSeconds(30));  // 30秒超时
+        return new ReactorNettyWebSocketClient(httpClient);
+    }
+
+    /**
+     * [ComfyUI] 创建 WebSocket 监听连接
+     * @param clientId 客户端ID(用于标识连接)
+     * @param wsUrl WebSocket 地址
+     * @return Mono<Void> 表示连接操作
+     */
+    @Override
+    public Mono<Void> connect(String clientId, String wsUrl) {
+
+        return Mono.defer(() -> {
+            if (sessions.containsKey(clientId)) {
+                return Mono.error(new IllegalStateException("Connection already exists for client: " + clientId));
+            }
+
+            // 动态创建带有认证头的客户端
+            WebSocketClient clientWithAuth = createWebSocketClientWithToken(COMFYUI_TOKEN);
+            return clientWithAuth.execute(URI.create(wsUrl + "?clientId=" + clientId), session -> {
+                // 保存会话
+                sessions.put(clientId, session);
+
+                // 接收消息
+                Flux<String> incomingMessages = session.receive()
+                    .map(WebSocketMessage::getPayloadAsText)
+                    .doOnNext(message -> {
+
+                        System.out.println("(doOnNext) Received from " + clientId + ": " + message);
+//                            // 转发到消息总线
+//                            messageSink.tryEmitNext(message);
+                    })
+                    .doOnError(e -> {
+                        System.err.println("(doOnError) Error for " + clientId + ": " + e.getMessage());
+                    })
+                    .doFinally(signal -> {
+                        System.out.println("(doFinally) Connection closed for " + clientId + ": " + signal);
+                        sessions.remove(clientId);
+                    });
+
+                // 需要返回一个Mono<Void>来表示处理完成
+                return incomingMessages.then();
+            });
+        });
+
+    }
+
+    /**
+     * [ComfyUI] 断开 WebSocket 监听连接
+     * @param clientId 客户端ID
+     * @return Mono<Void> 表示断开操作
+     */
+    @Override
+    public Mono<Void> disconnect(String clientId) {
+        return Mono.fromRunnable(() -> {
+            WebSocketSession session = sessions.get(clientId);
+            if (session != null) {
+                System.out.println("disconnect success! clientId: " + clientId);
+                session.close().subscribe();
+                sessions.remove(clientId);
+            }
+        });
+    }
+
+    /**
+     * 获取所有活动的连接ID
+     * @return 连接ID集合
+     */
+    public Flux<String> getActiveConnections() {
+        return Flux.fromIterable(sessions.keySet());
+    }
+
+}

+ 8 - 0
src/main/java/com/backendsys/modules/sdk/comfyui/utils/ComfyUtil.java

@@ -0,0 +1,8 @@
+package com.backendsys.modules.sdk.comfyui.utils;
+
+import org.springframework.stereotype.Component;
+
+@Component
+public class ComfyUtil {
+
+}