Просмотр исходного кода

接入ComfyUI websocket (未接入sse)

tsurumure 2 месяцев назад
Родитель
Сommit
62e7ade2b0

+ 107 - 107
src/main/java/com/backendsys/config/WebSocket/WebSocketConfig.java

@@ -1,107 +1,107 @@
-package com.backendsys.config.WebSocket;
-
-import cn.hutool.core.util.StrUtil;
-import com.backendsys.modules.common.config.security.utils.JwtUtil;
-
-import io.jsonwebtoken.Claims;
-import lombok.RequiredArgsConstructor;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.context.annotation.Configuration;
-import org.springframework.http.HttpHeaders;
-import org.springframework.messaging.Message;
-import org.springframework.messaging.MessageChannel;
-import org.springframework.messaging.simp.config.ChannelRegistration;
-import org.springframework.messaging.simp.config.MessageBrokerRegistry;
-import org.springframework.messaging.simp.stomp.StompCommand;
-import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
-import org.springframework.messaging.support.ChannelInterceptor;
-import org.springframework.messaging.support.MessageHeaderAccessor;
-import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
-import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
-import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
-
-@Configuration
-@EnableWebSocketMessageBroker
-@RequiredArgsConstructor
-public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
-
-
-    @Autowired
-    private JwtUtil jwtUtil;
-
-    /**
-     * 注册一个端点,客户端通过这个端点进行连接
-     */
-    @Override
-    public void registerStompEndpoints(StompEndpointRegistry registry) {
-        registry
-                .addEndpoint("/ws")   // 注册了一个 /ws 的端点
-                .setAllowedOriginPatterns("*") // 允许跨域的 WebSocket 连接
-                .withSockJS();  // 启用 SockJS (浏览器不支持WebSocket,SockJS 将会提供兼容性支持)
-    }
-
-    /**
-     * 配置消息代理
-     */
-    @Override
-    public void configureMessageBroker(MessageBrokerRegistry registry) {
-        // 客户端发送消息的请求前缀
-        registry.setApplicationDestinationPrefixes("/app");
-        // 客户端订阅消息的请求前缀,topic一般用于广播推送,queue用于点对点推送
-        registry.enableSimpleBroker("/topic", "/queue");
-        // 服务端通知客户端的前缀,可以不设置,默认为user
-        registry.setUserDestinationPrefix("/user");
-    }
-
-    /**
-     * 配置客户端入站通道拦截器
-     * <p>
-     * 添加 ChannelInterceptor 拦截器,用于在消息发送前,从请求头中获取 token 并解析出用户信息(username),用于点对点发送消息给指定用户
-     *
-     * @param registration 通道注册器
-     */
-    @Override
-    public void configureClientInboundChannel(ChannelRegistration registration) {
-        registration.interceptors(new ChannelInterceptor() {
-            @Override
-            public Message<?> preSend(Message<?> message, MessageChannel channel) {
-
-                StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
-
-                // (改) 如果是连接请求(CONNECT 命令),从请求头中取出 token 并设置到认证信息中
-                if (accessor != null && StompCommand.CONNECT.equals(accessor.getCommand())) {
-
-                   // 从连接头中提取授权令牌
-                   String bearerToken = accessor.getFirstNativeHeader(HttpHeaders.AUTHORIZATION);
-
-                   // 验证令牌格式并提取用户信息
-                   if (StrUtil.isNotBlank(bearerToken) && bearerToken.startsWith("Bearer ")) {
-                       try {
-                           // 移除 "Bearer " 前缀,从令牌中提取用户信息(username), 并设置到认证信息中
-                           String tokenWithoutPrefix = bearerToken.substring(7);
-
-                           Claims tokenInfo = jwtUtil.extractAllClaims(tokenWithoutPrefix);
-                           String username = (String) tokenInfo.get("username");
-
-                           if (StrUtil.isNotBlank(username)) {
-                               accessor.setUser(() -> username);
-                               return message;
-                           }
-                           
-                       } catch (Exception e) {
-                           throw new RuntimeException("Failed to process authentication token.");
-                       }
-                   }
-
-
-                }
-                // 不是连接请求,直接放行
-
-
-
-                return ChannelInterceptor.super.preSend(message, channel);
-            }
-        });
-    }
-
-}
+//package com.backendsys.config.WebSocket;
+//
+//import cn.hutool.core.util.StrUtil;
+//import com.backendsys.modules.common.config.security.utils.JwtUtil;
+//
+//import io.jsonwebtoken.Claims;
+//import lombok.RequiredArgsConstructor;
+//import org.springframework.beans.factory.annotation.Autowired;
+//import org.springframework.context.annotation.Configuration;
+//import org.springframework.http.HttpHeaders;
+//import org.springframework.messaging.Message;
+//import org.springframework.messaging.MessageChannel;
+//import org.springframework.messaging.simp.config.ChannelRegistration;
+//import org.springframework.messaging.simp.config.MessageBrokerRegistry;
+//import org.springframework.messaging.simp.stomp.StompCommand;
+//import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
+//import org.springframework.messaging.support.ChannelInterceptor;
+//import org.springframework.messaging.support.MessageHeaderAccessor;
+//import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
+//import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
+//import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
+//
+//@Configuration
+//@EnableWebSocketMessageBroker
+//@RequiredArgsConstructor
+//public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
+//
+//
+//    @Autowired
+//    private JwtUtil jwtUtil;
+//
+//    /**
+//     * 注册一个端点,客户端通过这个端点进行连接
+//     */
+//    @Override
+//    public void registerStompEndpoints(StompEndpointRegistry registry) {
+//        registry
+//                .addEndpoint("/ws")   // 注册了一个 /ws 的端点
+//                .setAllowedOriginPatterns("*") // 允许跨域的 WebSocket 连接
+//                .withSockJS();  // 启用 SockJS (浏览器不支持WebSocket,SockJS 将会提供兼容性支持)
+//    }
+//
+//    /**
+//     * 配置消息代理
+//     */
+//    @Override
+//    public void configureMessageBroker(MessageBrokerRegistry registry) {
+//        // 客户端发送消息的请求前缀
+//        registry.setApplicationDestinationPrefixes("/app");
+//        // 客户端订阅消息的请求前缀,topic一般用于广播推送,queue用于点对点推送
+//        registry.enableSimpleBroker("/topic", "/queue");
+//        // 服务端通知客户端的前缀,可以不设置,默认为user
+//        registry.setUserDestinationPrefix("/user");
+//    }
+//
+//    /**
+//     * 配置客户端入站通道拦截器
+//     * <p>
+//     * 添加 ChannelInterceptor 拦截器,用于在消息发送前,从请求头中获取 token 并解析出用户信息(username),用于点对点发送消息给指定用户
+//     *
+//     * @param registration 通道注册器
+//     */
+//    @Override
+//    public void configureClientInboundChannel(ChannelRegistration registration) {
+//        registration.interceptors(new ChannelInterceptor() {
+//            @Override
+//            public Message<?> preSend(Message<?> message, MessageChannel channel) {
+//
+//                StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
+//
+//                // (改) 如果是连接请求(CONNECT 命令),从请求头中取出 token 并设置到认证信息中
+//                if (accessor != null && StompCommand.CONNECT.equals(accessor.getCommand())) {
+//
+//                   // 从连接头中提取授权令牌
+//                   String bearerToken = accessor.getFirstNativeHeader(HttpHeaders.AUTHORIZATION);
+//
+//                   // 验证令牌格式并提取用户信息
+//                   if (StrUtil.isNotBlank(bearerToken) && bearerToken.startsWith("Bearer ")) {
+//                       try {
+//                           // 移除 "Bearer " 前缀,从令牌中提取用户信息(username), 并设置到认证信息中
+//                           String tokenWithoutPrefix = bearerToken.substring(7);
+//
+//                           Claims tokenInfo = jwtUtil.extractAllClaims(tokenWithoutPrefix);
+//                           String username = (String) tokenInfo.get("username");
+//
+//                           if (StrUtil.isNotBlank(username)) {
+//                               accessor.setUser(() -> username);
+//                               return message;
+//                           }
+//
+//                       } catch (Exception e) {
+//                           throw new RuntimeException("Failed to process authentication token.");
+//                       }
+//                   }
+//
+//
+//                }
+//                // 不是连接请求,直接放行
+//
+//
+//
+//                return ChannelInterceptor.super.preSend(message, channel);
+//            }
+//        });
+//    }
+//
+//}

