Ver Fonte

stash code

lloydzhou há 1 ano atrás
pai
commit
f5209fc344

+ 1 - 4
app/api/common.ts

@@ -32,10 +32,7 @@ export async function requestOpenai(req: NextRequest) {
     authHeaderName = "Authorization";
   }
 
-  let path = `${req.nextUrl.pathname}${req.nextUrl.search}`.replaceAll(
-    "/api/openai/",
-    "",
-  );
+  let path = `${req.nextUrl.pathname}`.replaceAll("/api/openai/", "");
 
   let baseUrl =
     (isAzure ? serverConfig.azureUrl : serverConfig.baseUrl) || OPENAI_BASE_URL;

+ 9 - 1
app/client/api.ts

@@ -5,7 +5,13 @@ import {
   ModelProvider,
   ServiceProvider,
 } from "../constant";
-import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
+import {
+  ChatMessageTool,
+  ChatMessage,
+  ModelType,
+  useAccessStore,
+  useChatStore,
+} from "../store";
 import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai";
 import { GeminiProApi } from "./platforms/google";
 import { ClaudeApi } from "./platforms/anthropic";
@@ -56,6 +62,8 @@ export interface ChatOptions {
   onFinish: (message: string) => void;
   onError?: (err: Error) => void;
   onController?: (controller: AbortController) => void;
+  onBeforeTool?: (tool: ChatMessageTool) => void;
+  onAfterTool?: (tool: ChatMessageTool) => void;
 }
 
 export interface LLMUsage {

+ 199 - 77
app/client/platforms/openai.ts

@@ -250,6 +250,8 @@ export class ChatGPTApi implements LLMApi {
         let responseText = "";
         let remainText = "";
         let finished = false;
+        let running = false;
+        let runTools = [];
 
         // animate response to make it looks smooth
         function animateResponseText() {
@@ -276,8 +278,70 @@ export class ChatGPTApi implements LLMApi {
         // start animaion
         animateResponseText();
 
+        // TODO 后面这里是从选择的plugins中获取function列表
+        const funcs = {
+          get_current_weather: (args) => {
+            console.log("call get_current_weather", args);
+            return "30";
+          },
+        };
         const finish = () => {
           if (!finished) {
+            console.log("try run tools", runTools.length, finished, running);
+            if (!running && runTools.length > 0) {
+              const toolCallMessage = {
+                role: "assistant",
+                tool_calls: [...runTools],
+              };
+              running = true;
+              runTools.splice(0, runTools.length); // empty runTools
+              return Promise.all(
+                toolCallMessage.tool_calls.map((tool) => {
+                  options?.onBeforeTool(tool);
+                  return Promise.resolve(
+                    funcs[tool.function.name](
+                      JSON.parse(tool.function.arguments),
+                    ),
+                  )
+                    .then((content) => {
+                      options?.onAfterTool({
+                        ...tool,
+                        content,
+                        isError: false,
+                      });
+                      return content;
+                    })
+                    .catch((e) => {
+                      options?.onAfterTool({ ...tool, isError: true });
+                      return e.toString();
+                    })
+                    .then((content) => ({
+                      role: "tool",
+                      content,
+                      tool_call_id: tool.id,
+                    }));
+                }),
+              ).then((toolCallResult) => {
+                console.log("end runTools", toolCallMessage, toolCallResult);
+                requestPayload["messages"].splice(
+                  requestPayload["messages"].length,
+                  0,
+                  toolCallMessage,
+                  ...toolCallResult,
+                );
+                setTimeout(() => {
+                  // call again
+                  console.log("start again");
+                  running = false;
+                  chatApi(chatPath, requestPayload); // call fetchEventSource
+                }, 0);
+              });
+              console.log("try run tools", runTools.length, finished);
+              return;
+            }
+            if (running) {
+              return;
+            }
             finished = true;
             options.onFinish(responseText + remainText);
           }
@@ -285,90 +349,148 @@ export class ChatGPTApi implements LLMApi {
 
         controller.signal.onabort = finish;
 
-        fetchEventSource(chatPath, {
-          ...chatPayload,
-          async onopen(res) {
-            clearTimeout(requestTimeoutId);
-            const contentType = res.headers.get("content-type");
-            console.log(
-              "[OpenAI] request response content type: ",
-              contentType,
-            );
-
-            if (contentType?.startsWith("text/plain")) {
-              responseText = await res.clone().text();
-              return finish();
-            }
+        function chatApi(chatPath, requestPayload) {
+          const chatPayload = {
+            method: "POST",
+            body: JSON.stringify({
+              ...requestPayload,
+              // TODO 这里暂时写死的,后面从store.tools中按照当前session中选择的获取
+              tools: [
+                {
+                  type: "function",
+                  function: {
+                    name: "get_current_weather",
+                    description: "Get the current weather",
+                    parameters: {
+                      type: "object",
+                      properties: {
+                        location: {
+                          type: "string",
+                          description:
+                            "The city and country, eg. San Francisco, USA",
+                        },
+                        format: {
+                          type: "string",
+                          enum: ["celsius", "fahrenheit"],
+                        },
+                      },
+                      required: ["location", "format"],
+                    },
+                  },
+                },
+              ],
+            }),
+            signal: controller.signal,
+            headers: getHeaders(),
+          };
+          console.log("chatApi", chatPath, requestPayload, chatPayload);
+          fetchEventSource(chatPath, {
+            ...chatPayload,
+            async onopen(res) {
+              clearTimeout(requestTimeoutId);
+              const contentType = res.headers.get("content-type");
+              console.log(
+                "[OpenAI] request response content type: ",
+                contentType,
+              );
+
+              if (contentType?.startsWith("text/plain")) {
+                responseText = await res.clone().text();
+                return finish();
+              }
 
-            if (
-              !res.ok ||
-              !res.headers
-                .get("content-type")
-                ?.startsWith(EventStreamContentType) ||
-              res.status !== 200
-            ) {
-              const responseTexts = [responseText];
-              let extraInfo = await res.clone().text();
-              try {
-                const resJson = await res.clone().json();
-                extraInfo = prettyObject(resJson);
-              } catch {}
+              if (
+                !res.ok ||
+                !res.headers
+                  .get("content-type")
+                  ?.startsWith(EventStreamContentType) ||
+                res.status !== 200
+              ) {
+                const responseTexts = [responseText];
+                let extraInfo = await res.clone().text();
+                try {
+                  const resJson = await res.clone().json();
+                  extraInfo = prettyObject(resJson);
+                } catch {}
 
-              if (res.status === 401) {
-                responseTexts.push(Locale.Error.Unauthorized);
-              }
+                if (res.status === 401) {
+                  responseTexts.push(Locale.Error.Unauthorized);
+                }
 
-              if (extraInfo) {
-                responseTexts.push(extraInfo);
-              }
+                if (extraInfo) {
+                  responseTexts.push(extraInfo);
+                }
 
-              responseText = responseTexts.join("\n\n");
+                responseText = responseTexts.join("\n\n");
 
-              return finish();
-            }
-          },
-          onmessage(msg) {
-            if (msg.data === "[DONE]" || finished) {
-              return finish();
-            }
-            const text = msg.data;
-            try {
-              const json = JSON.parse(text);
-              const choices = json.choices as Array<{
-                delta: { content: string };
-              }>;
-              const delta = choices[0]?.delta?.content;
-              const textmoderation = json?.prompt_filter_results;
-
-              if (delta) {
-                remainText += delta;
+                return finish();
               }
-
-              if (
-                textmoderation &&
-                textmoderation.length > 0 &&
-                ServiceProvider.Azure
-              ) {
-                const contentFilterResults =
-                  textmoderation[0]?.content_filter_results;
-                console.log(
-                  `[${ServiceProvider.Azure}] [Text Moderation] flagged categories result:`,
-                  contentFilterResults,
-                );
+            },
+            onmessage(msg) {
+              if (msg.data === "[DONE]" || finished) {
+                return finish();
               }
-            } catch (e) {
-              console.error("[Request] parse error", text, msg);
-            }
-          },
-          onclose() {
-            finish();
-          },
-          onerror(e) {
-            options.onError?.(e);
-            throw e;
-          },
-          openWhenHidden: true,
-        });
+              const text = msg.data;
+              try {
+                const json = JSON.parse(text);
+                const choices = json.choices as Array<{
+                  delta: { content: string };
+                }>;
+                console.log("choices", choices);
+                const delta = choices[0]?.delta?.content;
+                const tool_calls = choices[0]?.delta?.tool_calls;
+                const textmoderation = json?.prompt_filter_results;
+
+                if (delta) {
+                  remainText += delta;
+                }
+                if (tool_calls?.length > 0) {
+                  const index = tool_calls[0]?.index;
+                  const id = tool_calls[0]?.id;
+                  const args = tool_calls[0]?.function?.arguments;
+                  if (id) {
+                    runTools.push({
+                      id,
+                      type: tool_calls[0]?.type,
+                      function: {
+                        name: tool_calls[0]?.function?.name,
+                        arguments: args,
+                      },
+                    });
+                  } else {
+                    runTools[index]["function"]["arguments"] += args;
+                  }
+                }
+
+                console.log("runTools", runTools);
+
+                if (
+                  textmoderation &&
+                  textmoderation.length > 0 &&
+                  ServiceProvider.Azure
+                ) {
+                  const contentFilterResults =
+                    textmoderation[0]?.content_filter_results;
+                  console.log(
+                    `[${ServiceProvider.Azure}] [Text Moderation] flagged categories result:`,
+                    contentFilterResults,
+                  );
+                }
+              } catch (e) {
+                console.error("[Request] parse error", text, msg);
+              }
+            },
+            onclose() {
+              finish();
+            },
+            onerror(e) {
+              options.onError?.(e);
+              throw e;
+            },
+            openWhenHidden: true,
+          });
+        }
+        chatApi(chatPath, requestPayload); // call fetchEventSource
       } else {
         const res = await fetch(chatPath, chatPayload);
         clearTimeout(requestTimeoutId);

+ 16 - 1
app/components/chat.module.scss

@@ -413,6 +413,21 @@
   margin-top: 5px;
 }
 
+.chat-message-tools {
+  font-size: 12px;
+  color: #aaa;
+  line-height: 1.5;
+  margin-top: 5px;
+  .chat-message-tool {
+    display: inline-flex;
+    align-items: end;
+    svg {
+      margin-left: 5px;
+      margin-right: 5px;
+    }
+  }
+}
+
 .chat-message-item {
   box-sizing: border-box;
   max-width: 100%;
@@ -630,4 +645,4 @@
   .chat-input-send {
     bottom: 30px;
   }
-}
+}

+ 21 - 1
app/components/chat.tsx

@@ -28,6 +28,7 @@ import DeleteIcon from "../icons/clear.svg";
 import PinIcon from "../icons/pin.svg";
 import EditIcon from "../icons/rename.svg";
 import ConfirmIcon from "../icons/confirm.svg";
+import CloseIcon from "../icons/close.svg";
 import CancelIcon from "../icons/cancel.svg";
 import ImageIcon from "../icons/image.svg";
 
@@ -1573,11 +1574,30 @@ function _Chat() {
                       </div>
                     )}
                   </div>
-                  {showTyping && (
+                  {message?.tools?.length == 0 && showTyping && (
                     <div className={styles["chat-message-status"]}>
                       {Locale.Chat.Typing}
                     </div>
                   )}
+                  {message?.tools?.length > 0 && (
+                    <div className={styles["chat-message-tools"]}>
+                      {message?.tools?.map((tool) => (
+                        <div
+                          key={tool.id}
+                          className={styles["chat-message-tool"]}
+                        >
+                          {tool.isError === false ? (
+                            <ConfirmIcon />
+                          ) : tool.isError === true ? (
+                            <CloseIcon />
+                          ) : (
+                            <LoadingButtonIcon />
+                          )}
+                          <span>{tool.function.name}</span>
+                        </div>
+                      ))}
+                    </div>
+                  )}
                   <div className={styles["chat-message-item"]}>
                     <Markdown
                       key={message.streaming ? "loading" : "done"}

+ 29 - 0
app/store/chat.ts

@@ -28,12 +28,24 @@ import { collectModelsWithDefaultModel } from "../utils/model";
 import { useAccessStore } from "./access";
 import { isDalle3 } from "../utils";
 
+export type ChatMessageTool = {
+  id: string;
+  type?: string;
+  function?: {
+    name: string;
+    arguments?: string;
+  };
+  content?: string;
+  isError?: boolean;
+};
+
 export type ChatMessage = RequestMessage & {
   date: string;
   streaming?: boolean;
   isError?: boolean;
   id: string;
   model?: ModelType;
+  tools?: ChatMessageTool[];
 };
 
 export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@@ -389,6 +401,23 @@ export const useChatStore = createPersistStore(
             }
             ChatControllerPool.remove(session.id, botMessage.id);
           },
+          onBeforeTool(tool: ChatMessageTool) {
+            (botMessage.tools = botMessage?.tools || []).push(tool);
+            get().updateCurrentSession((session) => {
+              session.messages = session.messages.concat();
+            });
+          },
+          onAfterTool(tool: ChatMessageTool) {
+            console.log("onAfterTool", botMessage);
+            botMessage?.tools?.forEach((t, i, tools) => {
+              if (tool.id == t.id) {
+                tools[i] = { ...tool };
+              }
+            });
+            get().updateCurrentSession((session) => {
+              session.messages = session.messages.concat();
+            });
+          },
           onError(error) {
             const isAborted = error.message.includes("aborted");
             botMessage.content +=