Quellcode durchsuchen

feat: now user can choose their own summarize model

skymkmk vor 1 Jahr
Ursprung
Commit
93bc2f5870
3 geänderte Dateien mit 65 neuen und 46 gelöschten Zeilen
  1. 25 0
      app/components/model-config.tsx
  2. 29 44
      app/store/chat.ts
  3. 11 2
      app/store/config.ts

+ 25 - 0
app/components/model-config.tsx

@@ -12,6 +12,7 @@ export function ModelConfigList(props: {
 }) {
   const allModels = useAllModels();
   const value = `${props.modelConfig.model}@${props.modelConfig?.providerName}`;
+  const compressModelValue = `${props.modelConfig.compressModel}@${props.modelConfig?.compressProviderName}`;
 
   return (
     <>
@@ -228,6 +229,30 @@ export function ModelConfigList(props: {
           }
         ></input>
       </ListItem>
+      <ListItem
+        title={Locale.Settings.CompressModel.Title}
+        subTitle={Locale.Settings.CompressModel.SubTitle}
+      >
+        <Select
+          aria-label={Locale.Settings.CompressModel.Title}
+          value={compressModelValue}
+          onChange={(e) => {
+            const [model, providerName] = e.currentTarget.value.split("@");
+            props.updateConfig((config) => {
+              config.compressModel = ModalConfigValidator.model(model);
+              config.compressProviderName = providerName as ServiceProvider;
+            });
+          }}
+        >
+          {allModels
+            .filter((v) => v.available)
+            .map((v, i) => (
+              <option value={`${v.name}@${v.provider?.providerName}`} key={i}>
+                {v.displayName}({v.provider?.providerName})
+              </option>
+            ))}
+        </Select>
+      </ListItem>
     </>
   );
 }

+ 29 - 44
app/store/chat.ts

@@ -1,33 +1,29 @@
-import { trimTopic, getMessageTextContent } from "../utils";
+import { getMessageTextContent, trimTopic } from "../utils";
 
-import Locale, { getLang } from "../locales";
+import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
+import { nanoid } from "nanoid";
+import type {
+  ClientApi,
+  MultimodalContent,
+  RequestMessage,
+} from "../client/api";
+import { getClientApi } from "../client/api";
+import { ChatControllerPool } from "../client/controller";
 import { showToast } from "../components/ui-lib";
-import { ModelConfig, ModelType, useAppConfig } from "./config";
-import { createEmptyMask, Mask } from "./mask";
 import {
   DEFAULT_INPUT_TEMPLATE,
   DEFAULT_MODELS,
   DEFAULT_SYSTEM_TEMPLATE,
   KnowledgeCutOffDate,
   StoreKey,
-  SUMMARIZE_MODEL,
-  GEMINI_SUMMARIZE_MODEL,
 } from "../constant";
-import { getClientApi } from "../client/api";
-import type {
-  ClientApi,
-  RequestMessage,
-  MultimodalContent,
-} from "../client/api";
-import { ChatControllerPool } from "../client/controller";
+import Locale, { getLang } from "../locales";
+import { isDalle3, safeLocalStorage } from "../utils";
 import { prettyObject } from "../utils/format";
-import { estimateTokenLength } from "../utils/token";
-import { nanoid } from "nanoid";
 import { createPersistStore } from "../utils/store";
-import { collectModelsWithDefaultModel } from "../utils/model";
-import { useAccessStore } from "./access";
-import { isDalle3, safeLocalStorage } from "../utils";
-import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
+import { estimateTokenLength } from "../utils/token";
+import { ModelConfig, ModelType, useAppConfig } from "./config";
+import { createEmptyMask, Mask } from "./mask";
 
 const localStorage = safeLocalStorage();
 
@@ -106,27 +102,6 @@ function createEmptySession(): ChatSession {
   };
 }
 
