opchips преди 1 година
родител
ревизия
6667ee1c7f
променени са 11 файла, в които са добавени 155 реда и са изтрити 73 реда
  1. 8 0
      README.md
  2. 7 0
      README_CN.md
  3. 2 2
      app/api/common.ts
  4. 48 37
      app/components/chat.tsx
  5. 7 2
      app/components/model-config.tsx
  6. 3 2
      app/constant.ts
  7. 3 2
      app/store/access.ts
  8. 30 23
      app/store/chat.ts
  9. 15 4
      app/utils/model.ts
  10. 1 1
      src-tauri/tauri.conf.json
  11. 31 0
      test/model-provider.test.ts

+ 8 - 0
README.md

@@ -301,6 +301,14 @@ iflytek Api Key.
 
 iflytek Api Secret.
 
+### `CHATGLM_API_KEY` (optional)
+
+ChatGLM Api Key.
+
+### `CHATGLM_URL` (optional)
+
+ChatGLM Api Url.
+
 ### `HIDE_USER_API_KEY` (optional)
 
 > Default: Empty

+ 7 - 0
README_CN.md

@@ -184,6 +184,13 @@ ByteDance Api Url.
 
 讯飞星火Api Secret.
 
+### `CHATGLM_API_KEY` (可选)
+
+ChatGLM Api Key.
+
+### `CHATGLM_URL` (可选)
+
+ChatGLM Api Url.
 
 
 ### `HIDE_USER_API_KEY` (可选)

+ 2 - 2
app/api/common.ts

@@ -1,8 +1,8 @@
 import { NextRequest, NextResponse } from "next/server";
 import { getServerSideConfig } from "../config/server";
 import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
-import { isModelAvailableInServer } from "../utils/model";
 import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
