Browse Source

fix compressModel, related #5426, fix #5606 #5603 #5575

lloydzhou 1 year ago
parent
commit
a925b424a8
2 changed files with 50 additions and 8 deletions
  1. 46 4
      app/store/chat.ts
  2. 4 4
      app/store/config.ts

+ 46 - 4
app/store/chat.ts

@@ -16,6 +16,8 @@ import {
   DEFAULT_SYSTEM_TEMPLATE,
   KnowledgeCutOffDate,
   StoreKey,
+  SUMMARIZE_MODEL,
+  GEMINI_SUMMARIZE_MODEL,
 } from "../constant";
 import Locale, { getLang } from "../locales";
 import { isDalle3, safeLocalStorage } from "../utils";
@@ -23,6 +25,8 @@ import { prettyObject } from "../utils/format";
 import { createPersistStore } from "../utils/store";
 import { estimateTokenLength } from "../utils/token";
 import { ModelConfig, ModelType, useAppConfig } from "./config";
+import { useAccessStore } from "./access";
+import { collectModelsWithDefaultModel } from "../utils/model";
 import { createEmptyMask, Mask } from "./mask";
 
 const localStorage = safeLocalStorage();
@@ -103,6 +107,29 @@ function createEmptySession(): ChatSession {
   };
 }
 
+function getSummarizeModel(currentModel: string, providerName: string) {
+  // if it is using gpt-* models, force to use 4o-mini to summarize
+  if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
+    const configStore = useAppConfig.getState();
+    const accessStore = useAccessStore.getState();
+    const allModel = collectModelsWithDefaultModel(
+      configStore.models,
+      [configStore.customModels, accessStore.customModels].join(","),
+      accessStore.defaultModel,
+    );
+    const summarizeModel = allModel.find(
+      (m) => m.name === SUMMARIZE_MODEL && m.available,
+    );
+    if (summarizeModel) {
+      return [summarizeModel.name, summarizeModel.providerName];
+    }
+  }
+  if (currentModel.startsWith("gemini")) {
+    return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
+  }
+  return [currentModel, providerName];
+}
+
 function countMessages(msgs: ChatMessage[]) {
   return msgs.reduce(
     (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
@@ -579,7 +606,13 @@ export const useChatStore = createPersistStore(
           return;
         }
 
-        const providerName = modelConfig.compressProviderName;
+        // if not config compressModel, then using getSummarizeModel
+        const [model, providerName] = modelConfig.compressModel
+          ? [modelConfig.compressModel, modelConfig.compressProviderName]
+          : getSummarizeModel(
+              session.mask.modelConfig.model,
+              session.mask.modelConfig.providerName,
+            );
         const api: ClientApi = getClientApi(providerName);
 
         // remove error messages if any
@@ -611,7 +644,7 @@ export const useChatStore = createPersistStore(
           api.llm.chat({
             messages: topicMessages,
             config: {
-              model: modelConfig.compressModel,
+              model,
               stream: false,
               providerName,
             },
@@ -675,7 +708,8 @@ export const useChatStore = createPersistStore(
             config: {
               ...modelcfg,
               stream: true,
-              model: modelConfig.compressModel,
+              model,
+              providerName,
             },
             onUpdate(message) {
               session.memoryPrompt = message;
@@ -728,7 +762,7 @@ export const useChatStore = createPersistStore(
   },
   {
     name: StoreKey.Chat,
-    version: 3.2,
+    version: 3.3,
     migrate(persistedState, version) {
       const state = persistedState as any;
       const newState = JSON.parse(
@@ -784,6 +818,14 @@ export const useChatStore = createPersistStore(
             config.modelConfig.compressProviderName;
         });
       }
+      // revert default summarize model for every session
+      if (version < 3.3) {
+        newState.sessions.forEach((s) => {
+          const config = useAppConfig.getState();
+          s.mask.modelConfig.compressModel = undefined;
+          s.mask.modelConfig.compressProviderName = undefined;
+        });
+      }
 
       return newState as any;
     },

+ 4 - 4
app/store/config.ts

@@ -71,8 +71,8 @@ export const DEFAULT_CONFIG = {
     sendMemory: true,
     historyMessageCount: 4,
     compressMessageLengthThreshold: 1000,
-    compressModel: "gpt-4o-mini" as ModelType,
-    compressProviderName: "OpenAI" as ServiceProvider,
+    compressModel: undefined,
+    compressProviderName: undefined,
     enableInjectSystemPrompts: true,
     template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
     size: "1024x1024" as DalleSize,
@@ -178,7 +178,7 @@ export const useAppConfig = createPersistStore(
   }),
   {
     name: StoreKey.Config,
-    version: 4,
+    version: 4.1,
 
     merge(persistedState, currentState) {
       const state = persistedState as ChatConfig | undefined;
@@ -231,7 +231,7 @@ export const useAppConfig = createPersistStore(
             : config?.template ?? DEFAULT_INPUT_TEMPLATE;
       }
 
-      if (version < 4) {
+      if (version < 4.1) {
         state.modelConfig.compressModel =
           DEFAULT_CONFIG.modelConfig.compressModel;
         state.modelConfig.compressProviderName =