Browse Source

feat: support streaming for Gemini Pro (#3688)

* feat: support streaming for Gemini Pro

* feat: display texts smoothly

* chore: remove comments
Fred Liang 2 years ago
parent
commit
5cf58d9446
1 changed files with 48 additions and 80 deletions
  1. 48 80
      app/client/platforms/google.ts

+ 48 - 80
app/client/platforms/google.ts

@@ -20,6 +20,7 @@ export class GeminiProApi implements LLMApi {
     );
   }
   async chat(options: ChatOptions): Promise<void> {
+    const apiClient = this;
     const messages = options.messages.map((v) => ({
       role: v.role.replace("assistant", "model").replace("system", "user"),
       parts: [{ text: v.content }],
@@ -61,8 +62,7 @@ export class GeminiProApi implements LLMApi {
 
     console.log("[Request] google payload: ", requestPayload);
 
-    // todo: support stream later
-    const shouldStream = false;
+    const shouldStream = !!options.config.stream;
     const controller = new AbortController();
     options.onController?.(controller);
     try {
@@ -82,13 +82,21 @@ export class GeminiProApi implements LLMApi {
       if (shouldStream) {
         let responseText = "";
         let remainText = "";
+        let streamChatPath = chatPath.replace(
+          "generateContent",
+          "streamGenerateContent",
+        );
         let finished = false;
+        const finish = () => {
+          finished = true;
+          options.onFinish(responseText + remainText);
+        };
 
         // animate response to make it looks smooth
         function animateResponseText() {
           if (finished || controller.signal.aborted) {
             responseText += remainText;
-            console.log("[Response Animation] finished");
+            finish();
             return;
           }
 
@@ -105,88 +113,41 @@ export class GeminiProApi implements LLMApi {
 
         // start animaion
         animateResponseText();
+        fetch(streamChatPath, chatPayload)
+          .then((response) => {
+            const reader = response?.body?.getReader();
+            const decoder = new TextDecoder();
+            let partialData = "";
+
+            return reader?.read().then(function processText({
+              done,
+              value,
+            }): Promise<any> {
+              if (done) {
+                console.log("Stream complete");
+                // options.onFinish(responseText + remainText);
+                finished = true;
+                return Promise.resolve();
+              }
 
-        const finish = () => {
-          if (!finished) {
-            finished = true;
-            options.onFinish(responseText + remainText);
-          }
-        };
+              partialData += decoder.decode(value, { stream: true });
 
-        controller.signal.onabort = finish;
-
-        fetchEventSource(chatPath, {
-          ...chatPayload,
-          async onopen(res) {
-            clearTimeout(requestTimeoutId);
-            const contentType = res.headers.get("content-type");
-            console.log(
-              "[OpenAI] request response content type: ",
-              contentType,
-            );
-
-            if (contentType?.startsWith("text/plain")) {
-              responseText = await res.clone().text();
-              return finish();
-            }
-
-            if (
-              !res.ok ||
-              !res.headers
-                .get("content-type")
-                ?.startsWith(EventStreamContentType) ||
-              res.status !== 200
-            ) {
-              const responseTexts = [responseText];
-              let extraInfo = await res.clone().text();
               try {
-                const resJson = await res.clone().json();
-                extraInfo = prettyObject(resJson);
-              } catch {}
-
-              if (res.status === 401) {
-                responseTexts.push(Locale.Error.Unauthorized);
+                let data = JSON.parse(ensureProperEnding(partialData));
+                console.log(data);
+                let fetchText = apiClient.extractMessage(data[data.length - 1]);
+                console.log("[Response Animation] fetchText: ", fetchText);
+                remainText += fetchText;
+              } catch (error) {
+                // skip error message when parsing json
               }
 
-              if (extraInfo) {
-                responseTexts.push(extraInfo);
-              }
-
-              responseText = responseTexts.join("\n\n");
-
-              return finish();
-            }
-          },
-          onmessage(msg) {
-            if (msg.data === "[DONE]" || finished) {
-              return finish();
-            }
-            const text = msg.data;
-            try {
-              const json = JSON.parse(text) as {
-                choices: Array<{
-                  delta: {
-                    content: string;
-                  };
-                }>;
-              };
-              const delta = json.choices[0]?.delta?.content;
-              if (delta) {
-                remainText += delta;
-              }
-            } catch (e) {
-              console.error("[Request] parse error", text);
-            }
-          },
-          onclose() {
-            finish();
-          },
-          onerror(e) {
-            options.onError?.(e);
-            throw e;
-          },
-          openWhenHidden: true,
-        });
+              return reader.read().then(processText);
+            });
+          })
+          .catch((error) => {
+            console.error("Error:", error);
+          });
       } else {
         const res = await fetch(chatPath, chatPayload);
         clearTimeout(requestTimeoutId);
@@ -220,3 +181,10 @@ export class GeminiProApi implements LLMApi {
     return "/api/google/" + path;
   }
 }
+
+function ensureProperEnding(str: string) {
+  if (str.startsWith("[") && !str.endsWith("]")) {
+    return str + "]";
+  }
+  return str;
+}