Ver código fonte

Merge pull request #5426 from skymkmk/pr-summarize-customization

feat: summarize model customization
Dogtiti 1 ano atrás
pai
commit
b32d82e6c1

+ 25 - 0
app/components/model-config.tsx

@@ -17,6 +17,7 @@ export function ModelConfigList(props: {
     "provider.providerName",
   );
   const value = `${props.modelConfig.model}@${props.modelConfig?.providerName}`;
+  const compressModelValue = `${props.modelConfig.compressModel}@${props.modelConfig?.compressProviderName}`;
 
   return (
     <>
@@ -236,6 +237,30 @@ export function ModelConfigList(props: {
           }
         ></input>
       </ListItem>
+      <ListItem
+        title={Locale.Settings.CompressModel.Title}
+        subTitle={Locale.Settings.CompressModel.SubTitle}
+      >
+        <Select
+          aria-label={Locale.Settings.CompressModel.Title}
+          value={compressModelValue}
+          onChange={(e) => {
+            const [model, providerName] = e.currentTarget.value.split("@");
+            props.updateConfig((config) => {
+              config.compressModel = ModalConfigValidator.model(model);
+              config.compressProviderName = providerName as ServiceProvider;
+            });
+          }}
+        >
+          {allModels
+            .filter((v) => v.available)
+            .map((v, i) => (
+              <option value={`${v.name}@${v.provider?.providerName}`} key={i}>
+                {v.displayName}({v.provider?.providerName})
+              </option>
+            ))}
+        </Select>
+      </ListItem>
     </>
   );
 }

+ 4 - 0
app/locales/ar.ts

@@ -404,6 +404,10 @@ const ar: PartialLocaleType = {
     },
 
     Model: "النموذج",
+    CompressModel: {
+      Title: "نموذج الضغط",
+      SubTitle: "النموذج المستخدم لضغط السجل التاريخي",
+    },
     Temperature: {
       Title: "العشوائية (temperature)",
       SubTitle: "كلما زادت القيمة، زادت العشوائية في الردود",

+ 4 - 0
app/locales/bn.ts

@@ -411,6 +411,10 @@ const bn: PartialLocaleType = {
     },
 
     Model: "মডেল (model)",
+    CompressModel: {
+      Title: "সংকোচন মডেল",
+      SubTitle: "ইতিহাস সংকুচিত করার জন্য ব্যবহৃত মডেল",
+    },
     Temperature: {
       Title: "যাদুকরিতা (temperature)",
       SubTitle: "মান বাড়ালে উত্তর বেশি এলোমেলো হবে",

+ 4 - 0
app/locales/cn.ts

@@ -470,6 +470,10 @@ const cn = {
     },
 
     Model: "模型 (model)",
+    CompressModel: {
+      Title: "压缩模型",
+      SubTitle: "用于压缩历史记录的模型",
+    },
     Temperature: {
       Title: "随机性 (temperature)",
       SubTitle: "值越大,回复越随机",

+ 4 - 0
app/locales/cs.ts

@@ -410,6 +410,10 @@ const cs: PartialLocaleType = {
     },
 
     Model: "Model (model)",
+    CompressModel: {
+      Title: "Kompresní model",
+      SubTitle: "Model používaný pro kompresi historie",
+    },
     Temperature: {
       Title: "Náhodnost (temperature)",
       SubTitle: "Čím vyšší hodnota, tím náhodnější odpovědi",

+ 4 - 0
app/locales/de.ts

@@ -421,6 +421,10 @@ const de: PartialLocaleType = {
     },
 
     Model: "Modell",
+    CompressModel: {
+      Title: "Kompressionsmodell",
+      SubTitle: "Modell zur Komprimierung des Verlaufs",
+    },
     Temperature: {
       Title: "Zufälligkeit (temperature)",
       SubTitle: "Je höher der Wert, desto zufälliger die Antwort",

+ 4 - 0
app/locales/en.ts

@@ -474,6 +474,10 @@ const en: LocaleType = {
     },
 
     Model: "Model",
+    CompressModel: {
+      Title: "Compression Model",
+      SubTitle: "Model used to compress history",
+    },
     Temperature: {
       Title: "Temperature",
       SubTitle: "A larger value makes the more random output",

+ 4 - 0
app/locales/es.ts

@@ -423,6 +423,10 @@ const es: PartialLocaleType = {
     },
 
     Model: "Modelo (model)",
+    CompressModel: {
+      Title: "Modelo de compresión",
+      SubTitle: "Modelo utilizado para comprimir el historial",
+    },
     Temperature: {
       Title: "Aleatoriedad (temperature)",
       SubTitle: "Cuanto mayor sea el valor, más aleatorio será el resultado",

+ 4 - 0
app/locales/fr.ts

@@ -422,6 +422,10 @@ const fr: PartialLocaleType = {
     },
 
     Model: "Modèle",
+    CompressModel: {
+      Title: "Modèle de compression",
+      SubTitle: "Modèle utilisé pour compresser l'historique",
+    },
     Temperature: {
       Title: "Aléatoire (temperature)",
       SubTitle: "Plus la valeur est élevée, plus les réponses sont aléatoires",

+ 4 - 0
app/locales/id.ts

@@ -411,6 +411,10 @@ const id: PartialLocaleType = {
     },
 
     Model: "Model",
+    CompressModel: {
+      Title: "Model Kompresi",
+      SubTitle: "Model yang digunakan untuk mengompres riwayat",
+    },
     Temperature: {
       Title: "Randomness (temperature)",
       SubTitle: "Semakin tinggi nilainya, semakin acak responsnya",

+ 4 - 0
app/locales/it.ts

@@ -423,6 +423,10 @@ const it: PartialLocaleType = {
     },
 
     Model: "Modello (model)",
+    CompressModel: {
+      Title: "Modello di compressione",
+      SubTitle: "Modello utilizzato per comprimere la cronologia",
+    },
     Temperature: {
       Title: "Casualità (temperature)",
       SubTitle: "Valore più alto, risposte più casuali",

+ 4 - 0
app/locales/jp.ts

@@ -407,6 +407,10 @@ const jp: PartialLocaleType = {
     },
 
     Model: "モデル (model)",
+    CompressModel: {
+      Title: "圧縮モデル",
+      SubTitle: "履歴を圧縮するために使用されるモデル",
+    },
     Temperature: {
       Title: "ランダム性 (temperature)",
       SubTitle: "値が大きいほど応答がランダムになります",

+ 4 - 0
app/locales/ko.ts

@@ -404,6 +404,10 @@ const ko: PartialLocaleType = {
     },
 
     Model: "모델 (model)",
+    CompressModel: {
+      Title: "압축 모델",
+      SubTitle: "기록을 압축하는 데 사용되는 모델",
+    },
     Temperature: {
       Title: "무작위성 (temperature)",
       SubTitle: "값이 클수록 응답이 더 무작위적",

+ 4 - 0
app/locales/no.ts

@@ -415,6 +415,10 @@ const no: PartialLocaleType = {
     },
 
     Model: "Modell",
+    CompressModel: {
+      Title: "Komprimeringsmodell",
+      SubTitle: "Modell brukt for å komprimere historikken",
+    },
     Temperature: {
       Title: "Tilfeldighet (temperature)",
       SubTitle: "Høyere verdi gir mer tilfeldige svar",

+ 4 - 0
app/locales/pt.ts

@@ -346,6 +346,10 @@ const pt: PartialLocaleType = {
     },
 
     Model: "Modelo",
+    CompressModel: {
+      Title: "Modelo de Compressão",
+      SubTitle: "Modelo usado para comprimir o histórico",
+    },
     Temperature: {
       Title: "Temperatura",
       SubTitle: "Um valor maior torna a saída mais aleatória",

+ 4 - 0
app/locales/ru.ts

@@ -414,6 +414,10 @@ const ru: PartialLocaleType = {
     },
 
     Model: "Модель",
+    CompressModel: {
+      Title: "Модель сжатия",
+      SubTitle: "Модель, используемая для сжатия истории",
+    },
     Temperature: {
       Title: "Случайность (temperature)",
       SubTitle: "Чем больше значение, тем более случайные ответы",

+ 4 - 0
app/locales/sk.ts

@@ -365,6 +365,10 @@ const sk: PartialLocaleType = {
     },
 
     Model: "Model",
+    CompressModel: {
+      Title: "Kompresný model",
+      SubTitle: "Model používaný na kompresiu histórie",
+    },
     Temperature: {
       Title: "Teplota",
       SubTitle: "Vyššia hodnota robí výstup náhodnejším",

+ 4 - 0
app/locales/tr.ts

@@ -414,6 +414,10 @@ const tr: PartialLocaleType = {
     },
 
     Model: "Model (model)",
+    CompressModel: {
+      Title: "Sıkıştırma Modeli",
+      SubTitle: "Geçmişi sıkıştırmak için kullanılan model",
+    },
     Temperature: {
       Title: "Rastgelelik (temperature)",
       SubTitle: "Değer arttıkça yanıt daha rastgele olur",

+ 4 - 0
app/locales/tw.ts

@@ -368,6 +368,10 @@ const tw = {
     },
 
     Model: "模型 (model)",
+    CompressModel: {
+      Title: "壓縮模型",
+      SubTitle: "用於壓縮歷史記錄的模型",
+    },
     Temperature: {
       Title: "隨機性 (temperature)",
       SubTitle: "值越大,回應越隨機",

+ 4 - 0
app/locales/vi.ts

@@ -410,6 +410,10 @@ const vi: PartialLocaleType = {
     },
 
     Model: "Mô hình (model)",
+    CompressModel: {
+      Title: "Mô hình nén",
+      SubTitle: "Mô hình được sử dụng để nén lịch sử",
+    },
     Temperature: {
       Title: "Độ ngẫu nhiên (temperature)",
       SubTitle: "Giá trị càng lớn, câu trả lời càng ngẫu nhiên",

+ 29 - 44
app/store/chat.ts

@@ -1,33 +1,29 @@
-import { trimTopic, getMessageTextContent } from "../utils";
+import { getMessageTextContent, trimTopic } from "../utils";
 
-import Locale, { getLang } from "../locales";
+import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
+import { nanoid } from "nanoid";
+import type {
+  ClientApi,
+  MultimodalContent,
+  RequestMessage,
+} from "../client/api";
+import { getClientApi } from "../client/api";
+import { ChatControllerPool } from "../client/controller";
 import { showToast } from "../components/ui-lib";
-import { ModelConfig, ModelType, useAppConfig } from "./config";
-import { createEmptyMask, Mask } from "./mask";
 import {
   DEFAULT_INPUT_TEMPLATE,
   DEFAULT_MODELS,
   DEFAULT_SYSTEM_TEMPLATE,
   KnowledgeCutOffDate,
   StoreKey,
-  SUMMARIZE_MODEL,
-  GEMINI_SUMMARIZE_MODEL,
 } from "../constant";
-import { getClientApi } from "../client/api";
-import type {
-  ClientApi,
-  RequestMessage,
-  MultimodalContent,
-} from "../client/api";
-import { ChatControllerPool } from "../client/controller";
+import Locale, { getLang } from "../locales";
+import { isDalle3, safeLocalStorage } from "../utils";
 import { prettyObject } from "../utils/format";
-import { estimateTokenLength } from "../utils/token";
-import { nanoid } from "nanoid";
 import { createPersistStore } from "../utils/store";
-import { collectModelsWithDefaultModel } from "../utils/model";
-import { useAccessStore } from "./access";
-import { isDalle3, safeLocalStorage } from "../utils";
-import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
+import { estimateTokenLength } from "../utils/token";
+import { ModelConfig, ModelType, useAppConfig } from "./config";
+import { createEmptyMask, Mask } from "./mask";
 
 const localStorage = safeLocalStorage();
 
@@ -106,27 +102,6 @@ function createEmptySession(): ChatSession {
   };
 }
 
-function getSummarizeModel(currentModel: string) {
-  // if it is using gpt-* models, force to use 4o-mini to summarize
-  if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
-    const configStore = useAppConfig.getState();
-    const accessStore = useAccessStore.getState();
-    const allModel = collectModelsWithDefaultModel(
-      configStore.models,
-      [configStore.customModels, accessStore.customModels].join(","),
-      accessStore.defaultModel,
-    );
-    const summarizeModel = allModel.find(
-      (m) => m.name === SUMMARIZE_MODEL && m.available,
-    );
-    return summarizeModel?.name ?? currentModel;
-  }
-  if (currentModel.startsWith("gemini")) {
-    return GEMINI_SUMMARIZE_MODEL;
-  }
-  return currentModel;
-}
-
 function countMessages(msgs: ChatMessage[]) {
   return msgs.reduce(
     (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
@@ -581,7 +556,7 @@ export const useChatStore = createPersistStore(
           return;
         }
 
-        const providerName = modelConfig.providerName;
+        const providerName = modelConfig.compressProviderName;
         const api: ClientApi = getClientApi(providerName);
 
         // remove error messages if any
@@ -603,7 +578,7 @@ export const useChatStore = createPersistStore(
           api.llm.chat({
             messages: topicMessages,
             config: {
-              model: getSummarizeModel(session.mask.modelConfig.model),
+              model: modelConfig.compressModel,
               stream: false,
               providerName,
             },
@@ -666,7 +641,7 @@ export const useChatStore = createPersistStore(
             config: {
               ...modelcfg,
               stream: true,
-              model: getSummarizeModel(session.mask.modelConfig.model),
+              model: modelConfig.compressModel,
             },
             onUpdate(message) {
               session.memoryPrompt = message;
@@ -715,7 +690,7 @@ export const useChatStore = createPersistStore(
   },
   {
     name: StoreKey.Chat,
-    version: 3.1,
+    version: 3.2,
     migrate(persistedState, version) {
       const state = persistedState as any;
       const newState = JSON.parse(
@@ -762,6 +737,16 @@ export const useChatStore = createPersistStore(
         });
       }
 
+      // add default summarize model for every session
+      if (version < 3.2) {
+        newState.sessions.forEach((s) => {
+          const config = useAppConfig.getState();
+          s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
+          s.mask.modelConfig.compressProviderName =
+            config.modelConfig.compressProviderName;
+        });
+      }
+
       return newState as any;
     },
   },

+ 11 - 2
app/store/config.ts

@@ -50,7 +50,7 @@ export const DEFAULT_CONFIG = {
   models: DEFAULT_MODELS as any as LLMModel[],
 
   modelConfig: {
-    model: "gpt-3.5-turbo" as ModelType,
+    model: "gpt-4o-mini" as ModelType,
     providerName: "OpenAI" as ServiceProvider,
     temperature: 0.5,
     top_p: 1,
@@ -60,6 +60,8 @@ export const DEFAULT_CONFIG = {
     sendMemory: true,
     historyMessageCount: 4,
     compressMessageLengthThreshold: 1000,
+    compressModel: "gpt-4o-mini" as ModelType,
+    compressProviderName: "OpenAI" as ServiceProvider,
     enableInjectSystemPrompts: true,
     template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
     size: "1024x1024" as DalleSize,
@@ -140,7 +142,7 @@ export const useAppConfig = createPersistStore(
   }),
   {
     name: StoreKey.Config,
-    version: 3.9,
+    version: 4,
     migrate(persistedState, version) {
       const state = persistedState as ChatConfig;
 
@@ -178,6 +180,13 @@ export const useAppConfig = createPersistStore(
             : config?.template ?? DEFAULT_INPUT_TEMPLATE;
       }
 
+      if (version < 4) {
+        state.modelConfig.compressModel =
+          DEFAULT_CONFIG.modelConfig.compressModel;
+        state.modelConfig.compressProviderName =
+          DEFAULT_CONFIG.modelConfig.compressProviderName;
+      }
+
       return state as any;
     },
   },