浏览代码

create common function stream for fetchEventSource

lloydzhou 1 年之前
父节点
当前提交
7fc0d11931
共有 2 个文件被更改,包括 279 次插入255 次删除
  1. 73 254
      app/client/platforms/openai.ts
  2. 206 1
      app/utils/chat.ts

+ 73 - 254
app/client/platforms/openai.ts

@@ -20,6 +20,7 @@ import {
   preProcessImageContent,
   uploadImage,
   base64Image2Blob,
+  stream,
 } from "@/app/utils/chat";
 import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
 import { DalleSize, DalleQuality, DalleStyle } from "@/app/typing";
@@ -238,52 +239,30 @@ export class ChatGPTApi implements LLMApi {
           isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
         );
       }
-      const chatPayload = {
-        method: "POST",
-        body: JSON.stringify(requestPayload),
-        signal: controller.signal,
-        headers: getHeaders(),
-      };
-
-      // make a fetch request
-      const requestTimeoutId = setTimeout(
-        () => controller.abort(),
-        isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
-      );
-
       if (shouldStream) {
-        let responseText = "";
-        let remainText = "";
-        let finished = false;
-        let running = false;
-        let runTools: ChatMessageTool[] = [];
-
-        // animate response to make it looks smooth
-        function animateResponseText() {
-          if (finished || controller.signal.aborted) {
-            responseText += remainText;
-            console.log("[Response Animation] finished");
-            if (responseText?.length === 0) {
-              options.onError?.(new Error("empty response from server"));
-            }
-            return;
-          }
-
-          if (remainText.length > 0) {
-            const fetchCount = Math.max(1, Math.round(remainText.length / 60));
-            const fetchText = remainText.slice(0, fetchCount);
-            responseText += fetchText;
-            remainText = remainText.slice(fetchCount);
-            options.onUpdate?.(responseText, fetchText);
-          }
-
-          requestAnimationFrame(animateResponseText);
-        }
-
-        // start animaion
-        animateResponseText();
-
-        // TODO 后面这里是从选择的plugins中获取function列表
+        const tools = [
+          {
+            type: "function",
+            function: {
+              name: "get_current_weather",
+              description: "Get the current weather",
+              parameters: {
+                type: "object",
+                properties: {
+                  location: {
+                    type: "string",
+                    description: "The city and country, eg. San Francisco, USA",
+                  },
+                  format: {
+                    type: "string",
+                    enum: ["celsius", "fahrenheit"],
+                  },
+                },
+                required: ["location", "format"],
+              },
+            },
+          },
+        ];
         const funcs = {
           get_current_weather: (args: any) => {
             console.log("call get_current_weather", args);
@@ -292,221 +271,61 @@ export class ChatGPTApi implements LLMApi {
             });
           },
         };
