Forráskód Böngészése

新增chat 中止

tsurumure 5 hónapja
szülő
commit
12d2b93852

+ 7 - 0
src/main/java/com/backendsys/modules/ai/chat/controller/ChatController.java

@@ -30,6 +30,13 @@ public class ChatController {
         return Result.success().put("data", chatService.sendChat(chat));
     }
 
+    @Operation(summary = "中止对话")
+    @PreAuthorize("@sr.hasPermission('31')")
+    @PostMapping("/api/ai/chat/abortChat")
+    public Result abortChat(@Validated(Chat.Abort.class) @RequestBody Chat chat) {
+        return Result.success().put("data", chatService.abortChat(chat));
+    }
+
     @Operation(summary = "获取我的对话")
     @PreAuthorize("@sr.hasPermission('31')")
     @GetMapping("/api/ai/chat/getChatList")

+ 3 - 2
src/main/java/com/backendsys/modules/ai/chat/entity/Chat.java

@@ -13,6 +13,7 @@ import lombok.Data;
 @TableName("ai_chat")
 public class Chat {
 
+    public static interface Abort{}
     public static interface Detail{}
     public static interface Create{}
     public static interface Update{}
@@ -23,8 +24,8 @@ public class Chat {
     @NotNull(message = "id 不能为空", groups = { DeleteOne.class })
     private Long id;
 
-    @Size(max = 36, message = "对话历史记录ID长度不超过 {max} 字符", groups = { Detail.class, Delete.class })
-    @NotEmpty(message = "history_code 不能为空", groups = { Detail.class, Delete.class })
+    @Size(max = 36, message = "对话历史记录ID长度不超过 {max} 字符", groups = { Abort.class, Detail.class, Delete.class })
+    @NotEmpty(message = "history_code 不能为空", groups = { Abort.class, Detail.class, Delete.class })
     private String history_code;
 
     @NotEmpty(message = "model 不能为空", groups = { Create.class })

+ 1 - 1
src/main/java/com/backendsys/modules/ai/chat/entity/ChatSseMessage.java

@@ -8,7 +8,7 @@ import lombok.NoArgsConstructor;
 public class ChatSseMessage {
 
     private String content;
-    private String content_type;    // (LOADING: 加载中, REPLY: 回复, THINK: 思考)
+    private String content_type;    // (LOADING-加载中, REPLY-回复, REPLY-回复, REPLY_ABORT-回复中止, THINK-思考, THINK_ABORT-思考中止)
     private Long duration;
 
     public ChatSseMessage(String contentType, String content) {

+ 2 - 0
src/main/java/com/backendsys/modules/ai/chat/service/ChatService.java

@@ -10,6 +10,8 @@ public interface ChatService {
 
     // 发起对话
     Map<String, Object> sendChat(Chat chat);
+    // 中止对话
+    Map<String, Object> abortChat(Chat chat);
 
     // 获取我的对话列表
     PageEntity selectChatList(Chat chat);

+ 33 - 1
src/main/java/com/backendsys/modules/ai/chat/service/impl/ChatServiceImpl.java

@@ -1,6 +1,7 @@
 package com.backendsys.modules.ai.chat.service.impl;
 
 import cn.hutool.core.bean.BeanUtil;
+import cn.hutool.core.util.ObjectUtil;
 import cn.hutool.core.util.StrUtil;
 import com.backendsys.exception.CustException;
 import com.backendsys.modules.ai.chat.dao.ChatDao;
@@ -9,6 +10,7 @@ import com.backendsys.modules.ai.chat.entity.Chat;
 import com.backendsys.modules.ai.chat.entity.ChatHistory;
 import com.backendsys.modules.ai.chat.entity.ChatResult;
 import com.backendsys.modules.ai.chat.service.ChatService;
+import com.backendsys.modules.common.config.redis.utils.RedisUtil;
 import com.backendsys.modules.common.config.security.utils.SecurityUtil;
 import com.backendsys.modules.sdk.deepseek.service.DeepSeekClient;
 import com.backendsys.modules.sdk.deepseek.utils.OllamaUtil;
@@ -18,6 +20,7 @@ import com.backendsys.utils.response.PageInfoResult;
 import com.backendsys.utils.v2.PageUtils;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
@@ -37,6 +40,11 @@ public class ChatServiceImpl implements ChatService {
     private DeepSeekClient deepSeekClient;
     @Autowired
     private OllamaUtil ollamaUtil;
+    @Autowired
+    private RedisUtil redisUtil;
+
+    @Value("${spring.application.name}")
+    private String APPLICATION_NAME;
 
     @Autowired
     private ChatDao chatDao;
@@ -101,6 +109,16 @@ public class ChatServiceImpl implements ChatService {
 
         // -- 对话 -----------------------------------------------------
 
+
+        // -- 创建一个对话缓存键值 (用于停止请求) ---------------------------
+        String requestOfRedisKey = APPLICATION_NAME + "-chat-history-" + history_code;
+        redisUtil.setCacheObject(requestOfRedisKey, 1);
+//
+//        // -- 是否对话进行中 ---------------------------------------------
+//        if (ObjectUtil.isNotEmpty(redisUtil.getCacheObject(requestOfRedisKey))) {
+//            throw new CustException("请等待对话结束");
+//        }
+
         // 创建一个 CompletableFuture 来执行异步任务
         String final_history_code = history_code;
         CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
@@ -130,7 +148,7 @@ public class ChatServiceImpl implements ChatService {
                 ChatResult chatResult = null;
                 // -- [Deepseek R1] -------------------------------------------------------
                 if ("DEEPSEEK_R1".equals(model) || "GEMMA".equals(model)) {
-                    chatResult = ollamaUtil.chatDeepSeek(user_id, model_version, prompt, final_history_code, chatList);
+                    chatResult = ollamaUtil.chatCompletion(user_id, model_version, prompt, final_history_code, chatList);
                 }
                 // -- [Deepseek Api] ------------------------------------------------------
                 if ("DEEPSEEK_API".equals(model)) {
@@ -180,6 +198,9 @@ public class ChatServiceImpl implements ChatService {
 
             } catch (Exception e) {
                 System.out.println(e.getMessage());
+            } finally {
+                // 流程结束后,删除锁
+                redisUtil.delete(requestOfRedisKey);
             }
         });
         // ------------------------------------------------------------
@@ -189,6 +210,17 @@ public class ChatServiceImpl implements ChatService {
     }
 
 
+    /**
+     * 中止对话
+     */
+    @Override
+    public Map<String, Object> abortChat(Chat chat) {
+        String history_code = chat.getHistory_code();
+        String requestOfRedisKey = APPLICATION_NAME + "-chat-history-" + history_code;
+        redisUtil.delete(requestOfRedisKey);
+        return Map.of("history_code", history_code);
+    }
+
     /**
      * 获取我的对话列表 (升序, 最新的在最后面出现)
      */

+ 1 - 1
src/main/java/com/backendsys/modules/sdk/deepseek/controller/DeepSeekController.java

@@ -56,7 +56,7 @@ public class DeepSeekController {
     @PreAuthorize("@sr.hasPermission('101')")
     @PostMapping("/api/deepSeek/chatLocal")
     public Result chatLocal(@Validated @RequestBody DSParam param) {
-        return Result.success().put("data", ollamaUtil.chatDeepSeek(SecurityUtil.getUserId(), param.getModel(), param.getPrompt(), null, null));
+        return Result.success().put("data", ollamaUtil.chatCompletion(SecurityUtil.getUserId(), param.getModel(), param.getPrompt(), null, null));
     }
 
     @Operation(summary = "DS-获得模型列表")

+ 71 - 28
src/main/java/com/backendsys/modules/sdk/deepseek/utils/OllamaUtil.java

@@ -1,11 +1,14 @@
 package com.backendsys.modules.sdk.deepseek.utils;
 
 import cn.hutool.core.convert.Convert;
+import cn.hutool.core.util.NumberUtil;
+import cn.hutool.core.util.ObjectUtil;
 import cn.hutool.core.util.StrUtil;
 import com.alibaba.fastjson.JSONObject;
 import com.backendsys.modules.ai.chat.entity.Chat;
 import com.backendsys.modules.ai.chat.entity.ChatResult;
 import com.backendsys.modules.ai.chat.entity.ChatSseMessage;
+import com.backendsys.modules.common.config.redis.utils.RedisUtil;
 import com.backendsys.modules.sdk.deepseek.entity.DSRequest;
 import com.backendsys.modules.sdk.deepseek.entity.DSRequestMessage;
 import com.backendsys.modules.sse.entity.SseResponse;
@@ -38,6 +41,8 @@ public class OllamaUtil {
 
     @Autowired
     private SseUtil sseUtil;
+    @Autowired
+    private RedisUtil redisUtil;
 
     @Value("${spring.application.name}")
     private String APPLICATION_NAME;
@@ -48,19 +53,17 @@ public class OllamaUtil {
     /**
      * 流式对话
      */
-    public ChatResult chatDeepSeek(Long user_id, String model, String prompt, String history_code, List<Chat> chatList) {
+    public ChatResult chatCompletion(Long user_id, String model, String prompt, String history_code, List<Chat> chatList) {
 
-        long contentDuration = 0L;
+        Long contentDuration = 0L;
 
         ChatResult chatResult = new ChatResult();
-        try {
+//        try {
             System.out.println("向模型: " + model + " 提问: " + prompt);
 
             // 记录请求开始时间
             long allStartTime = System.currentTimeMillis();
 
-
-
             // 加入上下文历史对话
             System.out.println("---- 历史对话 (history_code): " + history_code + " ----");
 
@@ -84,7 +87,11 @@ public class OllamaUtil {
             });
             System.out.println("---------------------------------------------------------------------");
 
-
+            // 定义作用于全局的变量
+            Boolean isThinking = false;
+            StringBuilder allReplyContent = new StringBuilder();
+            StringBuilder allThinkContent = new StringBuilder();
+            String requestOfRedisKey = APPLICATION_NAME + "-chat-history-" + history_code;
 
             ObjectMapper objectMapper = new ObjectMapper();
             try (CloseableHttpClient client = HttpClients.createDefault()) {
@@ -135,6 +142,7 @@ public class OllamaUtil {
 
                 request.setEntity(new StringEntity(requestBody, StandardCharsets.UTF_8));
 
+
                 try (CloseableHttpResponse response = client.execute(request);
                      BufferedReader reader = new BufferedReader(new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8))) {
 
@@ -142,7 +150,7 @@ public class OllamaUtil {
 
                     long thinkStartTime = 0L;                            // 开始思考时间
                     long thinkDuration = 0L;                             // 思考耗时
-                    Boolean isThinking = false;
+
 
                     System.out.println("API 调用耗时: " + apiDuration + " 毫秒");
                     System.out.println("---- 开始流式回答: ------------------------------------");
@@ -151,12 +159,20 @@ public class OllamaUtil {
                     ChatSseMessage chatLoadingSseMessage = new ChatSseMessage("LOADING", "正在思考");
                     sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatLoadingSseMessage).toJsonStr());
 
-                    StringBuilder allReplyContent = new StringBuilder();
-                    StringBuilder allThinkContent = new StringBuilder();
-
                     String line;
                     while ((line = reader.readLine()) != null) {
 
+
+                        // 判断是否中止
+                        if (ObjectUtil.isEmpty(redisUtil.getCacheObject(requestOfRedisKey))) {
+                            System.out.println("中止!");
+                            request.abort();
+                            // 流程结束后,删除锁
+                            redisUtil.delete(requestOfRedisKey);
+                            break;
+                        }
+
+
                         // System.out.println(line);
                         /*
                             ---------------------- [Chat] line ----------------------
@@ -259,47 +275,74 @@ public class OllamaUtil {
                     chatResult.setContent(allReplyContent.toString());
                     chatResult.setContent_duration(contentDuration);
 
+                    // [SSE] 发送消息 (完成)
+                    ChatSseMessage chatSseMessage = new ChatSseMessage("REPLY", "[DONE][REPLY]", contentDuration);
+                    sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
+
                     return chatResult;
 
                 } catch (Exception e) {
                     System.out.println("Exception(1): " + e.getMessage());
                     String message = e.getMessage();
                     if (message.contains("failed to respond")) {
-                        message = "系统繁忙,请稍后再试 (Failed to respond)";
+                        message = "(系统繁忙,请稍后再试)";
+                    }
+                    if (message.contains("Premature end of chunk coded message body: closing chunk expected")) {
+                        message = "(请求中止)";
                     }
                     // [SSE] 发送消息
-                    ChatSseMessage chatSseMessage = new ChatSseMessage("REPLY", message, contentDuration);
+                    String contentType = (isThinking ? "THINK_ABORT" : "REPLY_ABORT");
+                    ChatSseMessage chatSseMessage = new ChatSseMessage(contentType, message, contentDuration);
                     sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
 
-                    chatResult.setContent(e.getMessage());
-                    e.printStackTrace();
+//                    chatResult.setContent(e.getMessage());
+
+                    // 由于中止导致的错误信息叠加 (一并保存进数据库)
+                    if (StrUtil.isNotEmpty(allThinkContent.toString())) {
+                        if (isThinking) {
+                            chatResult.setReasoning_content(allThinkContent.toString() + " " + message);
+                        } else {
+                            chatResult.setReasoning_content(allThinkContent.toString());
+                        }
+                    }
+                    chatResult.setContent(allReplyContent.toString() + " " + message);
+
+                    redisUtil.delete(requestOfRedisKey);
 
                     return chatResult;
+
+//                    return chatResult;
                 }
             } catch (Exception e) {
                 System.out.println("Exception(2): " + e.getMessage());
                 // [SSE] 发送消息
-                ChatSseMessage chatSseMessage = new ChatSseMessage("REPLY", e.getMessage(), contentDuration);
+                String contentType = (isThinking ? "THINK_ABORT" : "REPLY_ABORT");
+                ChatSseMessage chatSseMessage = new ChatSseMessage(contentType, e.getMessage(), contentDuration);
                 sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
 
+                redisUtil.delete(requestOfRedisKey);
+
                 chatResult.setContent(e.getMessage());
                 return chatResult;
             }
 
 
-        } catch (Exception e) {
-            System.out.println("Exception(3): " + e.getMessage());
-            // [SSE] 发送消息
-            ChatSseMessage chatSseMessage = new ChatSseMessage("REPLY", e.getMessage(), contentDuration);
-            sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
-
-            chatResult.setContent(e.getMessage());
-            return chatResult;
-        } finally {
-            // [SSE] 发送消息
-            ChatSseMessage chatSseMessage = new ChatSseMessage("REPLY", "[DONE][REPLY]", contentDuration);
-            sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
-        }
+//        } catch (Exception e) {
+//            System.out.println("Exception(3): " + e.getMessage());
+//            // [SSE] 发送消息
+//            ChatSseMessage chatSseMessage = new ChatSseMessage("REPLY", e.getMessage(), contentDuration);
+//            sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
+//
+//            chatResult.setContent(e.getMessage());
+//            return chatResult;
+//        }
+
+//        } finally {
+//            System.out.println("Finally.");
+//            // [SSE] 发送消息
+//            ChatSseMessage chatSseMessage = new ChatSseMessage("REPLY", "[DONE][REPLY]", contentDuration);
+//            sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
+//        }
 
     }
 }

+ 2 - 1
src/main/java/com/backendsys/modules/sse/controller/SseController.java

@@ -50,7 +50,8 @@ public class SseController {
 //        System.out.println("emitterKey = " + emitterKey);
 
 //        SseEmitterUTF8 emitter = new SseEmitterUTF8(Long.MAX_VALUE);
-        SseEmitterUTF8 emitter = new SseEmitterUTF8(60 * 1000L);
+        SseEmitterUTF8 emitter = new SseEmitterUTF8(2 * 60 * 1000L);
+
         SseEmitterManager manager = SseEmitterManager.getInstance();
         manager.addEmitter(emitterKey, emitter);
         try {