|
@@ -0,0 +1,132 @@
|
|
|
+package com.bosshand.virgo.api.workark.util;
|
|
|
+
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.springframework.http.MediaType;
|
|
|
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
+
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
+import java.util.function.Consumer;
|
|
|
+
|
|
|
+/**
|
|
|
+ * SSE长链接工具类
|
|
|
+ */
|
|
|
+@Slf4j
|
|
|
+public class SseEmitterUtil {
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 使用map对象,便于根据simpleUUID来获取对应的SseEmitter,或者放redis里面
|
|
|
+ */
|
|
|
+ private final static Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();
|
|
|
+
|
|
|
+ public static SseEmitter connect(String simpleUUID) {
|
|
|
+ // 设置超时时间,0表示不过期。默认30S,超时时间未完成会抛出异常:AsyncRequestTimeoutException
|
|
|
+ SseEmitter sseEmitter = new SseEmitter(0L);
|
|
|
+
|
|
|
+ // 注册回调
|
|
|
+ sseEmitter.onCompletion(completionCallBack(simpleUUID));
|
|
|
+ sseEmitter.onError(errorCallBack(simpleUUID));
|
|
|
+ sseEmitter.onTimeout(timeoutCallBack(simpleUUID));
|
|
|
+ sseEmitterMap.put(simpleUUID, sseEmitter);
|
|
|
+
|
|
|
+ log.info("创建新的 SSE 连接,当前用户 {}, 连接总数 {}", simpleUUID, sseEmitterMap.size());
|
|
|
+ return sseEmitter;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 给制定用户发送消息
|
|
|
+ *
|
|
|
+ * @param simpleUUID 指定用户名
|
|
|
+ * @param sseMessage 消息体
|
|
|
+ */
|
|
|
+ public static void sendMessage(String simpleUUID, String sseMessage) {
|
|
|
+ if (sseEmitterMap.containsKey(simpleUUID)) {
|
|
|
+ try {
|
|
|
+ sseEmitterMap.get(simpleUUID).send(sseMessage);
|
|
|
+ log.info("用户 {} 推送消息 {}", simpleUUID, sseMessage);
|
|
|
+ } catch (IOException e) {
|
|
|
+ log.error("用户 {} 推送消息异常", simpleUUID, e);
|
|
|
+ removeUser(simpleUUID);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ log.error("消息推送 用户 {} 不存在,链接总数 {}", simpleUUID, sseEmitterMap.size());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 群发消息
|
|
|
+ */
|
|
|
+ public static void batchSendMessage(String message, List<String> ids) {
|
|
|
+ ids.forEach(simpleUUID -> sendMessage(simpleUUID, message));
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 群发所有人
|
|
|
+ */
|
|
|
+ public static void batchSendMessage(String message) {
|
|
|
+ sseEmitterMap.forEach((k, v) -> {
|
|
|
+ try {
|
|
|
+ v.send(message, MediaType.APPLICATION_JSON);
|
|
|
+ } catch (IOException e) {
|
|
|
+ log.error("用户 {} 推送异常", k, e);
|
|
|
+ removeUser(k);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 移除用户连接
|
|
|
+ *
|
|
|
+ * @param simpleUUID 用户 ID
|
|
|
+ */
|
|
|
+ public static void removeUser(String simpleUUID) {
|
|
|
+ if (sseEmitterMap.containsKey(simpleUUID)) {
|
|
|
+ sseEmitterMap.get(simpleUUID).complete();
|
|
|
+ sseEmitterMap.remove(simpleUUID);
|
|
|
+ log.info("移除用户 {}, 剩余连接 {}", simpleUUID, sseEmitterMap.size());
|
|
|
+ } else {
|
|
|
+ log.error("消息推送 用户 {} 已被移除,剩余连接 {}", simpleUUID, sseEmitterMap.size());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取当前连接信息
|
|
|
+ *
|
|
|
+ * @return 所有的连接用户
|
|
|
+ */
|
|
|
+ public static List<String> getIds() {
|
|
|
+ return new ArrayList<>(sseEmitterMap.keySet());
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取当前的连接数量
|
|
|
+ *
|
|
|
+ * @return 当前的连接数量
|
|
|
+ */
|
|
|
+ public static int getUserCount() {
|
|
|
+ return sseEmitterMap.size();
|
|
|
+ }
|
|
|
+
|
|
|
+ private static Runnable completionCallBack(String simpleUUID) {
|
|
|
+ return () -> {
|
|
|
+ log.info("用户 {} 结束连接", simpleUUID);
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ private static Runnable timeoutCallBack(String simpleUUID) {
|
|
|
+ return () -> {
|
|
|
+ log.error("用户 {} 连接超时", simpleUUID);
|
|
|
+ removeUser(simpleUUID);
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ private static Consumer<Throwable> errorCallBack(String simpleUUID) {
|
|
|
+ return throwable -> {
|
|
|
+ log.error("用户 {} 连接异常", simpleUUID);
|
|
|
+ removeUser(simpleUUID);
|
|
|
+ };
|
|
|
+ }
|
|
|
+}
|