-        const finish = () => {
-          if (!finished) {
-            console.log("try run tools", runTools.length, finished, running);
-            if (!running && runTools.length > 0) {
-              const toolCallMessage = {
-                role: "assistant",
-                tool_calls: [...runTools],
+        stream(
+          chatPath,
+          requestPayload,
+          getHeaders(),
+          tools,
+          funcs,
+          controller,
+          (text: string, runTools: ChatMessageTool[]) => {
+            console.log("parseSSE", text, runTools);
+            const json = JSON.parse(text);
+            const choices = json.choices as Array<{
+              delta: {
+                content: string;
+                tool_calls: ChatMessageTool[];
               };
-              running = true;
-              runTools.splice(0, runTools.length); // empty runTools
-              return Promise.all(
-                toolCallMessage.tool_calls.map((tool) => {
-                  options?.onBeforeTool?.(tool);
-                  return Promise.resolve(
-                    // @ts-ignore
-                    funcs[tool.function.name](
-                      // @ts-ignore
-                      JSON.parse(tool.function.arguments),
-                    ),
-                  )
-                    .then((content) => {
-                      options?.onAfterTool?.({
-                        ...tool,
-                        content,
-                        isError: false,
-                      });
-                      return content;
-                    })
-                    .catch((e) => {
-                      options?.onAfterTool?.({ ...tool, isError: true });
-                      return e.toString();
-                    })
-                    .then((content) => ({
-                      role: "tool",
-                      content,
-                      tool_call_id: tool.id,
-                    }));
-                }),
-              ).then((toolCallResult) => {
-                console.log("end runTools", toolCallMessage, toolCallResult);
-                // @ts-ignore
-                requestPayload?.messages?.splice(
-                  // @ts-ignore
-                  requestPayload?.messages?.length,
-                  0,
-                  toolCallMessage,
-                  ...toolCallResult,
-                );
-                setTimeout(() => {
-                  // call again
-                  console.log("start again");
-                  running = false;
-                  chatApi(chatPath, requestPayload as RequestPayload); // call fetchEventSource
-                }, 60);
-              });
-              console.log("try run tools", runTools.length, finished);
-              return;
-            }
-            if (running) {
-              return;
-            }
-            finished = true;
-            options.onFinish(responseText + remainText);
-          }
-        };
-
-        controller.signal.onabort = finish;
-
-        function chatApi(chatPath: string, requestPayload: RequestPayload) {
-          const chatPayload = {
-            method: "POST",
-            body: JSON.stringify({
-              ...requestPayload,
-              // TODO 这里暂时写死的,后面从store.tools中按照当前session中选择的获取
-              tools: [
-                {
-                  type: "function",
+            }>;
+            const tool_calls = choices[0]?.delta?.tool_calls;
+            if (tool_calls?.length > 0) {
+              const index = tool_calls[0]?.index;
+              const id = tool_calls[0]?.id;
+              const args = tool_calls[0]?.function?.arguments;
+              if (id) {
+                runTools.push({
+                  id,
+                  type: tool_calls[0]?.type,
                   function: {
-                    name: "get_current_weather",
-                    description: "Get the current weather",
-                    parameters: {
-                      type: "object",
-                      properties: {
-                        location: {
-                          type: "string",
-                          description:
-                            "The city and country, eg. San Francisco, USA",
-                        },
-                        format: {
-                          type: "string",
-                          enum: ["celsius", "fahrenheit"],
-                        },
-                      },
-                      required: ["location", "format"],
-                    },
+                    name: tool_calls[0]?.function?.name as string,
+                    arguments: args,
                   },
-                },
-              ],
-            }),
-            signal: controller.signal,
-            headers: getHeaders(),
-          };
-          console.log("chatApi", chatPath, requestPayload, chatPayload);
-          fetchEventSource(chatPath, {
-            ...chatPayload,
-            async onopen(res) {
-              clearTimeout(requestTimeoutId);
-              const contentType = res.headers.get("content-type");
-              console.log(
-                "[OpenAI] request response content type: ",
-                contentType,
-              );
-
-              if (contentType?.startsWith("text/plain")) {
-                responseText = await res.clone().text();
-                return finish();
+                });
+              } else {
+                // @ts-ignore
+                runTools[index]["function"]["arguments"] += args;
               }
+            }
 
-              if (
-                !res.ok ||
-                !res.headers
-                  .get("content-type")
-                  ?.startsWith(EventStreamContentType) ||
-                res.status !== 200
-              ) {
-                const responseTexts = [responseText];
-                let extraInfo = await res.clone().text();
-                try {
-                  const resJson = await res.clone().json();
-                  extraInfo = prettyObject(resJson);
-                } catch {}
-
-                if (res.status === 401) {
-                  responseTexts.push(Locale.Error.Unauthorized);
-                }
-
-                if (extraInfo) {
-                  responseTexts.push(extraInfo);
-                }
-
-                responseText = responseTexts.join("\n\n");
-
-                return finish();
-              }
-            },
-            onmessage(msg) {
-              if (msg.data === "[DONE]" || finished) {
-                return finish();
-              }
-              const text = msg.data;
-              try {
-                const json = JSON.parse(text);
-                const choices = json.choices as Array<{
-                  delta: {
-                    content: string;
-                    tool_calls: ChatMessageTool[];
-                  };
-                }>;
-                console.log("choices", choices);
-                const delta = choices[0]?.delta?.content;
-                const tool_calls = choices[0]?.delta?.tool_calls;
-                const textmoderation = json?.prompt_filter_results;
-
-                if (delta) {
-                  remainText += delta;
-                }
-                if (tool_calls?.length > 0) {
-                  const index = tool_calls[0]?.index;
-                  const id = tool_calls[0]?.id;
-                  const args = tool_calls[0]?.function?.arguments;
-                  if (id) {
-                    runTools.push({
-                      id,
-                      type: tool_calls[0]?.type,
-                      function: {
-                        name: tool_calls[0]?.function?.name as string,
-                        arguments: args,
-                      },
-                    });
-                  } else {
-                    // @ts-ignore
-                    runTools[index]["function"]["arguments"] += args;
-                  }
-                }
-
-                console.log("runTools", runTools);
-
-                if (
-                  textmoderation &&
-                  textmoderation.length > 0 &&
-                  ServiceProvider.Azure
-                ) {
-                  const contentFilterResults =
-                    textmoderation[0]?.content_filter_results;
-                  console.log(
-                    `[${ServiceProvider.Azure}] [Text Moderation] flagged categories result:`,
-                    contentFilterResults,
-                  );
-                }
-              } catch (e) {
-                console.error("[Request] parse error", text, msg);
-              }
-            },
-            onclose() {
-              finish();
-            },
-            onerror(e) {
-              options.onError?.(e);
-              throw e;
-            },
-            openWhenHidden: true,
-          });
-        }
-        chatApi(chatPath, requestPayload as RequestPayload); // call fetchEventSource
+            console.log("runTools", runTools);
+            return choices[0]?.delta?.content;
+          },
+          options,
+        );
       } else {
+        const chatPayload = {
+          method: "POST",
+          body: JSON.stringify(requestPayload),
+          signal: controller.signal,
+          headers: getHeaders(),
+        };
+
+        // make a fetch request
+        const requestTimeoutId = setTimeout(
+          () => controller.abort(),
+          isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
+        );
+
         const res = await fetch(chatPath, chatPayload);
         clearTimeout(requestTimeoutId);
 

+ 206 - 1
app/utils/chat.ts

@@ -1,5 +1,15 @@
-import { CACHE_URL_PREFIX, UPLOAD_URL } from "@/app/constant";
+import {
+  CACHE_URL_PREFIX,
+  UPLOAD_URL,
+  REQUEST_TIMEOUT_MS,
+} from "@/app/constant";
 import { RequestMessage } from "@/app/client/api";
+import Locale from "@/app/locales";
+import {
+  EventStreamContentType,
+  fetchEventSource,
+} from "@fortaine/fetch-event-source";
+import { prettyObject } from "./format";
 
 export function compressImage(file: Blob, maxSize: number): Promise<string> {
   return new Promise((resolve, reject) => {
@@ -142,3 +152,198 @@ export function removeImage(imageUrl: string) {
     credentials: "include",
   });
 }
+
+export function stream(
+  chatPath: string,
+  requestPayload: any,
+  headers: any,
+  tools: any[],
+  funcs: any,
+  controller: AbortController,
+  parseSSE: (text: string, runTools: any[]) => string | undefined,
+  options: any,
+) {
+  let responseText = "";
+  let remainText = "";
+  let finished = false;
+  let running = false;
+  let runTools: any[] = [];
+
+  // animate response to make it looks smooth
+  function animateResponseText() {
+    if (finished || controller.signal.aborted) {
+      responseText += remainText;
+      console.log("[Response Animation] finished");
+      if (responseText?.length === 0) {
+        options.onError?.(new Error("empty response from server"));
+      }
+      return;
+    }
+
+    if (remainText.length > 0) {
+      const fetchCount = Math.max(1, Math.round(remainText.length / 60));
+      const fetchText = remainText.slice(0, fetchCount);
+      responseText += fetchText;
+      remainText = remainText.slice(fetchCount);
+      options.onUpdate?.(responseText, fetchText);
+    }
+
+    requestAnimationFrame(animateResponseText);
+  }
+
+  // start animaion
+  animateResponseText();
+
+  const finish = () => {
+    if (!finished) {
+      console.log("try run tools", runTools.length, finished, running);
+      if (!running && runTools.length > 0) {
+        const toolCallMessage = {
+          role: "assistant",
+          tool_calls: [...runTools],
+        };
+        running = true;
+        runTools.splice(0, runTools.length); // empty runTools
+        return Promise.all(
+          toolCallMessage.tool_calls.map((tool) => {
+            options?.onBeforeTool?.(tool);
+            return Promise.resolve(
+              // @ts-ignore
+              funcs[tool.function.name](
+                // @ts-ignore
+                JSON.parse(tool.function.arguments),
+              ),
+            )
+              .then((content) => {
+                options?.onAfterTool?.({
+                  ...tool,
+                  content,
+                  isError: false,
+                });
+                return content;
+              })
+              .catch((e) => {
+                options?.onAfterTool?.({ ...tool, isError: true });
+                return e.toString();
+              })
+              .then((content) => ({
+                role: "tool",
+                content,
+                tool_call_id: tool.id,
+              }));
+          }),
+        ).then((toolCallResult) => {
+          console.log("end runTools", toolCallMessage, toolCallResult);
+          // @ts-ignore
+          requestPayload?.messages?.splice(
+            // @ts-ignore
+            requestPayload?.messages?.length,
+            0,
+            toolCallMessage,
+            ...toolCallResult,
+          );
+          setTimeout(() => {
+            // call again
+            console.log("start again");
+            running = false;
+            chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource
+          }, 60);
+        });
+        console.log("try run tools", runTools.length, finished);
+        return;
+      }
+      if (running) {
+        return;
+      }
+      finished = true;
+      options.onFinish(responseText + remainText);
+    }
+  };
+
+  controller.signal.onabort = finish;
+
+  function chatApi(
+    chatPath: string,
+    headers: any,
+    requestPayload: any,
+    tools: any,
+  ) {
+    const chatPayload = {
+      method: "POST",
+      body: JSON.stringify({
+        ...requestPayload,
+        tools,
+      }),
+      signal: controller.signal,
+      headers,
+    };
+    const requestTimeoutId = setTimeout(
+      () => controller.abort(),
+      REQUEST_TIMEOUT_MS,
+    );
+    fetchEventSource(chatPath, {
+      ...chatPayload,
+      async onopen(res) {
+        clearTimeout(requestTimeoutId);
+        const contentType = res.headers.get("content-type");
+        console.log("[Request] response content type: ", contentType);
+
+        if (contentType?.startsWith("text/plain")) {
+          responseText = await res.clone().text();
+          return finish();
+        }
+
+        if (
+          !res.ok ||
+          !res.headers
+            .get("content-type")
+            ?.startsWith(EventStreamContentType) ||
+          res.status !== 200
+        ) {
+          const responseTexts = [responseText];
+          let extraInfo = await res.clone().text();
+          try {
+            const resJson = await res.clone().json();
+            extraInfo = prettyObject(resJson);
+          } catch {}
+
+          if (res.status === 401) {
+            responseTexts.push(Locale.Error.Unauthorized);
+          }
+
+          if (extraInfo) {
+            responseTexts.push(extraInfo);
+          }
+
+          responseText = responseTexts.join("\n\n");
+
+          return finish();
+        }
+      },
+      onmessage(msg) {
+        if (msg.data === "[DONE]" || finished) {
+          return finish();
+        }
+        const text = msg.data;
+        try {
+          const chunk = parseSSE(msg.data, runTools);
+          if (chunk) {
+            remainText += chunk;
+          }
+        } catch (e) {
+          console.error("[Request] parse error", text, msg);
+        }
+      },
+      onclose() {
+        finish();
+      },
+      onerror(e) {
+        options?.onError?.(e);
+        throw e;
+      },
+      openWhenHidden: true,
+    });
+    console.log("chatApi", chatPath, requestPayload, tools);
+  }
+  chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource
+}