-function getSummarizeModel(currentModel: string) {
-  // if it is using gpt-* models, force to use 4o-mini to summarize
-  if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
-    const configStore = useAppConfig.getState();
-    const accessStore = useAccessStore.getState();
-    const allModel = collectModelsWithDefaultModel(
-      configStore.models,
-      [configStore.customModels, accessStore.customModels].join(","),
-      accessStore.defaultModel,
-    );
-    const summarizeModel = allModel.find(
-      (m) => m.name === SUMMARIZE_MODEL && m.available,
-    );
-    return summarizeModel?.name ?? currentModel;
-  }
-  if (currentModel.startsWith("gemini")) {
-    return GEMINI_SUMMARIZE_MODEL;
-  }
-  return currentModel;
-}
-
 function countMessages(msgs: ChatMessage[]) {
   return msgs.reduce(
     (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
@@ -581,7 +556,7 @@ export const useChatStore = createPersistStore(
           return;
         }
 
-        const providerName = modelConfig.providerName;
+        const providerName = modelConfig.compressProviderName;
         const api: ClientApi = getClientApi(providerName);
 
         // remove error messages if any
@@ -603,7 +578,7 @@ export const useChatStore = createPersistStore(
           api.llm.chat({
             messages: topicMessages,
             config: {
-              model: getSummarizeModel(session.mask.modelConfig.model),
+              model: modelConfig.compressModel,
               stream: false,
               providerName,
             },
@@ -666,7 +641,7 @@ export const useChatStore = createPersistStore(
             config: {
               ...modelcfg,
               stream: true,
-              model: getSummarizeModel(session.mask.modelConfig.model),
+              model: modelConfig.compressModel,
             },
             onUpdate(message) {
               session.memoryPrompt = message;
@@ -715,7 +690,7 @@ export const useChatStore = createPersistStore(
   },
   {
     name: StoreKey.Chat,
-    version: 3.1,
+    version: 3.2,
     migrate(persistedState, version) {
       const state = persistedState as any;
       const newState = JSON.parse(
@@ -762,6 +737,16 @@ export const useChatStore = createPersistStore(
         });
       }
 
+      // add default summarize model for every session
+      if (version < 3.2) {
+        newState.sessions.forEach((s) => {
+          const config = useAppConfig.getState();
+          s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
+          s.mask.modelConfig.compressProviderName =
+            config.modelConfig.compressProviderName;
+        });
+      }
+
       return newState as any;
     },
   },

+ 11 - 2
app/store/config.ts

@@ -50,7 +50,7 @@ export const DEFAULT_CONFIG = {
   models: DEFAULT_MODELS as any as LLMModel[],
 
   modelConfig: {
-    model: "gpt-3.5-turbo" as ModelType,
+    model: "gpt-4o-mini" as ModelType,
     providerName: "OpenAI" as ServiceProvider,
     temperature: 0.5,
     top_p: 1,
@@ -60,6 +60,8 @@ export const DEFAULT_CONFIG = {
     sendMemory: true,
     historyMessageCount: 4,
     compressMessageLengthThreshold: 1000,
+    compressModel: "gpt-4o-mini" as ModelType,
+    compressProviderName: "OpenAI" as ServiceProvider,
     enableInjectSystemPrompts: true,
     template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
     size: "1024x1024" as DalleSize,
@@ -140,7 +142,7 @@ export const useAppConfig = createPersistStore(
   }),
   {
     name: StoreKey.Config,
-    version: 3.9,
+    version: 4,
     migrate(persistedState, version) {
       const state = persistedState as ChatConfig;
 
@@ -178,6 +180,13 @@ export const useAppConfig = createPersistStore(
             : config?.template ?? DEFAULT_INPUT_TEMPLATE;
       }
 
+      if (version < 4) {
+        state.modelConfig.compressModel =
+          DEFAULT_CONFIG.modelConfig.compressModel;
+        state.modelConfig.compressProviderName =
+          DEFAULT_CONFIG.modelConfig.compressProviderName;
+      }
+
       return state as any;
     },
   },