tsurumure 10 mesiacov pred
rodič
commit
1fa44cb5fe

+ 3 - 1
src/main/java/com/backendsys/modules/sse/controller/SseController.java

@@ -1,5 +1,6 @@
 package com.backendsys.modules.sse.controller;
 
+import cn.hutool.core.convert.Convert;
 import com.backendsys.modules.sse.emitter.SseEmitterManager;
 import com.backendsys.modules.sse.utils.SseUtil;
 import org.springframework.beans.factory.annotation.Autowired;
@@ -24,9 +25,10 @@ public class SseController {
 
     @GetMapping(value = "/api/sse/stream", produces = "text/event-stream")
     public SseEmitter stream() {
+        String userId = Convert.toStr(1L);
         SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
         SseEmitterManager manager = SseEmitterManager.getInstance();
-        manager.addEmitter(emitter);
+        manager.addEmitter(userId, emitter);
         executor.execute(() -> {
             try {
                 emitter.send(SseEmitter.event().data("success"));

+ 12 - 4
src/main/java/com/backendsys/modules/sse/emitter/SseEmitterManager.java

@@ -2,13 +2,16 @@ package com.backendsys.modules.sse.emitter;
 
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CopyOnWriteArrayList;
 
 public class SseEmitterManager {
     // 单例实例
     private static final SseEmitterManager INSTANCE = new SseEmitterManager();
     // 存储SseEmitter的线程安全列表
-    public final CopyOnWriteArrayList<SseEmitter> emitters = new CopyOnWriteArrayList<>();
+//    public final CopyOnWriteArrayList<Long, SseEmitter> emitters = new CopyOnWriteArrayList<>();
+    public final ConcurrentHashMap<String, SseEmitter> emitters = new ConcurrentHashMap<>();
+
     // 私有构造函数,防止外部直接实例化
     private SseEmitterManager() {}
     // 公共静态方法,获取单例实例
@@ -16,13 +19,18 @@ public class SseEmitterManager {
         return INSTANCE;
     }
     // 公共方法,供外部添加SseEmitter
-    public void addEmitter(SseEmitter emitter) {
-        this.emitters.add(emitter);
+    public void addEmitter(String userId, SseEmitter emitter) {
+        this.emitters.put(userId, emitter);
         emitter.onTimeout(() -> this.emitters.remove(emitter));
         emitter.onCompletion(() -> this.emitters.remove(emitter));
     }
     // 公共方法,供外部移除SseEmitter
+    public SseEmitter getEmitter(String userId) {
+        // 根据用户ID获取 SseEmitter
+        return this.emitters.get(userId);
+    }
+    // 公共方法,供外部移除SseEmitter
     public void removeEmitter(SseEmitter emitter) {
         this.emitters.remove(emitter);
     }
-}
+}

+ 8 - 2
src/main/java/com/backendsys/modules/sse/utils/SseUtil.java

@@ -1,5 +1,6 @@
 package com.backendsys.modules.sse.utils;
 
+import cn.hutool.core.convert.Convert;
 import com.backendsys.modules.sse.emitter.SseEmitterManager;
 import org.springframework.stereotype.Component;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
@@ -11,12 +12,17 @@ public class SseUtil {
 
     // [SSE] 发送消息
     public void send(String data) {
+
+        Long userId = 1L;
+
         SseEmitterManager manager = SseEmitterManager.getInstance();
-        for (SseEmitter emitter : manager.emitters) {
+        SseEmitter emitter = manager.getEmitter(Convert.toStr(userId));
+        if (emitter != null) {
             try {
                 emitter.send(SseEmitter.event().data(data));
             } catch (IOException e) {
-                manager.emitters.remove(emitter);
+                System.out.println(e.getMessage());
+                manager.removeEmitter(emitter);
             }
         }
     }