+import { getModelProvider, isModelAvailableInServer } from "../utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
         .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
         .forEach((m) => {
           const [fullName, displayName] = m.split("=");
-          const [_, providerName] = fullName.split("@");
+          const [_, providerName] = getModelProvider(fullName);
           if (providerName === "azure" && !displayName) {
             const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
               "deployments/",

+ 48 - 37
app/components/chat.tsx

@@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
 import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
 
 import { isEmpty } from "lodash-es";
+import { getModelProvider } from "../utils/model";
 
 const localStorage = safeLocalStorage();
 
@@ -148,7 +149,8 @@ export function SessionConfigModel(props: { onClose: () => void }) {
             text={Locale.Chat.Config.Reset}
             onClick={async () => {
               if (await showConfirm(Locale.Memory.ResetConfirm)) {
-                chatStore.updateCurrentSession(
+                chatStore.updateTargetSession(
+                  session,
                   (session) => (session.memoryPrompt = ""),
                 );
               }
@@ -173,7 +175,10 @@ export function SessionConfigModel(props: { onClose: () => void }) {
           updateMask={(updater) => {
             const mask = { ...session.mask };
             updater(mask);
-            chatStore.updateCurrentSession((session) => (session.mask = mask));
+            chatStore.updateTargetSession(
+              session,
+              (session) => (session.mask = mask),
+            );
           }}
           shouldSyncFromGlobal
           extraListItems={
@@ -345,12 +350,14 @@ export function PromptHints(props: {
 
 function ClearContextDivider() {
   const chatStore = useChatStore();
+  const session = chatStore.currentSession();
 
   return (
     <div
       className={styles["clear-context"]}
       onClick={() =>
-        chatStore.updateCurrentSession(
+        chatStore.updateTargetSession(
+          session,
           (session) => (session.clearContextIndex = undefined),
         )
       }
@@ -460,6 +467,7 @@ export function ChatActions(props: {
   const navigate = useNavigate();
   const chatStore = useChatStore();
   const pluginStore = usePluginStore();
+  const session = chatStore.currentSession();
 
   // switch themes
   const theme = config.theme;
@@ -476,10 +484,9 @@ export function ChatActions(props: {
   const stopAll = () => ChatControllerPool.stopAll();
 
   // switch model
-  const currentModel = chatStore.currentSession().mask.modelConfig.model;
+  const currentModel = session.mask.modelConfig.model;
   const currentProviderName =
-    chatStore.currentSession().mask.modelConfig?.providerName ||
-    ServiceProvider.OpenAI;
+    session.mask.modelConfig?.providerName || ServiceProvider.OpenAI;
   const allModels = useAllModels();
   const models = useMemo(() => {
     const filteredModels = allModels.filter((m) => m.available);
@@ -513,12 +520,9 @@ export function ChatActions(props: {
   const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
   const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
   const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
-  const currentSize =
-    chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024";
-  const currentQuality =
-    chatStore.currentSession().mask.modelConfig?.quality ?? "standard";
-  const currentStyle =
-    chatStore.currentSession().mask.modelConfig?.style ?? "vivid";
+  const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
+  const currentQuality = session.mask.modelConfig?.quality ?? "standard";
+  const currentStyle = session.mask.modelConfig?.style ?? "vivid";
 
   const isMobileScreen = useMobileScreen();
 
@@ -536,7 +540,7 @@ export function ChatActions(props: {
     if (isUnavailableModel && models.length > 0) {
       // show next model to default model if exist
       let nextModel = models.find((model) => model.isDefault) || models[0];
-      chatStore.updateCurrentSession((session) => {
+      chatStore.updateTargetSession(session, (session) => {
         session.mask.modelConfig.model = nextModel.name;
         session.mask.modelConfig.providerName = nextModel?.provider
           ?.providerName as ServiceProvider;
@@ -547,7 +551,7 @@ export function ChatActions(props: {
           : nextModel.name,
       );
     }
-  }, [chatStore, currentModel, models]);
+  }, [chatStore, currentModel, models, session]);
 
   return (
     <div className={styles["chat-input-actions"]}>
@@ -614,7 +618,7 @@ export function ChatActions(props: {
         text={Locale.Chat.InputActions.Clear}
         icon={<BreakIcon />}
         onClick={() => {
-          chatStore.updateCurrentSession((session) => {
+          chatStore.updateTargetSession(session, (session) => {
             if (session.clearContextIndex === session.messages.length) {
               session.clearContextIndex = undefined;
             } else {
@@ -645,8 +649,8 @@ export function ChatActions(props: {
           onClose={() => setShowModelSelector(false)}
           onSelection={(s) => {
             if (s.length === 0) return;
-            const [model, providerName] = s[0].split("@");
-            chatStore.updateCurrentSession((session) => {
+            const [model, providerName] = getModelProvider(s[0]);
+            chatStore.updateTargetSession(session, (session) => {
               session.mask.modelConfig.model = model as ModelType;
               session.mask.modelConfig.providerName =
                 providerName as ServiceProvider;
@@ -684,7 +688,7 @@ export function ChatActions(props: {
           onSelection={(s) => {
             if (s.length === 0) return;
             const size = s[0];
-            chatStore.updateCurrentSession((session) => {
+            chatStore.updateTargetSession(session, (session) => {
               session.mask.modelConfig.size = size;
             });
             showToast(size);
@@ -711,7 +715,7 @@ export function ChatActions(props: {
           onSelection={(q) => {
             if (q.length === 0) return;
             const quality = q[0];
-            chatStore.updateCurrentSession((session) => {
+            chatStore.updateTargetSession(session, (session) => {
               session.mask.modelConfig.quality = quality;
             });
             showToast(quality);
@@ -738,7 +742,7 @@ export function ChatActions(props: {
           onSelection={(s) => {
             if (s.length === 0) return;
             const style = s[0];
-            chatStore.updateCurrentSession((session) => {
+            chatStore.updateTargetSession(session, (session) => {
               session.mask.modelConfig.style = style;
             });
             showToast(style);
@@ -769,7 +773,7 @@ export function ChatActions(props: {
           }))}
           onClose={() => setShowPluginSelector(false)}
           onSelection={(s) => {
-            chatStore.updateCurrentSession((session) => {
+            chatStore.updateTargetSession(session, (session) => {
               session.mask.plugin = s as string[];
             });
           }}
@@ -812,7 +816,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
             icon={<ConfirmIcon />}
             key="ok"
             onClick={() => {
-              chatStore.updateCurrentSession(
+              chatStore.updateTargetSession(
+                session,
                 (session) => (session.messages = messages),
               );
               props.onClose();
@@ -829,7 +834,8 @@ export function EditMessageModal(props: { onClose: () => void }) {
               type="text"
               value={session.topic}
               onInput={(e) =>
-                chatStore.updateCurrentSession(
+                chatStore.updateTargetSession(
+                  session,
                   (session) => (session.topic = e.currentTarget.value),
                 )
               }
@@ -990,7 +996,8 @@ function _Chat() {
     prev: () => chatStore.nextSession(-1),
     next: () => chatStore.nextSession(1),
     clear: () =>
-      chatStore.updateCurrentSession(
+      chatStore.updateTargetSession(
+        session,
         (session) => (session.clearContextIndex = session.messages.length),
       ),
     fork: () => chatStore.forkSession(),
@@ -1061,7 +1068,7 @@ function _Chat() {
   };
 
   useEffect(() => {
-    chatStore.updateCurrentSession((session) => {
+    chatStore.updateTargetSession(session, (session) => {
       const stopTiming = Date.now() - REQUEST_TIMEOUT_MS;
       session.messages.forEach((m) => {
         // check if should stop all stale messages
@@ -1087,7 +1094,7 @@ function _Chat() {
       }
     });
     // eslint-disable-next-line react-hooks/exhaustive-deps
-  }, []);
+  }, [session]);
 
   // check if should send message
   const onInputKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
@@ -1118,7 +1125,8 @@ function _Chat() {
   };
 
   const deleteMessage = (msgId?: string) => {
-    chatStore.updateCurrentSession(
+    chatStore.updateTargetSession(
+      session,
       (session) =>
         (session.messages = session.messages.filter((m) => m.id !== msgId)),
     );
@@ -1185,7 +1193,7 @@ function _Chat() {
   };
 
   const onPinMessage = (message: ChatMessage) => {
-    chatStore.updateCurrentSession((session) =>
+    chatStore.updateTargetSession(session, (session) =>
       session.mask.context.push(message),
     );
 
@@ -1607,7 +1615,7 @@ function _Chat() {
               title={Locale.Chat.Actions.RefreshTitle}
               onClick={() => {
                 showToast(Locale.Chat.Actions.RefreshToast);
-                chatStore.summarizeSession(true);
+                chatStore.summarizeSession(true, session);
               }}
             />
           </div>
@@ -1711,14 +1719,17 @@ function _Chat() {
                                 });
                               }
                             }
-                            chatStore.updateCurrentSession((session) => {
-                              const m = session.mask.context
-                                .concat(session.messages)
-                                .find((m) => m.id === message.id);
-                              if (m) {
-                                m.content = newContent;
-                              }
-                            });
+                            chatStore.updateTargetSession(
+                              session,
+                              (session) => {
+                                const m = session.mask.context
+                                  .concat(session.messages)
+                                  .find((m) => m.id === message.id);
+                                if (m) {
+                                  m.content = newContent;
+                                }
+                              },
+                            );
                           }}
                         ></IconButton>
                       </div>

+ 7 - 2
app/components/model-config.tsx

@@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
 import { useAllModels } from "../utils/hooks";
 import { groupBy } from "lodash-es";
 import styles from "./model-config.module.scss";
+import { getModelProvider } from "../utils/model";
 
 export function ModelConfigList(props: {
   modelConfig: ModelConfig;
@@ -28,7 +29,9 @@ export function ModelConfigList(props: {
           value={value}
           align="left"
           onChange={(e) => {
-            const [model, providerName] = e.currentTarget.value.split("@");
+            const [model, providerName] = getModelProvider(
+              e.currentTarget.value,
+            );
             props.updateConfig((config) => {
               config.model = ModalConfigValidator.model(model);
               config.providerName = providerName as ServiceProvider;
@@ -247,7 +250,9 @@ export function ModelConfigList(props: {
           aria-label={Locale.Settings.CompressModel.Title}
           value={compressModelValue}
           onChange={(e) => {
-            const [model, providerName] = e.currentTarget.value.split("@");
+            const [model, providerName] = getModelProvider(
+              e.currentTarget.value,
+            );
             props.updateConfig((config) => {
               config.compressModel = ModalConfigValidator.model(model);
               config.compressProviderName = providerName as ServiceProvider;

+ 3 - 2
app/constant.ts

@@ -232,7 +232,7 @@ export const XAI = {
 
 export const ChatGLM = {
   ExampleEndpoint: CHATGLM_BASE_URL,
-  ChatPath: "/api/paas/v4/chat/completions",
+  ChatPath: "api/paas/v4/chat/completions",
 };
 
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@@ -327,12 +327,13 @@ const anthropicModels = [
   "claude-2.1",
   "claude-3-sonnet-20240229",
   "claude-3-opus-20240229",
+  "claude-3-opus-latest",
   "claude-3-haiku-20240307",
   "claude-3-5-haiku-20241022",
   "claude-3-5-sonnet-20240620",
   "claude-3-5-sonnet-20241022",
   "claude-3-5-sonnet-latest",
-  "claude-3-opus-latest",
+  "claude-3-5-haiku-latest",
 ];
 
 const baiduModels = [

+ 3 - 2
app/store/access.ts

@@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
 import { createPersistStore } from "../utils/store";
 import { ensure } from "../utils/clone";
 import { DEFAULT_CONFIG } from "./config";
+import { getModelProvider } from "../utils/model";
 
 let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
 
@@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
         .then((res) => {
           const defaultModel = res.defaultModel ?? "";
           if (defaultModel !== "") {
-            const [model, providerName] = defaultModel.split("@");
+            const [model, providerName] = getModelProvider(defaultModel);
             DEFAULT_CONFIG.modelConfig.model = model;
-            DEFAULT_CONFIG.modelConfig.providerName = providerName;
+            DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
           }
 
           return res;

+ 30 - 23
app/store/chat.ts

@@ -352,13 +352,13 @@ export const useChatStore = createPersistStore(
         return session;
       },
 
-      onNewMessage(message: ChatMessage) {
-        get().updateCurrentSession((session) => {
+      onNewMessage(message: ChatMessage, targetSession: ChatSession) {
+        get().updateTargetSession(targetSession, (session) => {
           session.messages = session.messages.concat();
           session.lastUpdate = Date.now();
         });
-        get().updateStat(message);
-        get().summarizeSession();
+        get().updateStat(message, targetSession);
+        get().summarizeSession(false, targetSession);
       },
 
       async onUserInput(content: string, attachImages?: string[]) {
@@ -396,10 +396,10 @@ export const useChatStore = createPersistStore(
         // get recent messages
         const recentMessages = get().getMessagesWithMemory();
         const sendMessages = recentMessages.concat(userMessage);
-        const messageIndex = get().currentSession().messages.length + 1;
+        const messageIndex = session.messages.length + 1;
 
         // save user's and bot's message
-        get().updateCurrentSession((session) => {
+        get().updateTargetSession(session, (session) => {
           const savedUserMessage = {
             ...userMessage,
             content: mContent,
@@ -420,7 +420,7 @@ export const useChatStore = createPersistStore(
             if (message) {
               botMessage.content = message;
             }
-            get().updateCurrentSession((session) => {
+            get().updateTargetSession(session, (session) => {
               session.messages = session.messages.concat();
             });
           },
@@ -428,13 +428,14 @@ export const useChatStore = createPersistStore(
             botMessage.streaming = false;
             if (message) {
               botMessage.content = message;
-              get().onNewMessage(botMessage);
+              botMessage.date = new Date().toLocaleString();
+              get().onNewMessage(botMessage, session);
             }
             ChatControllerPool.remove(session.id, botMessage.id);
           },
           onBeforeTool(tool: ChatMessageTool) {
             (botMessage.tools = botMessage?.tools || []).push(tool);
-            get().updateCurrentSession((session) => {
+            get().updateTargetSession(session, (session) => {
               session.messages = session.messages.concat();
             });
           },
@@ -444,7 +445,7 @@ export const useChatStore = createPersistStore(
                 tools[i] = { ...tool };
               }
             });
-            get().updateCurrentSession((session) => {
+            get().updateTargetSession(session, (session) => {
               session.messages = session.messages.concat();
             });
           },
@@ -459,7 +460,7 @@ export const useChatStore = createPersistStore(
             botMessage.streaming = false;
             userMessage.isError = !isAborted;
             botMessage.isError = !isAborted;
-            get().updateCurrentSession((session) => {
+            get().updateTargetSession(session, (session) => {
               session.messages = session.messages.concat();
             });
             ChatControllerPool.remove(
@@ -591,16 +592,19 @@ export const useChatStore = createPersistStore(
         set(() => ({ sessions }));
       },
 
-      resetSession() {
-        get().updateCurrentSession((session) => {
+      resetSession(session: ChatSession) {
+        get().updateTargetSession(session, (session) => {
           session.messages = [];
           session.memoryPrompt = "";
         });
       },
 
-      summarizeSession(refreshTitle: boolean = false) {
+      summarizeSession(
+        refreshTitle: boolean = false,
+        targetSession: ChatSession,
+      ) {
         const config = useAppConfig.getState();
-        const session = get().currentSession();
+        const session = targetSession;
         const modelConfig = session.mask.modelConfig;
         // skip summarize when using dalle3?
         if (isDalle3(modelConfig.model)) {
@@ -651,7 +655,8 @@ export const useChatStore = createPersistStore(
             },
             onFinish(message, responseRes) {
               if (responseRes?.status === 200) {
-                get().updateCurrentSession(
+                get().updateTargetSession(
+                  session,
                   (session) =>
                     (session.topic =
                       message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
@@ -719,7 +724,7 @@ export const useChatStore = createPersistStore(
             onFinish(message, responseRes) {
               if (responseRes?.status === 200) {
                 console.log("[Memory] ", message);
-                get().updateCurrentSession((session) => {
+                get().updateTargetSession(session, (session) => {
                   session.lastSummarizeIndex = lastSummarizeIndex;
                   session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
                 });
@@ -732,20 +737,22 @@ export const useChatStore = createPersistStore(
         }
       },
 
-      updateStat(message: ChatMessage) {
-        get().updateCurrentSession((session) => {
+      updateStat(message: ChatMessage, session: ChatSession) {
+        get().updateTargetSession(session, (session) => {
           session.stat.charCount += message.content.length;
           // TODO: should update chat count and word count
         });
       },
-
-      updateCurrentSession(updater: (session: ChatSession) => void) {
+      updateTargetSession(
+        targetSession: ChatSession,
+        updater: (session: ChatSession) => void,
+      ) {
         const sessions = get().sessions;
-        const index = get().currentSessionIndex;
+        const index = sessions.findIndex((s) => s.id === targetSession.id);
+        if (index < 0) return;
         updater(sessions[index]);
         set(() => ({ sessions }));
       },
-
       async clearAllData() {
         await indexedDBStorage.clear();
         localStorage.clear();

+ 15 - 4
app/utils/model.ts

@@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType<typeof collectModels>) =>
     }
   });
 
+/**
+ * get model name and provider from a formatted string,
+ * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
+ * @param modelWithProvider model name with provider separated by last `@` char,
+ * @returns [model, provider] tuple, if no `@` char found, provider is undefined
+ */
+export function getModelProvider(modelWithProvider: string): [string, string?] {
+  const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
+  return [model, provider];
+}
+
 export function collectModelTable(
   models: readonly LLMModel[],
   customModels: string,
@@ -79,10 +90,10 @@ export function collectModelTable(
         );
       } else {
         // 1. find model by name, and set available value
-        const [customModelName, customProviderName] = name.split("@");
+        const [customModelName, customProviderName] = getModelProvider(name);
         let count = 0;
         for (const fullName in modelTable) {
-          const [modelName, providerName] = fullName.split("@");
+          const [modelName, providerName] = getModelProvider(fullName);
           if (
             customModelName == modelName &&
             (customProviderName === undefined ||
@@ -102,7 +113,7 @@ export function collectModelTable(
         }
         // 2. if model not exists, create new model with available value
         if (count === 0) {
-          let [customModelName, customProviderName] = name.split("@");
+          let [customModelName, customProviderName] = getModelProvider(name);
           const provider = customProvider(
             customProviderName || customModelName,
           );
@@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel(
       for (const key of Object.keys(modelTable)) {
         if (
           modelTable[key].available &&
-          key.split("@").shift() == defaultModel
+          getModelProvider(key)[0] == defaultModel
         ) {
           modelTable[key].isDefault = true;
           break;

+ 1 - 1
src-tauri/tauri.conf.json

@@ -9,7 +9,7 @@
   },
   "package": {
     "productName": "NextChat",
-    "version": "2.15.6"
+    "version": "2.15.7"
   },
   "tauri": {
     "allowlist": {

+ 31 - 0
test/model-provider.test.ts

@@ -0,0 +1,31 @@
+import { getModelProvider } from "../app/utils/model";
+
+describe("getModelProvider", () => {
+  test("should return model and provider when input contains '@'", () => {
+    const input = "model@provider";
+    const [model, provider] = getModelProvider(input);
+    expect(model).toBe("model");
+    expect(provider).toBe("provider");
+  });
+
+  test("should return model and undefined provider when input does not contain '@'", () => {
+    const input = "model";
+    const [model, provider] = getModelProvider(input);
+    expect(model).toBe("model");
+    expect(provider).toBeUndefined();
+  });
+
+  test("should handle multiple '@' characters correctly", () => {
+    const input = "model@provider@extra";
+    const [model, provider] = getModelProvider(input);
+    expect(model).toBe("model@provider");
+    expect(provider).toBe("extra");
+  });
+
+  test("should return empty strings when input is empty", () => {
+    const input = "";
+    const [model, provider] = getModelProvider(input);
+    expect(model).toBe("");
+    expect(provider).toBeUndefined();
+  });
+});