dcs 4 dienas atpakaļ
vecāks
revīzija
fb9638e778

+ 6 - 2
virgo.api/src/main/java/com/bosshand/virgo/api/workark/controller/DifyController.java

@@ -3,12 +3,14 @@ package com.bosshand.virgo.api.workark.controller;
 import com.bosshand.virgo.api.workark.model.*;
 import com.bosshand.virgo.api.workark.service.DifyDatasetService;
 import com.bosshand.virgo.api.workark.service.DifyService;
+import com.bosshand.virgo.api.workark.util.SseEmitterUtil;
 import com.bosshand.virgo.core.response.Response;
 import com.bosshand.virgo.exception.Constant;
 import io.swagger.annotations.*;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.web.bind.annotation.*;
 import org.springframework.web.multipart.MultipartFile;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -135,8 +137,10 @@ public class DifyController {
     @ApiImplicitParams({
             @ApiImplicitParam(name = "difyTypeId", value = "dify类型id")
     })
-    public Response chatRun(@PathVariable long difyTypeId, @RequestBody Map<String, Object> inputs) {
-        return Response.ok(difyService.chatRun(difyTypeId, inputs));
+    public SseEmitter chatRun(@PathVariable long difyTypeId, @RequestBody Map<String, Object> inputs) {
+        String simpleUUID = difyService.chatRun(difyTypeId, inputs);
+        // 用于创建连接sse
+        return SseEmitterUtil.connect(simpleUUID);
     }
 
     @ApiOperation("获取对话型Chat执行情况")

+ 6 - 0
virgo.api/src/main/java/com/bosshand/virgo/api/workark/service/DifyService.java

@@ -3,6 +3,7 @@ package com.bosshand.virgo.api.workark.service;
 import com.alibaba.fastjson.JSONObject;
 import com.bosshand.virgo.api.workark.dao.*;
 import com.bosshand.virgo.api.workark.model.*;
+import com.bosshand.virgo.api.workark.util.SseEmitterUtil;
 import com.bosshand.virgo.core.model.UserContext;
 import com.bosshand.virgo.core.utils.ContextUtils;
 import com.bosshand.virgo.core.utils.StringUtil;
@@ -364,6 +365,8 @@ public class DifyService {
                 public void onMessage(MessageEvent event) {
                     //System.out.println("收到消息片段: " + event.getAnswer());
                     chatM.append(event.getAnswer());
+                    // 发送消息
+                    SseEmitterUtil.sendMessage(simpleUUID, event.getAnswer());
                 }
 
                 @Override
@@ -380,6 +383,7 @@ public class DifyService {
                     difyChat.setStatus("succeeded");
                     difyChatDao.update(difyChat);
                     System.out.println("消息结束,完整消息ID: " + event.getMessageId());
+                    SseEmitterUtil.removeUser(simpleUUID);
                 }
 
                 @Override
@@ -406,6 +410,8 @@ public class DifyService {
                 public void onAgentMessage(AgentMessageEvent event) {
                     //System.out.println("收到Agent消息片段: " + event.getAnswer());
                     agentM.append(event.getAnswer());
+                    // 发送消息
+                    SseEmitterUtil.sendMessage(simpleUUID, event.getAnswer());
                 }
 
                 @Override

+ 132 - 0
virgo.api/src/main/java/com/bosshand/virgo/api/workark/util/SseEmitterUtil.java

@@ -0,0 +1,132 @@
+package com.bosshand.virgo.api.workark.util;
+
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.http.MediaType;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Consumer;
+
+/**
+ * SSE长链接工具类
+ */
+@Slf4j
+public class SseEmitterUtil {
+
+    /**
+     * 使用map对象,便于根据simpleUUID来获取对应的SseEmitter,或者放redis里面
+     */
+    private final static Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();
+
+    public static SseEmitter connect(String simpleUUID) {
+        // 设置超时时间,0表示不过期。默认30S,超时时间未完成会抛出异常:AsyncRequestTimeoutException
+        SseEmitter sseEmitter = new SseEmitter(0L);
+
+        // 注册回调
+        sseEmitter.onCompletion(completionCallBack(simpleUUID));
+        sseEmitter.onError(errorCallBack(simpleUUID));
+        sseEmitter.onTimeout(timeoutCallBack(simpleUUID));
+        sseEmitterMap.put(simpleUUID, sseEmitter);
+
+        log.info("创建新的 SSE 连接,当前用户 {}, 连接总数 {}", simpleUUID, sseEmitterMap.size());
+        return sseEmitter;
+    }
+
+    /**
+     * 给制定用户发送消息
+     *
+     * @param simpleUUID 指定用户名
+     * @param sseMessage 消息体
+     */
+    public static void sendMessage(String simpleUUID, String sseMessage) {
+        if (sseEmitterMap.containsKey(simpleUUID)) {
+            try {
+                sseEmitterMap.get(simpleUUID).send(sseMessage);
+                log.info("用户 {} 推送消息 {}", simpleUUID, sseMessage);
+            } catch (IOException e) {
+                log.error("用户 {} 推送消息异常", simpleUUID, e);
+                removeUser(simpleUUID);
+            }
+        } else {
+            log.error("消息推送 用户 {} 不存在,链接总数 {}", simpleUUID, sseEmitterMap.size());
+        }
+    }
+
+    /**
+     * 群发消息
+     */
+    public static void batchSendMessage(String message, List<String> ids) {
+        ids.forEach(simpleUUID -> sendMessage(simpleUUID, message));
+    }
+
+    /**
+     * 群发所有人
+     */
+    public static void batchSendMessage(String message) {
+        sseEmitterMap.forEach((k, v) -> {
+            try {
+                v.send(message, MediaType.APPLICATION_JSON);
+            } catch (IOException e) {
+                log.error("用户 {} 推送异常", k, e);
+                removeUser(k);
+            }
+        });
+    }
+
+    /**
+     * 移除用户连接
+     *
+     * @param simpleUUID 用户 ID
+     */
+    public static void removeUser(String simpleUUID) {
+        if (sseEmitterMap.containsKey(simpleUUID)) {
+            sseEmitterMap.get(simpleUUID).complete();
+            sseEmitterMap.remove(simpleUUID);
+            log.info("移除用户 {}, 剩余连接 {}", simpleUUID, sseEmitterMap.size());
+        } else {
+            log.error("消息推送 用户 {} 已被移除,剩余连接 {}", simpleUUID, sseEmitterMap.size());
+        }
+    }
+
+    /**
+     * 获取当前连接信息
+     *
+     * @return 所有的连接用户
+     */
+    public static List<String> getIds() {
+        return new ArrayList<>(sseEmitterMap.keySet());
+    }
+
+    /**
+     * 获取当前的连接数量
+     *
+     * @return 当前的连接数量
+     */
+    public static int getUserCount() {
+        return sseEmitterMap.size();
+    }
+
+    private static Runnable completionCallBack(String simpleUUID) {
+        return () -> {
+            log.info("用户 {} 结束连接", simpleUUID);
+        };
+    }
+
+    private static Runnable timeoutCallBack(String simpleUUID) {
+        return () -> {
+            log.error("用户 {} 连接超时", simpleUUID);
+            removeUser(simpleUUID);
+        };
+    }
+
+    private static Consumer<Throwable> errorCallBack(String simpleUUID) {
+        return throwable -> {
+            log.error("用户 {} 连接异常", simpleUUID);
+            removeUser(simpleUUID);
+        };
+    }
+}

+ 3 - 0
virgo.core/src/main/java/com/bosshand/virgo/core/config/ShiroConfig.java

@@ -2,6 +2,7 @@ package com.bosshand.virgo.core.config;
 
 import com.bosshand.virgo.core.service.MgrUserService;
 import com.bosshand.virgo.core.shiro.*;
+import org.apache.shiro.SecurityUtils;
 import org.apache.shiro.mgt.SecurityManager;
 import org.apache.shiro.realm.Realm;
 import org.apache.shiro.session.mgt.SessionManager;
@@ -158,6 +159,8 @@ public class ShiroConfig {
         securityManager.setSessionManager(sessionManager);
         securityManager.setCacheManager(cacheManager);
         // securityManager.setRememberMeManager(rememberMeManager());
+
+		SecurityUtils.setSecurityManager(securityManager); // 关键设置
         return securityManager;
     }