Ver código fonte

停止聊天代码修改

S0025136190 7 meses atrás
pai
commit
e3a11392ff

+ 23 - 43
takai-ai/src/main/java/com/takai/ai/service/impl/TakaiAiServiceImpl.java

@@ -33,16 +33,12 @@ import org.springframework.web.multipart.MultipartFile;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
 import javax.validation.constraints.NotNull;
-import java.io.File;
 import java.io.IOException;
 import java.time.LocalDate;
 import java.time.LocalDateTime;
 import java.time.temporal.ChronoUnit;
 import java.util.*;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.stream.Collectors;
 
 /**
@@ -151,33 +147,6 @@ public class TakaiAiServiceImpl implements ITakaiAiService {
             TakaiKnowledge knowledge = takaiKnowledgeMapper.selectTargetKnowledge(TakaiKnowledge.builder().knowledgeId(appInfo.getKnowledgeIds()).build());
             if (knowledge != null) {
                 SseEmitter sseEmitter = new SseEmitter(0L);
-                // 生成唯一会话ID(可从请求参数获取或自动生成)
-                String sessionId = UUID.randomUUID().toString();
-
-                redisTemplate.opsForValue().set(
-                        "sse:terminate:" + sessionId,
-                        "false",
-                        1, TimeUnit.HOURS
-                );
-
-                // 定时检查终止信号(每2秒检查一次)
-                ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
-                scheduler.scheduleAtFixedRate(() -> {
-                    try {
-                        String flag = redisTemplate.opsForValue()
-                                .get("sse:terminate:" + sessionId).toString();
-
-                        if ("true".equals(flag)) {
-                            // 主动终止流程
-                            sseEmitter.complete();
-                            scheduler.shutdown();
-                            redisTemplate.delete("sse:terminate:" + sessionId);
-                        }
-                    } catch (Exception e) {
-                        log.error("终止检查失败", e);
-                    }
-                }, 0, 2, TimeUnit.SECONDS);
-
                 String url = deepseekConfig.getBaseurl() + deepseekConfig.getChat();
                 TakaiAppInfo info = takaiAppInfoMapper.selectAppInfoByAppId(sseParams.getAppId());
                 JSONObject json = JSONObject.parseObject(info.getAppInfo());
@@ -208,11 +177,6 @@ public class TakaiAiServiceImpl implements ITakaiAiService {
 
                     @Override
                     public void onEvent(@NotNull EventSource eventSource, String id, String type, @NotNull String data) {
-                        // 检查Redis终止标记(双重保障)
-                        if ("true".equals(redisTemplate.opsForValue().get("sse:terminate:" + sessionId))) {
-                            eventSource.cancel();
-                            return;
-                        }
                         if (!StringUtils.isEmpty(data)) {
                             String newData = data.substring(preData.length());
                             if (com.takai.common.utils.StringUtils.isNotEmpty(type) && "finish".equals(type)) {
@@ -271,20 +235,36 @@ public class TakaiAiServiceImpl implements ITakaiAiService {
                             sseEmitter.send(obj);
                         } catch (IOException e) {
                             log.error("deepseek 推送数据失败", e);
-                            throw new RuntimeException(e);
                         }
                     }
                 };
 
-                // 资源清理
+                OkHttpClient client = buildOkHttpClient();
+                EventSource.Factory factory = EventSources.createFactory(client);
+                final EventSource eventSources = factory.newEventSource(request, listener);
+
+                // 客户端主动关闭连接
                 sseEmitter.onCompletion(() -> {
-                    scheduler.shutdown();
-                    redisTemplate.delete("sse:terminate:" + sessionId);
+                    logger.info("deepseek客户端主动关闭连接 -- SSE 连接关闭");
                 });
 
-                OkHttpClient client = buildOkHttpClient();
-                EventSource.Factory factory = EventSources.createFactory(client);
-                factory.newEventSource(request, listener);
+                // 超时回调
+                sseEmitter.onTimeout(() -> {
+                    logger.info("deepseek客户端连接超时 -- SSE 连接关闭");
+                    if(eventSources != null) {
+                        logger.info("deepseek超时回调 -- 成功关闭SSE连接 ");
+                        eventSources.cancel();
+                    }
+                });
+
+                // 错误回调
+                sseEmitter.onError(e -> {
+                    logger.info("deepseek客户端回调失败 -- SSE 连接关闭");
+                    if(eventSources != null) {
+                        logger.info("deepseek错误回调 -- 成功关闭SSE连接 ");
+                        eventSources.cancel();
+                    }
+                });
                 return sseEmitter;
             }
         }

+ 28 - 45
takai-bigmodel/src/main/java/com/takai/bigmodel/service/impl/BigModelServiceImpl.java

@@ -20,6 +20,8 @@ import okhttp3.sse.EventSource;
 import okhttp3.sse.EventSourceListener;
 import okhttp3.sse.EventSources;
 import org.apache.commons.io.FileUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 import org.springframework.beans.BeanUtils;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.data.redis.core.RedisTemplate;
@@ -36,8 +38,7 @@ import java.time.LocalDate;
 import java.time.LocalDateTime;
 import java.time.temporal.ChronoUnit;
 import java.util.*;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
@@ -51,6 +52,8 @@ import java.util.stream.Collectors;
 @DataSource(DataSourceType.SLAVE)
 public class BigModelServiceImpl implements IBigModelService
 {
+    private static final Logger logger = LoggerFactory.getLogger(BigModelServiceImpl.class);
+
     @Autowired
     private BmMediaReplacementMapper bmMediaReplacementMapper;
 
@@ -123,34 +126,6 @@ public class BigModelServiceImpl implements IBigModelService
     @Override
     public SseEmitter sseInvoke(SseParams sseParams) {
         SseEmitter sseEmitter = new SseEmitter(0L);
-
-        // 生成唯一会话ID(可从请求参数获取或自动生成)
-        String sessionId = UUID.randomUUID().toString();
-
-        redisTemplate.opsForValue().set(
-                "sse:terminate:" + sessionId,
-                "false",
-                1, TimeUnit.HOURS
-        );
-
-        // 定时检查终止信号(每2秒检查一次)
-        ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
-        scheduler.scheduleAtFixedRate(() -> {
-            try {
-                String flag = redisTemplate.opsForValue()
-                        .get("sse:terminate:" + sessionId).toString();
-
-                if ("true".equals(flag)) {
-                    // 主动终止流程
-                    sseEmitter.complete();
-                    scheduler.shutdown();
-                    redisTemplate.delete("sse:terminate:" + sessionId);
-                }
-            } catch (Exception e) {
-                log.error("终止检查失败", e);
-            }
-        }, 0, 2, TimeUnit.SECONDS);
-
         String url = bigModelConfig.getBaseurl() + bigModelConfig.getSse().replace("{id}",sseParams.getAppId());
         JSONObject json = new JSONObject();
         List<PromptObject> list = sseParams.getPrompt();
@@ -186,11 +161,6 @@ public class BigModelServiceImpl implements IBigModelService
             @Override
             public void onEvent(@NotNull EventSource eventSource,  String id,  String type, @NotNull String data) {
                 try {
-                    // 检查Redis终止标记(双重保障)
-                    if ("true".equals(redisTemplate.opsForValue().get("sse:terminate:" + sessionId))) {
-                        eventSource.cancel();
-                        return;
-                    }
                     String newData = data.substring(preData.length());
                     preData = data;
                     if(newData.indexOf(START_SIGN) > -1 || symbolData.length() > 0) {
@@ -237,12 +207,9 @@ public class BigModelServiceImpl implements IBigModelService
                         log.info("智谱返回信息:" + json);
                         send(sseEmitter,json);
 
-//                        threadPoolTaskExecutor.execute(() -> send(sseEmitter,json));
                     }
-                    //sseEmitter.send(json);
                 } catch (Exception e) {
                     log.error("智谱 推送数据失败", e);
-                    throw new RuntimeException(e);
                 }
             }
             @Override
@@ -256,20 +223,36 @@ public class BigModelServiceImpl implements IBigModelService
                     sseEmitter.send(obj);
                 } catch (IOException e) {
                     log.error("智谱 推送数据失败", e);
-                    throw new RuntimeException(e);
                 }
             }
         };
 
-        // 资源清理
+        OkHttpClient client = buildOkHttpClient();
+        EventSource.Factory factory = EventSources.createFactory(client);
+        final EventSource eventSources = factory.newEventSource(request, listener);
+
+        // 客户端主动关闭连接
         sseEmitter.onCompletion(() -> {
-            scheduler.shutdown();
-            redisTemplate.delete("sse:terminate:" + sessionId);
+            logger.info("客户端主动关闭连接 -- SSE 连接关闭");
         });
 
-        OkHttpClient client = buildOkHttpClient();
-        EventSource.Factory factory = EventSources.createFactory(client);
-        factory.newEventSource(request, listener);
+        // 超时回调
+        sseEmitter.onTimeout(() -> {
+            logger.info("客户端连接超时 -- SSE 连接关闭");
+            if(eventSources != null) {
+                logger.info("超时回调 -- 成功关闭SSE连接 ");
+                eventSources.cancel();
+            }
+        });
+
+        // 错误回调
+        sseEmitter.onError(e -> {
+            logger.info("客户端回调失败 -- SSE 连接关闭");
+            if(eventSources != null) {
+                logger.info("错误回调 -- 成功关闭SSE连接 ");
+                eventSources.cancel();
+            }
+        });
         return sseEmitter;
     }