Parcourir la source

feat: Support DeepSeek API streaming with thinking mode

Kadxy il y a 10 mois
Parent
commit
c449737127
2 fichiers modifiés avec 301 ajouts et 6 suppressions
  1. 37 5
      app/client/platforms/deepseek.ts
  2. 264 1
      app/utils/chat.ts

+ 37 - 5
app/client/platforms/deepseek.ts

@@ -13,7 +13,7 @@ import {
   ChatMessageTool,
   usePluginStore,
 } from "@/app/store";
-import { stream } from "@/app/utils/chat";
+import { streamWithThink } from "@/app/utils/chat";
 import {
   ChatOptions,
   getHeaders,
@@ -107,6 +107,8 @@ export class DeepSeekApi implements LLMApi {
         headers: getHeaders(),
       };
 
+      console.log(chatPayload);
+
       // make a fetch request
       const requestTimeoutId = setTimeout(
         () => controller.abort(),
@@ -119,7 +121,7 @@ export class DeepSeekApi implements LLMApi {
           .getAsTools(
             useChatStore.getState().currentSession().mask?.plugin || [],
           );
-        return stream(
+        return streamWithThink(
           chatPath,
           requestPayload,
           getHeaders(),
@@ -128,12 +130,13 @@ export class DeepSeekApi implements LLMApi {
           controller,
           // parseSSE
           (text: string, runTools: ChatMessageTool[]) => {
-            // console.log("parseSSE", text, runTools);
+            console.log("parseSSE", text, runTools);
             const json = JSON.parse(text);
             const choices = json.choices as Array<{
               delta: {
-                content: string;
+                content: string | null;
                 tool_calls: ChatMessageTool[];
+                reasoning_content: string | null;
               };
             }>;
             const tool_calls = choices[0]?.delta?.tool_calls;
@@ -155,7 +158,36 @@ export class DeepSeekApi implements LLMApi {
                 runTools[index]["function"]["arguments"] += args;
               }
             }
-            return choices[0]?.delta?.content;
+            const reasoning = choices[0]?.delta?.reasoning_content;
+            const content = choices[0]?.delta?.content;
+
+            // Skip if both content and reasoning_content are empty or null
+            if (
+              (!reasoning || reasoning.trim().length === 0) &&
+              (!content || content.trim().length === 0)
+            ) {
+              return {
+                isThinking: false,
+                content: "",
+              };
+            }
+
+            if (reasoning && reasoning.trim().length > 0) {
+              return {
+                isThinking: true,
+                content: reasoning,
+              };
+            } else if (content && content.trim().length > 0) {
+              return {
+                isThinking: false,
+                content: content,
+              };
+            }
+
+            return {
+              isThinking: false,
+              content: "",
+            };
           },
           // processToolMessage, include tool_calls message and tool call results
           (

+ 264 - 1
app/utils/chat.ts

@@ -344,8 +344,12 @@ export function stream(
           return finish();
         }
         const text = msg.data;
+        // Skip empty messages
+        if (!text || text.trim().length === 0) {
+          return;
+        }
         try {
-          const chunk = parseSSE(msg.data, runTools);
+          const chunk = parseSSE(text, runTools);
           if (chunk) {
             remainText += chunk;
           }
@@ -366,3 +370,262 @@ export function stream(
   console.debug("[ChatAPI] start");
   chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource
 }
+
+export function streamWithThink(
+  chatPath: string,
+  requestPayload: any,
+  headers: any,
+  tools: any[],
+  funcs: Record<string, Function>,
+  controller: AbortController,
+  parseSSE: (
+    text: string,
+    runTools: any[],
+  ) => {
+    isThinking: boolean;
+    content: string | undefined;
+  },
+  processToolMessage: (
+    requestPayload: any,
+    toolCallMessage: any,
+    toolCallResult: any[],
+  ) => void,
+  options: any,
+) {
+  let responseText = "";
+  let remainText = "";
+  let finished = false;
+  let running = false;
+  let runTools: any[] = [];
+  let responseRes: Response;
+  let isInThinkingMode = false;
+  let lastIsThinking = false;
+
+  // animate response to make it looks smooth
+  function animateResponseText() {
+    if (finished || controller.signal.aborted) {
+      responseText += remainText;
+      console.log("[Response Animation] finished");
+      if (responseText?.length === 0) {
+        options.onError?.(new Error("empty response from server"));
+      }
+      return;
+    }
+
+    if (remainText.length > 0) {
+      const fetchCount = Math.max(1, Math.round(remainText.length / 60));
+      const fetchText = remainText.slice(0, fetchCount);
+      responseText += fetchText;
+      remainText = remainText.slice(fetchCount);
+      options.onUpdate?.(responseText, fetchText);
+    }
+
+    requestAnimationFrame(animateResponseText);
+  }
+
+  // start animaion
+  animateResponseText();
+
+  const finish = () => {
+    if (!finished) {
+      if (!running && runTools.length > 0) {
+        const toolCallMessage = {
+          role: "assistant",
+          tool_calls: [...runTools],
+        };
+        running = true;
+        runTools.splice(0, runTools.length); // empty runTools
+        return Promise.all(
+          toolCallMessage.tool_calls.map((tool) => {
+            options?.onBeforeTool?.(tool);
+            return Promise.resolve(
+              // @ts-ignore
+              funcs[tool.function.name](
+                // @ts-ignore
+                tool?.function?.arguments
+                  ? JSON.parse(tool?.function?.arguments)
+                  : {},
+              ),
+            )
+              .then((res) => {
+                let content = res.data || res?.statusText;
+                // hotfix #5614
+                content =
+                  typeof content === "string"
+                    ? content
+                    : JSON.stringify(content);
+                if (res.status >= 300) {
+                  return Promise.reject(content);
+                }
+                return content;
+              })
+              .then((content) => {
+                options?.onAfterTool?.({
+                  ...tool,
+                  content,
+                  isError: false,
+                });
+                return content;
+              })
+              .catch((e) => {
+                options?.onAfterTool?.({
+                  ...tool,
+                  isError: true,
+                  errorMsg: e.toString(),
+                });
+                return e.toString();
+              })
+              .then((content) => ({
+                name: tool.function.name,
+                role: "tool",
+                content,
+                tool_call_id: tool.id,
+              }));
+          }),
+        ).then((toolCallResult) => {
+          processToolMessage(requestPayload, toolCallMessage, toolCallResult);
+          setTimeout(() => {
+            // call again
+            console.debug("[ChatAPI] restart");
+            running = false;
+            chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource
+          }, 60);
+        });
+        return;
+      }
+      if (running) {
+        return;
+      }
+      console.debug("[ChatAPI] end");
+      finished = true;
+      options.onFinish(responseText + remainText, responseRes);
+    }
+  };
+
+  controller.signal.onabort = finish;
+
+  function chatApi(
+    chatPath: string,
+    headers: any,
+    requestPayload: any,
+    tools: any,
+  ) {
+    const chatPayload = {
+      method: "POST",
+      body: JSON.stringify({
+        ...requestPayload,
+        tools: tools && tools.length ? tools : undefined,
+      }),
+      signal: controller.signal,
+      headers,
+    };
+    const requestTimeoutId = setTimeout(
+      () => controller.abort(),
+      REQUEST_TIMEOUT_MS,
+    );
+    fetchEventSource(chatPath, {
+      fetch: tauriFetch as any,
+      ...chatPayload,
+      async onopen(res) {
+        clearTimeout(requestTimeoutId);
+        const contentType = res.headers.get("content-type");
+        console.log("[Request] response content type: ", contentType);
+        responseRes = res;
+
+        if (contentType?.startsWith("text/plain")) {
+          responseText = await res.clone().text();
+          return finish();
+        }
+
+        if (
+          !res.ok ||
+          !res.headers
+            .get("content-type")
+            ?.startsWith(EventStreamContentType) ||
+          res.status !== 200
+        ) {
+          const responseTexts = [responseText];
+          let extraInfo = await res.clone().text();
+          try {
+            const resJson = await res.clone().json();
+            extraInfo = prettyObject(resJson);
+          } catch {}
+
+          if (res.status === 401) {
+            responseTexts.push(Locale.Error.Unauthorized);
+          }
+
+          if (extraInfo) {
+            responseTexts.push(extraInfo);
+          }
+
+          responseText = responseTexts.join("\n\n");
+
+          return finish();
+        }
+      },
+      onmessage(msg) {
+        if (msg.data === "[DONE]" || finished) {
+          return finish();
+        }
+        const text = msg.data;
+        // Skip empty messages
+        if (!text || text.trim().length === 0) {
+          return;
+        }
+        try {
+          const chunk = parseSSE(text, runTools);
+          // Skip if content is empty
+          if (!chunk?.content || chunk.content.trim().length === 0) {
+            return;
+          }
+          // Check if thinking mode changed
+          const isThinkingChanged = lastIsThinking !== chunk.isThinking;
+          lastIsThinking = chunk.isThinking;
+
+          if (chunk.isThinking) {
+            // If in thinking mode
+            if (!isInThinkingMode || isThinkingChanged) {
+              // If this is a new thinking block or mode changed, add prefix
+              isInThinkingMode = true;
+              if (remainText.length > 0) {
+                remainText += "\n";
+              }
+              remainText += "> " + chunk.content;
+            } else {
+              // Handle newlines in thinking content
+              if (chunk.content.includes("\n\n")) {
+                const lines = chunk.content.split("\n\n");
+                remainText += lines.join("\n\n> ");
+              } else {
+                remainText += chunk.content;
+              }
+            }
+          } else {
+            // If in normal mode
+            if (isInThinkingMode || isThinkingChanged) {
+              // If switching from thinking mode to normal mode
+              isInThinkingMode = false;
+              remainText += "\n\n" + chunk.content;
+            } else {
+              remainText += chunk.content;
+            }
+          }
+        } catch (e) {
+          console.error("[Request] parse error", text, msg, e);
+          // Don't throw error for parse failures, just log them
+        }
+      },
+      onclose() {
+        finish();
+      },
+      onerror(e) {
+        options?.onError?.(e);
+        throw e;
+      },
+      openWhenHidden: true,
+    });
+  }
+  console.debug("[ChatAPI] start");
+  chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource
+}