Преглед на файлове

完善DeepseekR1的对话上下文功能

tsurumure преди 5 месеца
родител
ревизия
173ca77c64

+ 0 - 2
db/ai_chat.sql

@@ -16,8 +16,6 @@ CREATE TABLE `ai_chat` (
     `content` TEXT NOT NULL COMMENT '对话内容',
     `duration` BIGINT COMMENT '耗时',
     `user_id` BIGINT COMMENT '用户ID',
-#     `user_nickname` VARCHAR(255) COMMENT '用户名',
-#     `user_avatar` VARCHAR(255) COMMENT '用户头像',
     `robot_code` VARCHAR(255) COMMENT '机器人CODE',
     `del_flag` TINYINT(1) DEFAULT '-1' COMMENT '删除标志 (-1未删除, 1删除)',
     `create_time` DATETIME DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',

+ 1 - 1
db/ai_chat_history.sql

@@ -11,7 +11,7 @@ CREATE TABLE `ai_chat_history` (
     `history_code` VARCHAR(36) NOT NULL COMMENT '对话历史记录ID',
     `user_id` BIGINT NOT NULL COMMENT '用户ID',
     `robot_code` VARCHAR(255) COMMENT '机器人CODE',
-    `last_prompt` VARCHAR(255) NOT NULL COMMENT '最后一次的对话内容',
+    `last_prompt` TEXT NOT NULL COMMENT '最后一次的对话内容',
     `del_flag` TINYINT(1) DEFAULT '-1' COMMENT '删除标志 (-1未删除, 1删除)',
     `create_time` DATETIME DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
     `update_time` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',

+ 5 - 5
src/main/java/com/backendsys/modules/ai/chat/service/impl/ChatServiceImpl.java

@@ -102,7 +102,7 @@ public class ChatServiceImpl implements ChatService {
         // -- 对话 -----------------------------------------------------
 
         // 创建一个 CompletableFuture 来执行异步任务
-        String finalHistory_code = history_code;
+        String final_history_code = history_code;
         CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
             try {
 
@@ -112,7 +112,7 @@ public class ChatServiceImpl implements ChatService {
                 List<Chat> chatList = chatDao.selectList(
                     new LambdaQueryWrapper<Chat>()
                         .eq(Chat::getDel_flag, -1)
-                        .eq(Chat::getHistory_code, finalHistory_code)
+                        .eq(Chat::getHistory_code, final_history_code)
                         .and(wrapper -> wrapper
                             .ne(Chat::getContent_type, "THINK")
                             .or()
@@ -132,16 +132,16 @@ public class ChatServiceImpl implements ChatService {
                 ChatResult chatResult = null;
                 // -- [Deepseek R1] -------------------------------------------------------
                 if ("DEEPSEEK_R1".equals(model)) {
-                    chatResult = ollamaUtil.chatDeepSeek(user_id, model_version, prompt, chatList);
+                    chatResult = ollamaUtil.chatDeepSeek(user_id, model_version, prompt, final_history_code, chatList);
                 }
                 // -- [Deepseek Api] ------------------------------------------------------
                 if ("DEEPSEEK_API".equals(model)) {
                     // [DeepSeek] 发起对话
-                    chatResult = deepSeekClient.chatCompletion(user_id, model_version, prompt, chatList);
+                    chatResult = deepSeekClient.chatCompletion(user_id, model_version, prompt, final_history_code, chatList);
                 }
                 // -- [腾讯混元大模型] -------------------------------------------------------
                 if ("HUNYUAN".equals(model)) {
-                    chatResult = hunYuanClient.chatCompletion(user_id, prompt, chatList);
+                    chatResult = hunYuanClient.chatCompletion(user_id, prompt, final_history_code, chatList);
                 }
                 // ------------------------------------------------------------------------
 

+ 2 - 2
src/main/java/com/backendsys/modules/sdk/deepseek/controller/DeepSeekController.java

@@ -44,7 +44,7 @@ public class DeepSeekController {
 //        CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
 //            deepSeekClient.chatCompletion(param.getModel(), param.getPrompt());
 //        });
-        return Result.success().put("data", deepSeekClient.chatCompletion(SecurityUtil.getUserId(), param.getModel(), param.getPrompt(), null));
+        return Result.success().put("data", deepSeekClient.chatCompletion(SecurityUtil.getUserId(), param.getModel(), param.getPrompt(), null, null));
     }
 
     /**
@@ -56,7 +56,7 @@ public class DeepSeekController {
     @PreAuthorize("@sr.hasPermission('101')")
     @PostMapping("/api/deepSeek/chatLocal")
     public Result chatLocal(@Validated @RequestBody DSParam param) {
-        return Result.success().put("data", ollamaUtil.chatDeepSeek(SecurityUtil.getUserId(), param.getModel(), param.getPrompt(), null));
+        return Result.success().put("data", ollamaUtil.chatDeepSeek(SecurityUtil.getUserId(), param.getModel(), param.getPrompt(), null, null));
     }
 
     @Operation(summary = "DS-获得模型列表")

+ 1 - 1
src/main/java/com/backendsys/modules/sdk/deepseek/service/DeepSeekClient.java

@@ -9,7 +9,7 @@ import java.util.List;
 public interface DeepSeekClient {
 
     // [DeepSeek] 发起对话
-    ChatResult chatCompletion(Long user_id, String model, String prompt, List<Chat> chatList);
+    ChatResult chatCompletion(Long user_id, String model, String prompt, String history_code, List<Chat> chatList);
 
     // [DeepSeek] 获得模型
     JSONArray getModels(Long user_id);

+ 5 - 6
src/main/java/com/backendsys/modules/sdk/deepseek/service/impl/DeepSeekClientImpl.java

@@ -53,7 +53,7 @@ public class DeepSeekClientImpl implements DeepSeekClient {
      * - 文档:https://api-docs.deepseek.com/zh-cn/api/create-chat-completion
      */
     @Override
-    public ChatResult chatCompletion(Long user_id, String model, String prompt, List<Chat> chatList) {
+    public ChatResult chatCompletion(Long user_id, String model, String prompt, String history_code, List<Chat> chatList) {
 
         long replyDuration = 0L;
 
@@ -71,7 +71,7 @@ public class DeepSeekClientImpl implements DeepSeekClient {
             // - (还没做) 超过50条记录的限制
 
             // 加入上下文历史对话
-            System.out.println("---------------------- 历史对话: ----------------------");
+            System.out.println("---- 历史对话 (history_code): " + history_code + " ----");
 
             List<DSRequestMessage> messages = new ArrayList<>();
             if (chatList != null && !chatList.isEmpty()) {
@@ -91,13 +91,12 @@ public class DeepSeekClientImpl implements DeepSeekClient {
             messages.stream().forEach(msg -> {
                 System.out.println("[" + msg.getRole() + "]: " + msg.getContent());
             });
-            System.out.println("-----------------------------------------------------");
+            System.out.println("---------------------------------------------------------------------");
 
 
             // 构建请求体
+            // - model: (deepseek-chat: 对话模型, deepseek-reasoner: 推理模型)
             DSRequest body = new DSRequest();
-
-            // model: (deepseek-chat: 对话模型, deepseek-reasoner: 推理模型)
             body.setModel(model);
             body.setMessages(messages);
             body.setStream(true);
@@ -125,7 +124,7 @@ public class DeepSeekClientImpl implements DeepSeekClient {
                     Boolean isThinking = false;
 
                     System.out.println("API 调用耗时: " + apiDuration + " 毫秒");
-                    System.out.println("-------------------- 开始流式回答: --------------------");
+                    System.out.println("---- 开始流式回答: ------------------------------------");
 
                     // [SSE] 发送消息
                     ChatSseMessage chatLoadingSseMessage = new ChatSseMessage("LOADING", "正在思考");

+ 88 - 17
src/main/java/com/backendsys/modules/sdk/deepseek/utils/OllamaUtil.java

@@ -6,6 +6,8 @@ import com.alibaba.fastjson.JSONObject;
 import com.backendsys.modules.ai.chat.entity.Chat;
 import com.backendsys.modules.ai.chat.entity.ChatResult;
 import com.backendsys.modules.ai.chat.entity.ChatSseMessage;
+import com.backendsys.modules.sdk.deepseek.entity.DSRequest;
+import com.backendsys.modules.sdk.deepseek.entity.DSRequestMessage;
 import com.backendsys.modules.sse.entity.SseResponse;
 import com.backendsys.modules.sse.entity.SseResponseEnum;
 import com.backendsys.modules.sse.utils.SseUtil;
@@ -22,9 +24,7 @@ import org.springframework.stereotype.Component;
 import java.io.BufferedReader;
 import java.io.InputStreamReader;
 import java.nio.charset.StandardCharsets;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
 import java.util.concurrent.CompletableFuture;
 
 /**
@@ -48,7 +48,7 @@ public class OllamaUtil {
     /**
      * 流式对话
      */
-    public ChatResult chatDeepSeek(Long user_id, String model, String prompt, List<Chat> chatList) {
+    public ChatResult chatDeepSeek(Long user_id, String model, String prompt, String history_code, List<Chat> chatList) {
 
         long contentDuration = 0L;
 
@@ -59,22 +59,80 @@ public class OllamaUtil {
             // 记录请求开始时间
             long allStartTime = System.currentTimeMillis();
 
+
+
             // 加入上下文历史对话
-            System.out.println("---------------------- 历史对话: ----------------------");
-            System.out.println("-----------------------------------------------------");
+            System.out.println("---- 历史对话 (history_code): " + history_code + " ----");
+
+            List<DSRequestMessage> messages = new ArrayList<>();
+            if (chatList != null && !chatList.isEmpty()) {
+                chatList.stream().forEach(chat -> {
+                    if (!"THINK".equals(chat.getContent_type())) {
+                        messages.add(new DSRequestMessage(chat.getRole(), chat.getContent()));
+                    }
+                });
+                // 反转列表
+                Collections.reverse(messages);
+            }
+
+            // 新的对话内容
+            messages.add(new DSRequestMessage("user", prompt));
+
+            // 输出全部对话内容
+            messages.stream().forEach(msg -> {
+                System.out.println("[" + msg.getRole() + "]: " + msg.getContent());
+            });
+            System.out.println("---------------------------------------------------------------------");
 
 
 
             ObjectMapper objectMapper = new ObjectMapper();
             try (CloseableHttpClient client = HttpClients.createDefault()) {
 
-                HttpPost request = new HttpPost(DOMAIN + "/api/generate");
-                Map<String, Object> requestMap = new HashMap<>();
-                requestMap.put("model", model);
-                requestMap.put("prompt", prompt);
-                requestMap.put("stream", true);
+                /*
+                【/api/generate】
+                它是一个相对基础的文本生成端点,主要用于根据给定的提示信息生成一段连续的文本。
+                这个端点会基于输入的提示,按照模型的语言生成能力输出一段完整的内容,更侧重于单纯的文本生成任务。
+                生成过程不依赖于上下文的历史对话信息,每次请求都是独立的,模型仅依据当前输入的提示进行文本生成。
+                {
+                    "model": "llama2", "prompt": "请描述一下美丽的海滩", "num_predict": 200, "temperature": 0.7
+                }
+
+                【/api/chat】
+                该端点专为模拟聊天场景设计,具备处理对话上下文的能力。它可以跟踪对话的历史记录,理解对话的上下文信息,从而生成更符合对话逻辑和连贯性的回复。
+                更注重模拟真实的人机对话交互,能够根据历史对话和当前输入生成合适的回应,适用于构建聊天机器人等交互式应用。
+                {
+                    "model": "deepseek-r1:1.5b",
+                    "messages": [
+                        {
+                            "role": "system",
+                            "content": "你是一个能够理解中文指令并帮助完成任务的智能助手。你的任务是根据用户的需求生成合适的分类任务或生成任务,并准确判断这些任务的类型。请确保你的回答简洁、准确且符合中英文语境。"
+                        },
+                        {
+                            "role": "user",
+                            "content": "写一个简单的 Python 函数,用于计算两个数的和"
+                        }
+                    ],
+                    "stream": false
+                }
+                 */
+
+                // [Chat] 构建请求体
+                HttpPost request = new HttpPost(DOMAIN + "/api/chat");
+                DSRequest body = new DSRequest();
+                body.setModel(model);
+                body.setMessages(messages);
+                body.setStream(true);
+                String requestBody = objectMapper.writeValueAsString(body);
+
+//                // [Generate] 构建请求体
+//                HttpPost request = new HttpPost(DOMAIN + "/api/generate");
+//                Map<String, Object> requestMap = new HashMap<>();
+//                requestMap.put("model", model);
+//                requestMap.put("prompt", prompt);
+//                requestMap.put("stream", true);
+//                String requestBody = objectMapper.writeValueAsString(requestMap);
 
-                String requestBody = objectMapper.writeValueAsString(requestMap);
                 request.setEntity(new StringEntity(requestBody, StandardCharsets.UTF_8));
 
                 try (CloseableHttpResponse response = client.execute(request);
@@ -87,7 +145,7 @@ public class OllamaUtil {
                     Boolean isThinking = false;
 
                     System.out.println("API 调用耗时: " + apiDuration + " 毫秒");
-                    System.out.println("-------------------- 开始流式回答: --------------------");
+                    System.out.println("---- 开始流式回答: ------------------------------------");
 
                     // [SSE] 发送消息
                     ChatSseMessage chatLoadingSseMessage = new ChatSseMessage("LOADING", "正在思考");
@@ -101,15 +159,26 @@ public class OllamaUtil {
 
                          System.out.println(line);
                         /*
-                            ---------------------- line ----------------------
+                            ---------------------- [Chat] line ----------------------
+                            {"model":"deepseek-r1:1.5b","created_at":"2025-03-18T07:37:06.483163789Z","message":{"role":"assistant","content":"\u003cthink\u003e"},"done":false}
+                            ---------------------- [Generate] line ----------------------
                             {"model":"deepseek-r1:1.5b","created_at":"2025-03-05T10:51:17.443189986Z","response":"\u003cthink\u003e","done":false}
                             {"model":"deepseek-r1:1.5b","created_at":"2025-03-06T11:08:30.9219611Z","response":"\n\n","done":false}
-                            --------------------------------------------------
+                            -------------------------------------------------------------
                          */
 
                         // 每行数据可以是一个JSON对象,根据实际情况处理
                         JSONObject resJson = JSONObject.parseObject(line);
-                        String content = resJson.getString("response");
+
+                        // --------------------------------------------------------------
+                        // [Chat]
+                        JSONObject resJsonMessage = resJson.getJSONObject("message");
+                        String content = resJsonMessage.getString("content");
+
+//                        // [Generate]
+//                        String content = resJson.getString("response");
+                        // --------------------------------------------------------------
+
 
                         // System.out.println("content: " + content);
                         // content: \n\n
@@ -126,7 +195,7 @@ public class OllamaUtil {
                             isThinking = false;
                             thinkDuration = thinkStartTime - allStartTime;
                             System.out.println("推理耗时: " + thinkDuration + "毫秒");
-                            System.out.println("-----------------------------------------------");
+                            System.out.println("-----------------------------------------------------");
 
                             if (allThinkContent.length() > 0){
                                 // [SSE] 发送消息
@@ -201,6 +270,8 @@ public class OllamaUtil {
                     sseUtil.send(user_id, new SseResponse(SseResponseEnum.DEEPSEEK, chatSseMessage).toJsonStr());
 
                     chatResult.setContent(e.getMessage());
+                    e.printStackTrace();
+
                     return chatResult;
                 }
             } catch (Exception e) {

+ 1 - 1
src/main/java/com/backendsys/modules/sdk/tencentcloud/huanyuan/service/HunYuanClient.java

@@ -8,6 +8,6 @@ import java.util.List;
 public interface HunYuanClient {
 
     // [HunYuan] 发起对话
-    ChatResult chatCompletion(Long user_id, String prompt, List<Chat> chatList);
+    ChatResult chatCompletion(Long user_id, String prompt, String history_code, List<Chat> chatList);
 
 }

+ 3 - 3
src/main/java/com/backendsys/modules/sdk/tencentcloud/huanyuan/service/impl/HunYuanClientImpl.java

@@ -57,7 +57,7 @@ public class HunYuanClientImpl implements HunYuanClient {
      * https://cloud.tencent.com/document/product/1729/101836
      */
     @Override
-    public ChatResult chatCompletion(Long user_id, String prompt, List<Chat> chatList) {
+    public ChatResult chatCompletion(Long user_id, String prompt, String history_code, List<Chat> chatList) {
 
         long replyDuration = 0L;
 
@@ -71,7 +71,7 @@ public class HunYuanClientImpl implements HunYuanClient {
             long allStartTime = System.currentTimeMillis();
 
             // 加入上下文历史对话
-            System.out.println("---------------------- 历史对话: ----------------------");
+            System.out.println("---- 历史对话 (history_code): " + history_code + " ----");
 
             List<Message> messages = new ArrayList<>();
             if (chatList != null && !chatList.isEmpty()) {
@@ -90,7 +90,7 @@ public class HunYuanClientImpl implements HunYuanClient {
             messages.stream().forEach(msg -> {
                 System.out.println("[" + msg.getRole() + "]: " + msg.getContent());
             });
-            System.out.println("-----------------------------------------------------");
+            System.out.println("---------------------------------------------------------------------");
 
 
             // [混元大模型]