+ 50 - 0
src/main/java/com/backendsys/modules/sdk/comfyui/controller/ComfyUIDemoController.java

@@ -0,0 +1,50 @@
+package com.backendsys.modules.sdk.comfyui.controller;
+
+import com.backendsys.exception.CustException;
+import com.backendsys.modules.common.config.security.annotations.Anonymous;
+import com.backendsys.modules.sdk.comfyui.service.ComfyUIService;
+import com.tencentcloudapi.tione.v20211111.models.ChatCompletionResponse;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.http.MediaType;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.PathVariable;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+import reactor.core.publisher.Flux;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+@RestController
+public class ComfyUIDemoController {
+
+    @Autowired
+    private ComfyUIService comfyUIService;
+
+    /**
+     * 创建WebSocket连接
+     * @param clientId 客户端ID
+     * @return 操作结果
+     */
+    @Anonymous
+    @PostMapping("/api/comfyui/ws/connect")
+    public String connect(String clientId) {
+        comfyUIService.connect(clientId, "ws://43.128.1.201:8007/ws").subscribe();
+        return "Connection initiated for: " + clientId;
+    }
+
+    /**
+     * 断开WebSocket连接
+     * @param clientId 客户端ID
+     * @return 操作结果
+     */
+    @Anonymous
+    @PostMapping("/api/comfyui/ws/disconnect")
+    public String disconnect(String clientId) {
+        comfyUIService.disconnect(clientId).subscribe();
+        return "Disconnected: " + clientId;
+    }
+
+}

