tsurumure 6 miesięcy temu
rodzic
commit
fdb3a0fbd8

+ 9 - 2
src/main/java/com/backendsys/modules/ai/chat/service/impl/ChatServiceImpl.java

@@ -17,6 +17,9 @@ import com.backendsys.utils.response.PageInfoResult;
 import com.backendsys.utils.v2.PageUtils;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.stereotype.Service;
 
 import java.util.List;
@@ -44,6 +47,10 @@ public class ChatServiceImpl implements ChatService {
     @Override
     public Map<String, Object> sendChat(Chat chat) {
 
+        // 手动设置 SecurityContext
+        SecurityContext context = SecurityContextHolder.getContext();
+        SecurityContextHolder.setContext(context);
+
         Long user_id = SecurityUtil.getUserId();
         chat.setUser_id(user_id);
         chat.setRole("user");
@@ -114,12 +121,12 @@ public class ChatServiceImpl implements ChatService {
                 // ------------------------------------------------------------
                 if ("DEEPSEEK".equals(model)) {
                     // [DeepSeek] 发起对话
-                    chatResult = deepSeekClient.chatCompletion(model_version, prompt, chatList);
+                    chatResult = deepSeekClient.chatCompletion(user_id, model_version, prompt, chatList);
                 }
                 // ------------------------------------------------------------
                 if ("HUNYUAN".equals(model)) {
                     // [混元] 发起对话
-                    chatResult = hunYuanClient.chatCompletion(prompt, chatList);
+                    chatResult = hunYuanClient.chatCompletion(user_id, prompt, chatList);
                 }
                 // ------------------------------------------------------------
 

+ 0 - 1
src/main/java/com/backendsys/modules/common/config/security/utils/SecurityUtil.java

@@ -76,7 +76,6 @@ public class SecurityUtil {
         String token = getToken();
         if (StrUtil.isEmpty(token)) throw new CustException("getUserInfo() token is empty.");
         if (token.contains("SessionId") || token.contains("RemoteIpAddress")) throw new CustException("getUserInfo() need token.");
-
         Claims tokenInfo = Jwts.parser().verifyWith(getSignInKey()).build().parseSignedClaims(token).getPayload();
         JSONObject userInfo = JSONUtil.parseObj(tokenInfo.get("userInfo"));
         String target = Convert.toStr(tokenInfo.get("target"));

+ 1 - 1
src/main/java/com/backendsys/modules/log/controller/LogStreamController.java

@@ -104,7 +104,7 @@ public class LogStreamController {
     public String send() {
         String message = "{\"message\": \"Hello World 中文\"}";
         // 获得当前用户Id
-        sseUtil.send(APPLICATION_NAME + "-1", message);
+        sseUtil.send(SecurityUtil.getUserId(), message);
         return "success";
     }
 

+ 3 - 3
src/main/java/com/backendsys/modules/sdk/deepseek/controller/DeepSeekController.java

@@ -44,7 +44,7 @@ public class DeepSeekController {
 //        CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
 //            deepSeekClient.chatCompletion(param.getModel(), param.getPrompt());
 //        });
-        return Result.success().put("data", deepSeekClient.chatCompletion(param.getModel(), param.getPrompt(), null));
+        return Result.success().put("data", deepSeekClient.chatCompletion(SecurityUtil.getUserId(), param.getModel(), param.getPrompt(), null));
     }
 
     /**
@@ -56,7 +56,7 @@ public class DeepSeekController {
     @PreAuthorize("@sr.hasPermission('101')")
     @PostMapping("/api/deepSeek/chatLocal")
     public Result chatLocal(@Validated @RequestBody DSParam param) {
-        ollamaUtil.chatDeepSeek(param.getModel(), param.getPrompt(), SecurityUtil.getUserId());
+        ollamaUtil.chatDeepSeek(SecurityUtil.getUserId(), param.getModel(), param.getPrompt());
         return Result.success();
     }
 
@@ -64,7 +64,7 @@ public class DeepSeekController {
     @PreAuthorize("@sr.hasPermission('101')")
     @GetMapping("/api/deepSeek/getModels")
     public Result getModels() {
-        return Result.success().put("data", deepSeekClient.getModels());
+        return Result.success().put("data", deepSeekClient.getModels(SecurityUtil.getUserId()));
     }
 
 }

+ 2 - 2
src/main/java/com/backendsys/modules/sdk/deepseek/service/DeepSeekClient.java

@@ -9,9 +9,9 @@ import java.util.List;
 public interface DeepSeekClient {
 
     // [DeepSeek] 发起对话
-    ChatResult chatCompletion(String model, String prompt, List<Chat> chatList);
+    ChatResult chatCompletion(Long user_id, String model, String prompt, List<Chat> chatList);
 
     // [DeepSeek] 获得模型
-    JSONArray getModels();
+    JSONArray getModels(Long user_id);
 
 }

+ 9 - 9
src/main/java/com/backendsys/modules/sdk/deepseek/service/impl/DeepSeekClientImpl.java

@@ -53,7 +53,7 @@ public class DeepSeekClientImpl implements DeepSeekClient {
      * - 文档:https://api-docs.deepseek.com/zh-cn/api/create-chat-completion
      */
     @Override
-    public ChatResult chatCompletion(String model, String prompt, List<Chat> chatList) {
+    public ChatResult chatCompletion(Long user_id, String model, String prompt, List<Chat> chatList) {
 
         ChatResult chatResult = new ChatResult();
         try {
@@ -127,7 +127,7 @@ public class DeepSeekClientImpl implements DeepSeekClient {
                     ChatSseMessage chatLoadingSseMessage = new ChatSseMessage();
                     chatLoadingSseMessage.setContent_type("loading");
                     chatLoadingSseMessage.setContent("正在思考");
-                    sseUtil.send(new SseResponse(SseResponseEnum.DEEPSEEK, chatLoadingSseMessage).toJsonStr());
+                    sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatLoadingSseMessage).toJsonStr());
 
                     StringBuilder allContent = new StringBuilder();
                     StringBuilder allReasoningContent = new StringBuilder();
@@ -174,7 +174,7 @@ public class DeepSeekClientImpl implements DeepSeekClient {
                                 ChatSseMessage chatSseMessage = new ChatSseMessage();
                                 chatSseMessage.setContent_type("think");
                                 chatSseMessage.setContent(reasoning_content);
-                                sseUtil.send(new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
+                                sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
 
                                 // 收集推理内容
                                 allReasoningContent.append(reasoning_content);
@@ -198,7 +198,7 @@ public class DeepSeekClientImpl implements DeepSeekClient {
                                 ChatSseMessage chatSseMessage = new ChatSseMessage();
                                 chatSseMessage.setContent_type("reply");
                                 chatSseMessage.setContent(content);
-                                sseUtil.send(new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
+                                sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
 
                                 // 收集回答内容
                                 allContent.append(content);
@@ -223,13 +223,13 @@ public class DeepSeekClientImpl implements DeepSeekClient {
 
                 }
             } catch (Exception e) {
-                sseUtil.send((new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr());
+                sseUtil.send(user_id, (new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr());
                 System.out.println(e.getMessage());
                 chatResult.setContent(e.getMessage());
                 return chatResult;
             }
         } catch (Exception e) {
-            sseUtil.send((new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr());
+            sseUtil.send(user_id, (new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr());
             System.out.println(e.getMessage());
             chatResult.setContent(e.getMessage());
             return chatResult;
@@ -241,7 +241,7 @@ public class DeepSeekClientImpl implements DeepSeekClient {
      * [DeepSeek] 获得模型
      */
     @Override
-    public JSONArray getModels() {
+    public JSONArray getModels(Long user_id) {
 
         // 调用 Deepseek API
         try (CloseableHttpClient client = HttpClients.createDefault()) {
@@ -266,11 +266,11 @@ public class DeepSeekClientImpl implements DeepSeekClient {
                 }
 
             } catch (Exception e) {
-                sseUtil.send((new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr());
+                sseUtil.send(user_id, (new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr());
                 throw new CustException(e.getMessage());
             }
         } catch (Exception e) {
-            sseUtil.send((new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr());
+            sseUtil.send(user_id, (new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr());
             throw new CustException(e.getMessage());
         }
 

+ 5 - 7
src/main/java/com/backendsys/modules/sdk/deepseek/utils/OllamaUtil.java

@@ -43,15 +43,13 @@ public class OllamaUtil {
     /**
      * 流式对话
      */
-    public void chatDeepSeek(String model, String question, Long userId) {
+    public void chatDeepSeek(Long userId, String model, String question) {
 
         System.out.println("(userId): " + userId + " 提问: " + question);
 
         // 创建一个 CompletableFuture 来执行异步任务
         CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
 
-            String emitterKey = APPLICATION_NAME + "-userid-" + Convert.toStr(userId);
-
             ObjectMapper objectMapper = new ObjectMapper();
             try (CloseableHttpClient client = HttpClients.createDefault()) {
 
@@ -82,22 +80,22 @@ public class OllamaUtil {
                         System.out.println("content: " + responseContent);
 
                         String dataStr = (new SseResponse(SseResponseEnum.DEEPSEEK, responseContent)).toJsonStr();
-                        sseUtil.send(emitterKey, dataStr);
+                        sseUtil.send(userId, dataStr);
 
                     }
 
                     System.out.println("-- 回答结束 ------------------------------------------");
                     String dataStr = (new SseResponse(SseResponseEnum.DEEPSEEK, "end")).toJsonStr();
-                    sseUtil.send(emitterKey, dataStr);
+                    sseUtil.send(userId, dataStr);
 
                 } catch (Exception e) {
                     String dataStr = (new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr();
-                    sseUtil.send(emitterKey, dataStr);
+                    sseUtil.send(userId, dataStr);
                     e.printStackTrace();
                 }
             } catch (Exception e) {
                 String dataStr = (new SseResponse(SseResponseEnum.DEEPSEEK, e.getMessage())).toJsonStr();
-                sseUtil.send(emitterKey, dataStr);
+                sseUtil.send(userId, dataStr);
                 e.printStackTrace();
             }
 

+ 1 - 2
src/main/java/com/backendsys/modules/sdk/douyincloud/tos/service/impl/DouyinTosServiceImpl.java

@@ -91,8 +91,7 @@ public class DouyinTosServiceImpl implements DouyinTosService {
                 progress.setPercent(percentage);
                 progress.setState(state);
                 String dataStr = (new SseResponse(SseResponseEnum.UPLOAD, progress)).toJsonStr();
-                String emitterKey = APPLICATION_NAME + "-userid-" + Convert.toStr(SecurityUtil.getUserId());
-                sseUtil.send(emitterKey, dataStr);
+                sseUtil.send(SecurityUtil.getUserId(), dataStr);
 
             }
         };

+ 4 - 3
src/main/java/com/backendsys/modules/sdk/tencentcloud/cos/service/impl/TencentCosServiceImpl.java

@@ -79,12 +79,13 @@ public class TencentCosServiceImpl implements TencentCosService {
     // [腾讯云COS][高级接口] 获取进度函数
     private void showTransferProgress(String filename, Transfer transfer) {
 
+        Long user_id = SecurityUtil.getUserId();
+
         // [SSE] 进度回传
         Progress progress = new Progress();
         progress.setState("init");
         progress.setFilename(filename);
-        String emitterKey = APPLICATION_NAME + "-userid-" + Convert.toStr(SecurityUtil.getUserId());
-        sseUtil.send(emitterKey, (new SseResponse(SseResponseEnum.UPLOAD, progress)).toJsonStr());
+        sseUtil.send(user_id, (new SseResponse(SseResponseEnum.UPLOAD, progress)).toJsonStr());
 
         // 查询上传是否已经完成
         while (transfer.isDone() == false) {
@@ -104,7 +105,7 @@ public class TencentCosServiceImpl implements TencentCosService {
             progress.setPercent(percent);
             progress.setState(state.toLowerCase());
             String dataStr = (new SseResponse(SseResponseEnum.UPLOAD, progress)).toJsonStr();
-            sseUtil.send(emitterKey, dataStr);
+            sseUtil.send(user_id, dataStr);
 
             // state: (完成 Completed, 失败 Failed)
             // System.out.println(transfer.getState());

+ 1 - 1
src/main/java/com/backendsys/modules/sdk/tencentcloud/huanyuan/service/HunYuanClient.java

@@ -8,6 +8,6 @@ import java.util.List;
 public interface HunYuanClient {
 
     // [HunYuan] 发起对话
-    ChatResult chatCompletion(String prompt, List<Chat> chatList);
+    ChatResult chatCompletion(Long user_id, String prompt, List<Chat> chatList);
 
 }

+ 7 - 3
src/main/java/com/backendsys/modules/sdk/tencentcloud/huanyuan/service/impl/HunYuanClientImpl.java

@@ -4,6 +4,7 @@ import cn.hutool.core.util.ArrayUtil;
 import com.backendsys.modules.ai.chat.entity.Chat;
 import com.backendsys.modules.ai.chat.entity.ChatResult;
 import com.backendsys.modules.ai.chat.entity.ChatSseMessage;
+import com.backendsys.modules.common.config.security.utils.SecurityUtil;
 import com.backendsys.modules.sdk.deepseek.entity.DSRequestMessage;
 import com.backendsys.modules.sdk.tencentcloud.huanyuan.service.HunYuanClient;
 import com.backendsys.modules.sse.entity.SseResponse;
@@ -23,6 +24,8 @@ import com.tencentcloudapi.hunyuan.v20230901.models.ChatStdResponse;
 import com.tencentcloudapi.hunyuan.v20230901.models.Message;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.stereotype.Service;
 
 import java.util.ArrayList;
@@ -54,7 +57,9 @@ public class HunYuanClientImpl implements HunYuanClient {
      * https://cloud.tencent.com/document/product/1729/101836
      */
     @Override
-    public ChatResult chatCompletion(String prompt, List<Chat> chatList) {
+    public ChatResult chatCompletion(Long user_id, String prompt, List<Chat> chatList) {
+
+        System.out.println("(chatCompletion) user_id = " + user_id);
 
         ChatResult chatResult = new ChatResult();
         try {
@@ -87,7 +92,6 @@ public class HunYuanClientImpl implements HunYuanClient {
             });
             System.out.println("-----------------------------------------------------");
 
-
             // 构建请求体
             ChatStdRequest req = new ChatStdRequest();
             req.setMessages(ArrayUtil.toArray(messages, Message.class));
@@ -124,7 +128,7 @@ public class HunYuanClientImpl implements HunYuanClient {
                 ChatSseMessage chatSseMessage = new ChatSseMessage();
                 chatSseMessage.setContent_type("reply");
                 chatSseMessage.setContent(content);
-                sseUtil.send(new SseResponse(SseResponseEnum.HUNYUAN, chatSseMessage).toJsonStr());
+                sseUtil.send(user_id, new SseResponse(SseResponseEnum.HUNYUAN, chatSseMessage).toJsonStr());
 
                 // 收集回答内容
                 allContent.append(content);

+ 3 - 2
src/main/java/com/backendsys/modules/sse/controller/SseController.java

@@ -10,6 +10,8 @@ import com.backendsys.modules.sse.utils.SseUtil;
 import jakarta.servlet.http.HttpServletResponse;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@@ -62,8 +64,7 @@ public class SseController {
     @GetMapping("/api/sse/sendHello")
     public String sendHelloWorld() {
         String dataStr = (new SseResponse("Hello World")).toJsonStr();
-        String emitterKey = APPLICATION_NAME + "-userid-" + Convert.toStr(SecurityUtil.getUserId());
-        sseUtil.send(emitterKey, dataStr);
+        sseUtil.send(SecurityUtil.getUserId(), dataStr);
         return "success";
     }
 

+ 17 - 15
src/main/java/com/backendsys/modules/sse/utils/SseUtil.java

@@ -19,23 +19,25 @@ public class SseUtil {
     @Value("${spring.application.name}")
     private String APPLICATION_NAME;
 
-    // [SSE] 发送消息 (单个)
-    public void send(String emitterKey, Object data) {
-        SseEmitterManager manager = SseEmitterManager.getInstance();
-        SseEmitterUTF8 emitter = manager.getEmitter(emitterKey);
-        if (emitter != null) {
-            try {
-                emitter.send(SseEmitter.event().data(data));
-            } catch (IOException e) {
-                System.out.println(e.getMessage());
-                manager.removeEmitter(emitter);
-            }
-        }
-    }
+//    // [SSE] 发送消息 (单个)
+//    public void send(String emitterKey, Object data) {
+//        SseEmitterManager manager = SseEmitterManager.getInstance();
+//        SseEmitterUTF8 emitter = manager.getEmitter(emitterKey);
+//        if (emitter != null) {
+//            try {
+//                emitter.send(SseEmitter.event().data(data));
+//            } catch (IOException e) {
+//                System.out.println(e.getMessage());
+//                manager.removeEmitter(emitter);
+//            }
+//        }
+//    }
+
+    // 当多层嵌套时,Service 没有经过 Spring AOP 代理,因此不会正确获得 SecurityContextHolder 上下文;
+    // 因此此处不能使用 SecurityUtil.getUserId() 直接获得 UserId
 
     // [SSE] 发送消息 (单个) (自己)
-    public void send(Object data) {
-        Long user_id = SecurityUtil.getUserId();
+    public void send(Long user_id, Object data) {
         SseEmitterManager manager = SseEmitterManager.getInstance();
         SseEmitterUTF8 emitter = manager.getEmitter(APPLICATION_NAME + "-userid-" + Convert.toStr(user_id));
         if (emitter != null) {

+ 3 - 6
src/main/java/com/backendsys/modules/system/service/impl/SysUserServiceImpl.java

@@ -334,9 +334,8 @@ public class SysUserServiceImpl extends ServiceImpl<SysUserDao, SysUser> impleme
             sysUserDao.updateById(entity);
 
             // [SSE] 发送退出登录的消息
-            String emitterKey = APPLICATION_NAME + "-userid-" + Convert.toStr(sysUserDTO.getUser_id());
             String dataStr = (new SseResponse(SseResponseEnum.LOGOUT)).toJsonStr();
-            sseUtil.send(emitterKey, dataStr);
+            sseUtil.send(sysUserDTO.getUser_id(), dataStr);
 
             return Map.of("user_id", sysUserDTO.getUser_id());
 
@@ -373,9 +372,8 @@ public class SysUserServiceImpl extends ServiceImpl<SysUserDao, SysUser> impleme
             sysUserDao.updateById(entity);
 
             // [SSE] 发送退出登录的消息
-            String emitterKey = APPLICATION_NAME + "-userid-" + Convert.toStr(sysUserDTO.getUser_id());
             String dataStr = (new SseResponse(SseResponseEnum.LOGOUT)).toJsonStr();
-            sseUtil.send(emitterKey, dataStr);
+            sseUtil.send(sysUserDTO.getUser_id(), dataStr);
 
             Map<String, Object> response = new LinkedHashMap<>();
             response.put("user_id", sysUserDTO.getUser_id());
@@ -441,9 +439,8 @@ public class SysUserServiceImpl extends ServiceImpl<SysUserDao, SysUser> impleme
                 sysUserInfoDao.update(null, updateWrapper);
 
                 // [SSE] 发送退出登录的消息
-                String emitterKey = APPLICATION_NAME + "-userid-" + Convert.toStr(sysUserInfo.getUser_id());
                 String dataStr = (new SseResponse(SseResponseEnum.LOGOUT)).toJsonStr();
-                sseUtil.send(emitterKey, dataStr);
+                sseUtil.send(sysUserInfo.getUser_id(), dataStr);
             }
 
             return Map.of("user_id", user_id);

+ 1 - 4
src/main/java/com/backendsys/service/Ai/AiChatServiceImpl.java

@@ -115,9 +115,6 @@ public class AiChatServiceImpl implements AiChatService {
 //        String receiver = (String) loginUserInfo.get("username");
 
 
-        String emitterKey = APPLICATION_NAME + "-userid-" + Convert.toStr(user_id);
-
-
 
         // [发送/接收] 内容
         String sendContent = aiChatDTO.getContent();
@@ -196,7 +193,7 @@ public class AiChatServiceImpl implements AiChatService {
                 messagingTemplate.convertAndSendToUser(receiver, "/queue/ai-chating", e.Data);
 
                 // [SSE] 发送消息
-                sseUtil.send(emitterKey, e.Data);
+                sseUtil.send(user_id, e.Data);
 
                 receiveRobotCode = dataObject.getStr("Id");
                 JSONArray choicesArray = dataObject.getJSONArray("Choices");