|
|
@@ -2,7 +2,13 @@ import { trimTopic } from "../utils";
|
|
|
|
|
|
import Locale, { getLang } from "../locales";
|
|
|
import { showToast } from "../components/ui-lib";
|
|
|
-import { ModelConfig, ModelType, useAppConfig } from "./config";
|
|
|
+import {
|
|
|
+ LLMProvider,
|
|
|
+ MaskConfig,
|
|
|
+ ModelConfig,
|
|
|
+ ModelType,
|
|
|
+ useAppConfig,
|
|
|
+} from "./config";
|
|
|
import { createEmptyMask, Mask } from "./mask";
|
|
|
import {
|
|
|
DEFAULT_INPUT_TEMPLATE,
|
|
|
@@ -10,19 +16,19 @@ import {
|
|
|
StoreKey,
|
|
|
SUMMARIZE_MODEL,
|
|
|
} from "../constant";
|
|
|
-import { api, RequestMessage } from "../client/api";
|
|
|
-import { ChatControllerPool } from "../client/controller";
|
|
|
+import { ChatControllerPool } from "../client/common/controller";
|
|
|
import { prettyObject } from "../utils/format";
|
|
|
import { estimateTokenLength } from "../utils/token";
|
|
|
import { nanoid } from "nanoid";
|
|
|
import { createPersistStore } from "../utils/store";
|
|
|
+import { RequestMessage, api } from "../client";
|
|
|
|
|
|
export type ChatMessage = RequestMessage & {
|
|
|
date: string;
|
|
|
streaming?: boolean;
|
|
|
isError?: boolean;
|
|
|
id: string;
|
|
|
- model?: ModelType;
|
|
|
+ model?: string;
|
|
|
};
|
|
|
|
|
|
export function createMessage(override: Partial<ChatMessage>): ChatMessage {
|
|
|
@@ -84,46 +90,25 @@ function getSummarizeModel(currentModel: string) {
|
|
|
return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel;
|
|
|
}
|
|
|
|
|
|
-interface ChatStore {
|
|
|
- sessions: ChatSession[];
|
|
|
- currentSessionIndex: number;
|
|
|
- clearSessions: () => void;
|
|
|
- moveSession: (from: number, to: number) => void;
|
|
|
- selectSession: (index: number) => void;
|
|
|
- newSession: (mask?: Mask) => void;
|
|
|
- deleteSession: (index: number) => void;
|
|
|
- currentSession: () => ChatSession;
|
|
|
- nextSession: (delta: number) => void;
|
|
|
- onNewMessage: (message: ChatMessage) => void;
|
|
|
- onUserInput: (content: string) => Promise<void>;
|
|
|
- summarizeSession: () => void;
|
|
|
- updateStat: (message: ChatMessage) => void;
|
|
|
- updateCurrentSession: (updater: (session: ChatSession) => void) => void;
|
|
|
- updateMessage: (
|
|
|
- sessionIndex: number,
|
|
|
- messageIndex: number,
|
|
|
- updater: (message?: ChatMessage) => void,
|
|
|
- ) => void;
|
|
|
- resetSession: () => void;
|
|
|
- getMessagesWithMemory: () => ChatMessage[];
|
|
|
- getMemoryPrompt: () => ChatMessage;
|
|
|
-
|
|
|
- clearAllData: () => void;
|
|
|
-}
|
|
|
-
|
|
|
function countMessages(msgs: ChatMessage[]) {
|
|
|
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
|
|
|
}
|
|
|
|
|
|
-function fillTemplateWith(input: string, modelConfig: ModelConfig) {
|
|
|
+function fillTemplateWith(
|
|
|
+ input: string,
|
|
|
+ context: {
|
|
|
+ model: string;
|
|
|
+ template?: string;
|
|
|
+ },
|
|
|
+) {
|
|
|
const vars = {
|
|
|
- model: modelConfig.model,
|
|
|
+ model: context.model,
|
|
|
time: new Date().toLocaleString(),
|
|
|
lang: getLang(),
|
|
|
input: input,
|
|
|
};
|
|
|
|
|
|
- let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
|
|
|
+ let output = context.template ?? DEFAULT_INPUT_TEMPLATE;
|
|
|
|
|
|
// must contains {{input}}
|
|
|
const inputVar = "{{input}}";
|
|
|
@@ -197,13 +182,13 @@ export const useChatStore = createPersistStore(
|
|
|
|
|
|
if (mask) {
|
|
|
const config = useAppConfig.getState();
|
|
|
- const globalModelConfig = config.modelConfig;
|
|
|
+ const globalModelConfig = config.globalMaskConfig;
|
|
|
|
|
|
session.mask = {
|
|
|
...mask,
|
|
|
- modelConfig: {
|
|
|
+ config: {
|
|
|
...globalModelConfig,
|
|
|
- ...mask.modelConfig,
|
|
|
+ ...mask.config,
|
|
|
},
|
|
|
};
|
|
|
session.topic = mask.name;
|
|
|
@@ -288,11 +273,39 @@ export const useChatStore = createPersistStore(
|
|
|
get().summarizeSession();
|
|
|
},
|
|
|
|
|
|
+ getCurrentMaskConfig() {
|
|
|
+ return get().currentSession().mask.config;
|
|
|
+ },
|
|
|
+
|
|
|
+ extractModelConfig(maskConfig: MaskConfig) {
|
|
|
+ const provider = maskConfig.provider;
|
|
|
+ if (!maskConfig.modelConfig[provider]) {
|
|
|
+ throw Error("[Chat] failed to initialize provider: " + provider);
|
|
|
+ }
|
|
|
+
|
|
|
+ return maskConfig.modelConfig[provider];
|
|
|
+ },
|
|
|
+
|
|
|
+ getCurrentModelConfig() {
|
|
|
+ const maskConfig = this.getCurrentMaskConfig();
|
|
|
+ return this.extractModelConfig(maskConfig);
|
|
|
+ },
|
|
|
+
|
|
|
+ getClient() {
|
|
|
+ const appConfig = useAppConfig.getState();
|
|
|
+ const currentMaskConfig = get().getCurrentMaskConfig();
|
|
|
+ return api.createLLMClient(appConfig.providerConfig, currentMaskConfig);
|
|
|
+ },
|
|
|
+
|
|
|
async onUserInput(content: string) {
|
|
|
const session = get().currentSession();
|
|
|
- const modelConfig = session.mask.modelConfig;
|
|
|
+ const maskConfig = this.getCurrentMaskConfig();
|
|
|
+ const modelConfig = this.getCurrentModelConfig();
|
|
|
|
|
|
- const userContent = fillTemplateWith(content, modelConfig);
|
|
|
+ const userContent = fillTemplateWith(content, {
|
|
|
+ model: modelConfig.model,
|
|
|
+ template: maskConfig.chatConfig.template,
|
|
|
+ });
|
|
|
console.log("[User Input] after template: ", userContent);
|
|
|
|
|
|
const userMessage: ChatMessage = createMessage({
|
|
|
@@ -323,10 +336,11 @@ export const useChatStore = createPersistStore(
|
|
|
]);
|
|
|
});
|
|
|
|
|
|
+ const client = this.getClient();
|
|
|
+
|
|
|
// make request
|
|
|
- api.llm.chat({
|
|
|
+ client.chatStream({
|
|
|
messages: sendMessages,
|
|
|
- config: { ...modelConfig, stream: true },
|
|
|
onUpdate(message) {
|
|
|
botMessage.streaming = true;
|
|
|
if (message) {
|
|
|
@@ -391,7 +405,9 @@ export const useChatStore = createPersistStore(
|
|
|
|
|
|
getMessagesWithMemory() {
|
|
|
const session = get().currentSession();
|
|
|
- const modelConfig = session.mask.modelConfig;
|
|
|
+ const maskConfig = this.getCurrentMaskConfig();
|
|
|
+ const chatConfig = maskConfig.chatConfig;
|
|
|
+ const modelConfig = this.getCurrentModelConfig();
|
|
|
const clearContextIndex = session.clearContextIndex ?? 0;
|
|
|
const messages = session.messages.slice();
|
|
|
const totalMessageCount = session.messages.length;
|
|
|
@@ -400,14 +416,14 @@ export const useChatStore = createPersistStore(
|
|
|
const contextPrompts = session.mask.context.slice();
|
|
|
|
|
|
// system prompts, to get close to OpenAI Web ChatGPT
|
|
|
- const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts;
|
|
|
+ const shouldInjectSystemPrompts = chatConfig.enableInjectSystemPrompts;
|
|
|
const systemPrompts = shouldInjectSystemPrompts
|
|
|
? [
|
|
|
createMessage({
|
|
|
role: "system",
|
|
|
content: fillTemplateWith("", {
|
|
|
- ...modelConfig,
|
|
|
- template: DEFAULT_SYSTEM_TEMPLATE,
|
|
|
+ model: modelConfig.model,
|
|
|
+ template: chatConfig.template,
|
|
|
}),
|
|
|
}),
|
|
|
]
|
|
|
@@ -421,7 +437,7 @@ export const useChatStore = createPersistStore(
|
|
|
|
|
|
// long term memory
|
|
|
const shouldSendLongTermMemory =
|
|
|
- modelConfig.sendMemory &&
|
|
|
+ chatConfig.sendMemory &&
|
|
|
session.memoryPrompt &&
|
|
|
session.memoryPrompt.length > 0 &&
|
|
|
session.lastSummarizeIndex > clearContextIndex;
|
|
|
@@ -433,7 +449,7 @@ export const useChatStore = createPersistStore(
|
|
|
// short term memory
|
|
|
const shortTermMemoryStartIndex = Math.max(
|
|
|
0,
|
|
|
- totalMessageCount - modelConfig.historyMessageCount,
|
|
|
+ totalMessageCount - chatConfig.historyMessageCount,
|
|
|
);
|
|
|
|
|
|
// lets concat send messages, including 4 parts:
|
|
|
@@ -494,6 +510,8 @@ export const useChatStore = createPersistStore(
|
|
|
|
|
|
summarizeSession() {
|
|
|
const config = useAppConfig.getState();
|
|
|
+ const maskConfig = this.getCurrentMaskConfig();
|
|
|
+ const chatConfig = maskConfig.chatConfig;
|
|
|
const session = get().currentSession();
|
|
|
|
|
|
// remove error messages if any
|
|
|
@@ -502,7 +520,7 @@ export const useChatStore = createPersistStore(
|
|
|
// should summarize topic after chating more than 50 words
|
|
|
const SUMMARIZE_MIN_LEN = 50;
|
|
|
if (
|
|
|
- config.enableAutoGenerateTitle &&
|
|
|
+ chatConfig.enableAutoGenerateTitle &&
|
|
|
session.topic === DEFAULT_TOPIC &&
|
|
|
countMessages(messages) >= SUMMARIZE_MIN_LEN
|
|
|
) {
|
|
|
@@ -512,11 +530,12 @@ export const useChatStore = createPersistStore(
|
|
|
content: Locale.Store.Prompt.Topic,
|
|
|
}),
|
|
|
);
|
|
|
- api.llm.chat({
|
|
|
+
|
|
|
+ const client = this.getClient();
|
|
|
+ client.chat({
|
|
|
messages: topicMessages,
|
|
|
- config: {
|
|
|
- model: getSummarizeModel(session.mask.modelConfig.model),
|
|
|
- },
|
|
|
+ shouldSummarize: true,
|
|
|
+
|
|
|
onFinish(message) {
|
|
|
get().updateCurrentSession(
|
|
|
(session) =>
|
|
|
@@ -527,7 +546,7 @@ export const useChatStore = createPersistStore(
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- const modelConfig = session.mask.modelConfig;
|
|
|
+ const modelConfig = this.getCurrentModelConfig();
|
|
|
const summarizeIndex = Math.max(
|
|
|
session.lastSummarizeIndex,
|
|
|
session.clearContextIndex ?? 0,
|
|
|
@@ -541,7 +560,7 @@ export const useChatStore = createPersistStore(
|
|
|
if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
|
|
|
const n = toBeSummarizedMsgs.length;
|
|
|
toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
|
|
|
- Math.max(0, n - modelConfig.historyMessageCount),
|
|
|
+ Math.max(0, n - chatConfig.historyMessageCount),
|
|
|
);
|
|
|
}
|
|
|
|
|
|
@@ -554,14 +573,14 @@ export const useChatStore = createPersistStore(
|
|
|
"[Chat History] ",
|
|
|
toBeSummarizedMsgs,
|
|
|
historyMsgLength,
|
|
|
- modelConfig.compressMessageLengthThreshold,
|
|
|
+ chatConfig.compressMessageLengthThreshold,
|
|
|
);
|
|
|
|
|
|
if (
|
|
|
- historyMsgLength > modelConfig.compressMessageLengthThreshold &&
|
|
|
- modelConfig.sendMemory
|
|
|
+ historyMsgLength > chatConfig.compressMessageLengthThreshold &&
|
|
|
+ chatConfig.sendMemory
|
|
|
) {
|
|
|
- api.llm.chat({
|
|
|
+ this.getClient().chatStream({
|
|
|
messages: toBeSummarizedMsgs.concat(
|
|
|
createMessage({
|
|
|
role: "system",
|
|
|
@@ -569,11 +588,7 @@ export const useChatStore = createPersistStore(
|
|
|
date: "",
|
|
|
}),
|
|
|
),
|
|
|
- config: {
|
|
|
- ...modelConfig,
|
|
|
- stream: true,
|
|
|
- model: getSummarizeModel(session.mask.modelConfig.model),
|
|
|
- },
|
|
|
+ shouldSummarize: true,
|
|
|
onUpdate(message) {
|
|
|
session.memoryPrompt = message;
|
|
|
},
|
|
|
@@ -614,52 +629,9 @@ export const useChatStore = createPersistStore(
|
|
|
name: StoreKey.Chat,
|
|
|
version: 3.1,
|
|
|
migrate(persistedState, version) {
|
|
|
- const state = persistedState as any;
|
|
|
- const newState = JSON.parse(
|
|
|
- JSON.stringify(state),
|
|
|
- ) as typeof DEFAULT_CHAT_STATE;
|
|
|
-
|
|
|
- if (version < 2) {
|
|
|
- newState.sessions = [];
|
|
|
-
|
|
|
- const oldSessions = state.sessions;
|
|
|
- for (const oldSession of oldSessions) {
|
|
|
- const newSession = createEmptySession();
|
|
|
- newSession.topic = oldSession.topic;
|
|
|
- newSession.messages = [...oldSession.messages];
|
|
|
- newSession.mask.modelConfig.sendMemory = true;
|
|
|
- newSession.mask.modelConfig.historyMessageCount = 4;
|
|
|
- newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
|
|
|
- newState.sessions.push(newSession);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if (version < 3) {
|
|
|
- // migrate id to nanoid
|
|
|
- newState.sessions.forEach((s) => {
|
|
|
- s.id = nanoid();
|
|
|
- s.messages.forEach((m) => (m.id = nanoid()));
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- // Enable `enableInjectSystemPrompts` attribute for old sessions.
|
|
|
- // Resolve issue of old sessions not automatically enabling.
|
|
|
- if (version < 3.1) {
|
|
|
- newState.sessions.forEach((s) => {
|
|
|
- if (
|
|
|
- // Exclude those already set by user
|
|
|
- !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
|
|
|
- ) {
|
|
|
- // Because users may have changed this configuration,
|
|
|
- // the user's current configuration is used instead of the default
|
|
|
- const config = useAppConfig.getState();
|
|
|
- s.mask.modelConfig.enableInjectSystemPrompts =
|
|
|
- config.modelConfig.enableInjectSystemPrompts;
|
|
|
- }
|
|
|
- });
|
|
|
- }
|
|
|
+ // TODO(yifei): migrate from old versions
|
|
|
|
|
|
- return newState as any;
|
|
|
+ return persistedState as any;
|
|
|
},
|
|
|
},
|
|
|
);
|