+ 8 - 0
src/main/java/com/backendsys/modules/sdk/comfyui/service/ComfyUIService.java

@@ -1,4 +1,12 @@
 package com.backendsys.modules.sdk.comfyui.service;
 
+import reactor.core.publisher.Mono;
+
 public interface ComfyUIService {
+
+    // 连接到第三方 WebSocket 服务
+    Mono<Void> connect(String clientId, String wsUrl);
+    // 断开指定连接
+    Mono<Void> disconnect(String clientId);
+
 }

+ 99 - 27
src/main/java/com/backendsys/modules/sdk/comfyui/service/impl/ComfyUIServiceImpl.java

@@ -1,51 +1,123 @@
 package com.backendsys.modules.sdk.comfyui.service.impl;
 
 import com.backendsys.modules.sdk.comfyui.service.ComfyUIService;
+import com.tencentcloudapi.tione.v20211111.models.ChatCompletionResponse;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClients;
+import org.springframework.beans.factory.annotation.Value;
 import org.springframework.stereotype.Service;
 import org.springframework.web.reactive.socket.WebSocketMessage;
 import org.springframework.web.reactive.socket.WebSocketSession;
-import org.springframework.web.socket.client.WebSocketClient;
-import org.springframework.web.socket.client.standard.StandardWebSocketClient;
+import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient;
+import org.springframework.web.reactive.socket.client.WebSocketClient;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.FluxSink;
 import reactor.core.publisher.Mono;
+import reactor.netty.http.client.HttpClient;
+import reactor.util.retry.Retry;
 
 import java.net.URI;
-import java.net.URISyntaxException;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.time.Duration;
+import java.util.Base64;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
-import static org.springframework.core.io.support.SpringFactoriesLoader.FailureHandler.handleMessage;
-
 @Service
 public class ComfyUIServiceImpl implements ComfyUIService {
 
-    // 单例 WebSocketClient(线程安全)
-    private final WebSocketClient webSocketClient = new StandardWebSocketClient();
+    @Value("${comfyui.token}")
+    private String COMFYUI_TOKEN = "";
+
+//    // 单例 WebSocketClient(线程安全)
+//    private WebSocketClient webSocketClient;
 
     // 管理多个连接
     private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();
 
-//    public Mono<Void> connect(String clientId, String wsUrl) throws URISyntaxException {
-//        return webSocketClient.execute(new URI(wsUrl), session -> {
-//            // 在 WebSocketSession 的回调中处理消息
-//            session.receive()          // 返回 Flux<WebSocketMessage>
-//                    .doOnNext(message -> {
-//                        // 处理每条消息
-//                        String payload = message.getPayloadAsText();
-//                        System.out.println("Received: " + payload);
-//                    })
-//                    .doOnError(e -> System.err.println("Error: " + e.getMessage()))
-//                    .doFinally(signal -> System.out.println("Connection closed: " + signal))
-//                    .then();               // 返回 Mono<Void> 表示完成
-//        });
-//    }
-
-    private void handleMessage(String clientId, WebSocketMessage message) {
-        // 转发到 SSE 或其他逻辑
-        System.out.println("Received from " + clientId + ": " + message.getPayloadAsText());
+    /**
+     * 创建带有认证头的 WebSocketClient
+     * @param token 认证令牌
+     * @return 配置好的 WebSocketClient
+     */
+    private WebSocketClient createWebSocketClientWithToken(String token) {
+        HttpClient httpClient = HttpClient.create()
+            .headers(headers -> headers.add("Authorization", "Bearer " + token))
+            .responseTimeout(Duration.ofSeconds(30));  // 30秒超时
+        return new ReactorNettyWebSocketClient(httpClient);
+    }
+
+    /**
+     * 连接到第三方 WebSocket 服务
+     * @param clientId 客户端ID(用于标识连接)
+     * @param wsUrl WebSocket 地址
+     * @return Mono<Void> 表示连接操作
+     */
+    @Override
+    public Mono<Void> connect(String clientId, String wsUrl) {
+
+        return Mono.defer(() -> {
+            if (sessions.containsKey(clientId)) {
+                return Mono.error(new IllegalStateException("Connection already exists for client: " + clientId));
+            }
+
+            // 动态创建带有认证头的客户端
+            WebSocketClient clientWithAuth = createWebSocketClientWithToken(COMFYUI_TOKEN);
+            return clientWithAuth.execute(URI.create(wsUrl + "?clientId=" + clientId), session -> {
+                // 保存会话
+                sessions.put(clientId, session);
+
+                // 接收消息
+                Flux<String> incomingMessages = session.receive()
+                    .map(WebSocketMessage::getPayloadAsText)
+                    .doOnNext(message -> {
+
+                        System.out.println("(doOnNext) Received from " + clientId + ": " + message);
+//                            // 转发到消息总线
+//                            messageSink.tryEmitNext(message);
+                    })
+                    .doOnError(e -> {
+                        System.err.println("(doOnError) Error for " + clientId + ": " + e.getMessage());
+                    })
+                    .doFinally(signal -> {
+                        System.out.println("(doFinally) Connection closed for " + clientId + ": " + signal);
+                        sessions.remove(clientId);
+                        // 断开连接
+//                        disconnect(clientId);
+                    });
+
+                // 需要返回一个Mono<Void>来表示处理完成
+                return incomingMessages.then();
+            });
+        });
+
+    }
+
+    /**
+     * 断开指定连接
+     * @param clientId 客户端ID
+     * @return Mono<Void> 表示断开操作
+     */
+    @Override
+    public Mono<Void> disconnect(String clientId) {
+        return Mono.fromRunnable(() -> {
+            WebSocketSession session = sessions.get(clientId);
+            if (session != null) {
+                System.out.println("disconnect success! clientId: " + clientId);
+                session.close().subscribe();
+                sessions.remove(clientId);
+            }
+        });
     }
 
-    private void handleError(String clientId, Throwable e) {
-        System.err.println("Error for " + clientId + ": " + e.getMessage());
+    /**
+     * 获取所有活动的连接ID
+     * @return 连接ID集合
+     */
+    public Flux<String> getActiveConnections() {
+        return Flux.fromIterable(sessions.keySet());
     }
 
 }

+ 4 - 1
src/main/resources/application-dev.yml

@@ -188,4 +188,7 @@ klingai:
   url: https://api-beijing.klingai.com
   access-key: A9baBTPFLH8RfrAHGeb4mGagmRHhRHTg
   secret-key: PQnT4E9TMYkPN93pb8JCHJC3dtFAtNPC
-  token-duration-time: 10000
+  token-duration-time: 10000
+
+comfyui:
+  token: $2b$12$.MR4qGaFetN1FPQzbfyIrehsyjnPJ12xAZhR/l7KZpLkUPQTCG4gy

+ 4 - 1
src/main/resources/application-local.yml

@@ -201,4 +201,7 @@ klingai:
   url: https://api-beijing.klingai.com
   access-key: A9baBTPFLH8RfrAHGeb4mGagmRHhRHTg
   secret-key: PQnT4E9TMYkPN93pb8JCHJC3dtFAtNPC
-  token-duration-time: 10000
+  token-duration-time: 10000
+
+comfyui:
+  token: $2b$12$.MR4qGaFetN1FPQzbfyIrehsyjnPJ12xAZhR/l7KZpLkUPQTCG4gy

+ 4 - 1
src/main/resources/application-prod.yml

@@ -189,4 +189,7 @@ klingai:
   url: https://api-beijing.klingai.com
   access-key: A9baBTPFLH8RfrAHGeb4mGagmRHhRHTg
   secret-key: PQnT4E9TMYkPN93pb8JCHJC3dtFAtNPC
-  token-duration-time: 10000
+  token-duration-time: 10000
+
+comfyui:
+  token: $2b$12$.MR4qGaFetN1FPQzbfyIrehsyjnPJ12xAZhR/l7KZpLkUPQTCG4gy