chat.ts 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. import {
  2. getMessageTextContent,
  3. isDalle3,
  4. safeLocalStorage,
  5. trimTopic,
  6. } from "../utils";
  7. import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
  8. import { nanoid } from "nanoid";
  9. import type {
  10. ClientApi,
  11. MultimodalContent,
  12. RequestMessage,
  13. } from "../client/api";
  14. import { getClientApi } from "../client/api";
  15. import { ChatControllerPool } from "../client/controller";
  16. import { showToast } from "../components/ui-lib";
  17. import {
  18. DEFAULT_INPUT_TEMPLATE,
  19. DEFAULT_MODELS,
  20. DEFAULT_SYSTEM_TEMPLATE,
  21. GEMINI_SUMMARIZE_MODEL,
  22. KnowledgeCutOffDate,
  23. MCP_SYSTEM_TEMPLATE,
  24. MCP_TOOLS_TEMPLATE,
  25. ServiceProvider,
  26. StoreKey,
  27. SUMMARIZE_MODEL,
  28. } from "../constant";
  29. import Locale, { getLang } from "../locales";
  30. import { prettyObject } from "../utils/format";
  31. import { createPersistStore } from "../utils/store";
  32. import { estimateTokenLength } from "../utils/token";
  33. import { ModelConfig, ModelType, useAppConfig } from "./config";
  34. import { useAccessStore } from "./access";
  35. import { collectModelsWithDefaultModel } from "../utils/model";
  36. import { createEmptyMask, Mask } from "./mask";
  37. import { executeMcpAction, getAllTools } from "../mcp/actions";
  38. import { extractMcpJson, isMcpJson } from "../mcp/utils";
  39. const localStorage = safeLocalStorage();
  40. export type ChatMessageTool = {
  41. id: string;
  42. index?: number;
  43. type?: string;
  44. function?: {
  45. name: string;
  46. arguments?: string;
  47. };
  48. content?: string;
  49. isError?: boolean;
  50. errorMsg?: string;
  51. };
  52. export type ChatMessage = RequestMessage & {
  53. date: string;
  54. streaming?: boolean;
  55. isError?: boolean;
  56. id: string;
  57. model?: ModelType;
  58. tools?: ChatMessageTool[];
  59. audio_url?: string;
  60. isMcpResponse?: boolean;
  61. };
  62. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  63. return {
  64. id: nanoid(),
  65. date: new Date().toLocaleString(),
  66. role: "user",
  67. content: "",
  68. ...override,
  69. };
  70. }
  71. export interface ChatStat {
  72. tokenCount: number;
  73. wordCount: number;
  74. charCount: number;
  75. }
  76. export interface ChatSession {
  77. id: string;
  78. topic: string;
  79. memoryPrompt: string;
  80. messages: ChatMessage[];
  81. stat: ChatStat;
  82. lastUpdate: number;
  83. lastSummarizeIndex: number;
  84. clearContextIndex?: number;
  85. mask: Mask;
  86. }
  87. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  88. export const BOT_HELLO: ChatMessage = createMessage({
  89. role: "assistant",
  90. content: Locale.Store.BotHello,
  91. });
  92. function createEmptySession(): ChatSession {
  93. return {
  94. id: nanoid(),
  95. topic: DEFAULT_TOPIC,
  96. memoryPrompt: "",
  97. messages: [],
  98. stat: {
  99. tokenCount: 0,
  100. wordCount: 0,
  101. charCount: 0,
  102. },
  103. lastUpdate: Date.now(),
  104. lastSummarizeIndex: 0,
  105. mask: createEmptyMask(),
  106. };
  107. }
  108. function getSummarizeModel(
  109. currentModel: string,
  110. providerName: string,
  111. ): string[] {
  112. // if it is using gpt-* models, force to use 4o-mini to summarize
  113. if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
  114. const configStore = useAppConfig.getState();
  115. const accessStore = useAccessStore.getState();
  116. const allModel = collectModelsWithDefaultModel(
  117. configStore.models,
  118. [configStore.customModels, accessStore.customModels].join(","),
  119. accessStore.defaultModel,
  120. );
  121. const summarizeModel = allModel.find(
  122. (m) => m.name === SUMMARIZE_MODEL && m.available,
  123. );
  124. if (summarizeModel) {
  125. return [
  126. summarizeModel.name,
  127. summarizeModel.provider?.providerName as string,
  128. ];
  129. }
  130. }
  131. if (currentModel.startsWith("gemini")) {
  132. return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
  133. }
  134. return [currentModel, providerName];
  135. }
  136. function countMessages(msgs: ChatMessage[]) {
  137. return msgs.reduce(
  138. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  139. 0,
  140. );
  141. }
  142. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  143. const cutoff =
  144. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  145. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  146. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  147. var serviceProvider = "OpenAI";
  148. if (modelInfo) {
  149. // TODO: auto detect the providerName from the modelConfig.model
  150. // Directly use the providerName from the modelInfo
  151. serviceProvider = modelInfo.provider.providerName;
  152. }
  153. const vars = {
  154. ServiceProvider: serviceProvider,
  155. cutoff,
  156. model: modelConfig.model,
  157. time: new Date().toString(),
  158. lang: getLang(),
  159. input: input,
  160. };
  161. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  162. // remove duplicate
  163. if (input.startsWith(output)) {
  164. output = "";
  165. }
  166. // must contains {{input}}
  167. const inputVar = "{{input}}";
  168. if (!output.includes(inputVar)) {
  169. output += "\n" + inputVar;
  170. }
  171. Object.entries(vars).forEach(([name, value]) => {
  172. const regex = new RegExp(`{{${name}}}`, "g");
  173. output = output.replace(regex, value.toString()); // Ensure value is a string
  174. });
  175. return output;
  176. }
  177. async function getMcpSystemPrompt(): Promise<string> {
  178. const tools = await getAllTools();
  179. let toolsStr = "";
  180. tools.forEach((i) => {
  181. // error client has no tools
  182. if (!i.tools) return;
  183. toolsStr += MCP_TOOLS_TEMPLATE.replace(
  184. "{{ clientId }}",
  185. i.clientId,
  186. ).replace(
  187. "{{ tools }}",
  188. i.tools.tools.map((p: object) => JSON.stringify(p, null, 2)).join("\n"),
  189. );
  190. });
  191. return MCP_SYSTEM_TEMPLATE.replace("{{ MCP_TOOLS }}", toolsStr);
  192. }
  193. const DEFAULT_CHAT_STATE = {
  194. sessions: [createEmptySession()],
  195. currentSessionIndex: 0,
  196. lastInput: "",
  197. };
  198. export const useChatStore = createPersistStore(
  199. DEFAULT_CHAT_STATE,
  200. (set, _get) => {
  201. function get() {
  202. return {
  203. ..._get(),
  204. ...methods,
  205. };
  206. }
  207. const methods = {
  208. forkSession() {
  209. // 获取当前会话
  210. const currentSession = get().currentSession();
  211. if (!currentSession) return;
  212. const newSession = createEmptySession();
  213. newSession.topic = currentSession.topic;
  214. newSession.messages = [...currentSession.messages];
  215. newSession.mask = {
  216. ...currentSession.mask,
  217. modelConfig: {
  218. ...currentSession.mask.modelConfig,
  219. },
  220. };
  221. set((state) => ({
  222. currentSessionIndex: 0,
  223. sessions: [newSession, ...state.sessions],
  224. }));
  225. },
  226. clearSessions() {
  227. set(() => ({
  228. sessions: [createEmptySession()],
  229. currentSessionIndex: 0,
  230. }));
  231. },
  232. selectSession(index: number) {
  233. set({
  234. currentSessionIndex: index,
  235. });
  236. },
  237. moveSession(from: number, to: number) {
  238. set((state) => {
  239. const { sessions, currentSessionIndex: oldIndex } = state;
  240. // move the session
  241. const newSessions = [...sessions];
  242. const session = newSessions[from];
  243. newSessions.splice(from, 1);
  244. newSessions.splice(to, 0, session);
  245. // modify current session id
  246. let newIndex = oldIndex === from ? to : oldIndex;
  247. if (oldIndex > from && oldIndex <= to) {
  248. newIndex -= 1;
  249. } else if (oldIndex < from && oldIndex >= to) {
  250. newIndex += 1;
  251. }
  252. return {
  253. currentSessionIndex: newIndex,
  254. sessions: newSessions,
  255. };
  256. });
  257. },
  258. newSession(mask?: Mask) {
  259. const session = createEmptySession();
  260. if (mask) {
  261. const config = useAppConfig.getState();
  262. const globalModelConfig = config.modelConfig;
  263. session.mask = {
  264. ...mask,
  265. modelConfig: {
  266. ...globalModelConfig,
  267. ...mask.modelConfig,
  268. },
  269. };
  270. session.topic = mask.name;
  271. }
  272. set((state) => ({
  273. currentSessionIndex: 0,
  274. sessions: [session].concat(state.sessions),
  275. }));
  276. },
  277. nextSession(delta: number) {
  278. const n = get().sessions.length;
  279. const limit = (x: number) => (x + n) % n;
  280. const i = get().currentSessionIndex;
  281. get().selectSession(limit(i + delta));
  282. },
  283. deleteSession(index: number) {
  284. const deletingLastSession = get().sessions.length === 1;
  285. const deletedSession = get().sessions.at(index);
  286. if (!deletedSession) return;
  287. const sessions = get().sessions.slice();
  288. sessions.splice(index, 1);
  289. const currentIndex = get().currentSessionIndex;
  290. let nextIndex = Math.min(
  291. currentIndex - Number(index < currentIndex),
  292. sessions.length - 1,
  293. );
  294. if (deletingLastSession) {
  295. nextIndex = 0;
  296. sessions.push(createEmptySession());
  297. }
  298. // for undo delete action
  299. const restoreState = {
  300. currentSessionIndex: get().currentSessionIndex,
  301. sessions: get().sessions.slice(),
  302. };
  303. set(() => ({
  304. currentSessionIndex: nextIndex,
  305. sessions,
  306. }));
  307. showToast(
  308. Locale.Home.DeleteToast,
  309. {
  310. text: Locale.Home.Revert,
  311. onClick() {
  312. set(() => restoreState);
  313. },
  314. },
  315. 5000,
  316. );
  317. },
  318. currentSession() {
  319. let index = get().currentSessionIndex;
  320. const sessions = get().sessions;
  321. if (index < 0 || index >= sessions.length) {
  322. index = Math.min(sessions.length - 1, Math.max(0, index));
  323. set(() => ({ currentSessionIndex: index }));
  324. }
  325. const session = sessions[index];
  326. return session;
  327. },
  328. onNewMessage(message: ChatMessage, targetSession: ChatSession) {
  329. get().updateTargetSession(targetSession, (session) => {
  330. session.messages = session.messages.concat();
  331. session.lastUpdate = Date.now();
  332. });
  333. get().updateStat(message, targetSession);
  334. get().checkMcpJson(message);
  335. get().summarizeSession(false, targetSession);
  336. },
  337. async onUserInput(
  338. content: string,
  339. attachImages?: string[],
  340. isMcpResponse?: boolean,
  341. ) {
  342. const session = get().currentSession();
  343. const modelConfig = session.mask.modelConfig;
  344. // MCP Response no need to fill template
  345. let mContent: string | MultimodalContent[] = isMcpResponse
  346. ? content
  347. : fillTemplateWith(content, modelConfig);
  348. if (!isMcpResponse && attachImages && attachImages.length > 0) {
  349. mContent = [
  350. ...(content ? [{ type: "text" as const, text: content }] : []),
  351. ...attachImages.map((url) => ({
  352. type: "image_url" as const,
  353. image_url: { url },
  354. })),
  355. ];
  356. }
  357. let userMessage: ChatMessage = createMessage({
  358. role: "user",
  359. content: mContent,
  360. isMcpResponse,
  361. });
  362. const botMessage: ChatMessage = createMessage({
  363. role: "assistant",
  364. streaming: true,
  365. model: modelConfig.model,
  366. });
  367. // get recent messages
  368. const recentMessages = await get().getMessagesWithMemory();
  369. const sendMessages = recentMessages.concat(userMessage);
  370. const messageIndex = session.messages.length + 1;
  371. // save user's and bot's message
  372. get().updateTargetSession(session, (session) => {
  373. const savedUserMessage = {
  374. ...userMessage,
  375. content: mContent,
  376. };
  377. session.messages = session.messages.concat([
  378. savedUserMessage,
  379. botMessage,
  380. ]);
  381. });
  382. const api: ClientApi = getClientApi(modelConfig.providerName);
  383. // make request
  384. api.llm.chat({
  385. messages: sendMessages,
  386. config: { ...modelConfig, stream: true },
  387. onUpdate(message) {
  388. botMessage.streaming = true;
  389. if (message) {
  390. botMessage.content = message;
  391. }
  392. get().updateTargetSession(session, (session) => {
  393. session.messages = session.messages.concat();
  394. });
  395. },
  396. async onFinish(message) {
  397. botMessage.streaming = false;
  398. if (message) {
  399. botMessage.content = message;
  400. botMessage.date = new Date().toLocaleString();
  401. get().onNewMessage(botMessage, session);
  402. }
  403. ChatControllerPool.remove(session.id, botMessage.id);
  404. },
  405. onBeforeTool(tool: ChatMessageTool) {
  406. (botMessage.tools = botMessage?.tools || []).push(tool);
  407. get().updateTargetSession(session, (session) => {
  408. session.messages = session.messages.concat();
  409. });
  410. },
  411. onAfterTool(tool: ChatMessageTool) {
  412. botMessage?.tools?.forEach((t, i, tools) => {
  413. if (tool.id == t.id) {
  414. tools[i] = { ...tool };
  415. }
  416. });
  417. get().updateTargetSession(session, (session) => {
  418. session.messages = session.messages.concat();
  419. });
  420. },
  421. onError(error) {
  422. const isAborted = error.message?.includes?.("aborted");
  423. botMessage.content +=
  424. "\n\n" +
  425. prettyObject({
  426. error: true,
  427. message: error.message,
  428. });
  429. botMessage.streaming = false;
  430. userMessage.isError = !isAborted;
  431. botMessage.isError = !isAborted;
  432. get().updateTargetSession(session, (session) => {
  433. session.messages = session.messages.concat();
  434. });
  435. ChatControllerPool.remove(
  436. session.id,
  437. botMessage.id ?? messageIndex,
  438. );
  439. console.error("[Chat] failed ", error);
  440. },
  441. onController(controller) {
  442. // collect controller for stop/retry
  443. ChatControllerPool.addController(
  444. session.id,
  445. botMessage.id ?? messageIndex,
  446. controller,
  447. );
  448. },
  449. });
  450. },
  451. getMemoryPrompt() {
  452. const session = get().currentSession();
  453. if (session.memoryPrompt.length) {
  454. return {
  455. role: "system",
  456. content: Locale.Store.Prompt.History(session.memoryPrompt),
  457. date: "",
  458. } as ChatMessage;
  459. }
  460. },
  461. async getMessagesWithMemory() {
  462. const session = get().currentSession();
  463. const modelConfig = session.mask.modelConfig;
  464. const clearContextIndex = session.clearContextIndex ?? 0;
  465. const messages = session.messages.slice();
  466. const totalMessageCount = session.messages.length;
  467. // in-context prompts
  468. const contextPrompts = session.mask.context.slice();
  469. // system prompts, to get close to OpenAI Web ChatGPT
  470. const shouldInjectSystemPrompts =
  471. modelConfig.enableInjectSystemPrompts &&
  472. (session.mask.modelConfig.model.startsWith("gpt-") ||
  473. session.mask.modelConfig.model.startsWith("chatgpt-"));
  474. const mcpSystemPrompt = await getMcpSystemPrompt();
  475. var systemPrompts: ChatMessage[] = [];
  476. systemPrompts = shouldInjectSystemPrompts
  477. ? [
  478. createMessage({
  479. role: "system",
  480. content:
  481. fillTemplateWith("", {
  482. ...modelConfig,
  483. template: DEFAULT_SYSTEM_TEMPLATE,
  484. }) + mcpSystemPrompt,
  485. }),
  486. ]
  487. : [
  488. createMessage({
  489. role: "system",
  490. content: mcpSystemPrompt,
  491. }),
  492. ];
  493. if (shouldInjectSystemPrompts) {
  494. console.log(
  495. "[Global System Prompt] ",
  496. systemPrompts.at(0)?.content ?? "empty",
  497. );
  498. }
  499. const memoryPrompt = get().getMemoryPrompt();
  500. // long term memory
  501. const shouldSendLongTermMemory =
  502. modelConfig.sendMemory &&
  503. session.memoryPrompt &&
  504. session.memoryPrompt.length > 0 &&
  505. session.lastSummarizeIndex > clearContextIndex;
  506. const longTermMemoryPrompts =
  507. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  508. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  509. // short term memory
  510. const shortTermMemoryStartIndex = Math.max(
  511. 0,
  512. totalMessageCount - modelConfig.historyMessageCount,
  513. );
  514. // lets concat send messages, including 4 parts:
  515. // 0. system prompt: to get close to OpenAI Web ChatGPT
  516. // 1. long term memory: summarized memory messages
  517. // 2. pre-defined in-context prompts
  518. // 3. short term memory: latest n messages
  519. // 4. newest input message
  520. const memoryStartIndex = shouldSendLongTermMemory
  521. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  522. : shortTermMemoryStartIndex;
  523. // and if user has cleared history messages, we should exclude the memory too.
  524. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  525. const maxTokenThreshold = modelConfig.max_tokens;
  526. // get recent messages as much as possible
  527. const reversedRecentMessages = [];
  528. for (
  529. let i = totalMessageCount - 1, tokenCount = 0;
  530. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  531. i -= 1
  532. ) {
  533. const msg = messages[i];
  534. if (!msg || msg.isError) continue;
  535. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  536. reversedRecentMessages.push(msg);
  537. }
  538. // concat all messages
  539. const recentMessages = [
  540. ...systemPrompts,
  541. ...longTermMemoryPrompts,
  542. ...contextPrompts,
  543. ...reversedRecentMessages.reverse(),
  544. ];
  545. return recentMessages;
  546. },
  547. updateMessage(
  548. sessionIndex: number,
  549. messageIndex: number,
  550. updater: (message?: ChatMessage) => void,
  551. ) {
  552. const sessions = get().sessions;
  553. const session = sessions.at(sessionIndex);
  554. const messages = session?.messages;
  555. updater(messages?.at(messageIndex));
  556. set(() => ({ sessions }));
  557. },
  558. resetSession(session: ChatSession) {
  559. get().updateTargetSession(session, (session) => {
  560. session.messages = [];
  561. session.memoryPrompt = "";
  562. });
  563. },
  564. summarizeSession(
  565. refreshTitle: boolean = false,
  566. targetSession: ChatSession,
  567. ) {
  568. const config = useAppConfig.getState();
  569. const session = targetSession;
  570. const modelConfig = session.mask.modelConfig;
  571. // skip summarize when using dalle3?
  572. if (isDalle3(modelConfig.model)) {
  573. return;
  574. }
  575. // if not config compressModel, then using getSummarizeModel
  576. const [model, providerName] = modelConfig.compressModel
  577. ? [modelConfig.compressModel, modelConfig.compressProviderName]
  578. : getSummarizeModel(
  579. session.mask.modelConfig.model,
  580. session.mask.modelConfig.providerName,
  581. );
  582. const api: ClientApi = getClientApi(providerName as ServiceProvider);
  583. // remove error messages if any
  584. const messages = session.messages;
  585. // should summarize topic after chating more than 50 words
  586. const SUMMARIZE_MIN_LEN = 50;
  587. if (
  588. (config.enableAutoGenerateTitle &&
  589. session.topic === DEFAULT_TOPIC &&
  590. countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
  591. refreshTitle
  592. ) {
  593. const startIndex = Math.max(
  594. 0,
  595. messages.length - modelConfig.historyMessageCount,
  596. );
  597. const topicMessages = messages
  598. .slice(
  599. startIndex < messages.length ? startIndex : messages.length - 1,
  600. messages.length,
  601. )
  602. .concat(
  603. createMessage({
  604. role: "user",
  605. content: Locale.Store.Prompt.Topic,
  606. }),
  607. );
  608. api.llm.chat({
  609. messages: topicMessages,
  610. config: {
  611. model,
  612. stream: false,
  613. providerName,
  614. },
  615. onFinish(message, responseRes) {
  616. if (responseRes?.status === 200) {
  617. get().updateTargetSession(
  618. session,
  619. (session) =>
  620. (session.topic =
  621. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  622. );
  623. }
  624. },
  625. });
  626. }
  627. const summarizeIndex = Math.max(
  628. session.lastSummarizeIndex,
  629. session.clearContextIndex ?? 0,
  630. );
  631. let toBeSummarizedMsgs = messages
  632. .filter((msg) => !msg.isError)
  633. .slice(summarizeIndex);
  634. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  635. if (historyMsgLength > (modelConfig?.max_tokens || 4000)) {
  636. const n = toBeSummarizedMsgs.length;
  637. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  638. Math.max(0, n - modelConfig.historyMessageCount),
  639. );
  640. }
  641. const memoryPrompt = get().getMemoryPrompt();
  642. if (memoryPrompt) {
  643. // add memory prompt
  644. toBeSummarizedMsgs.unshift(memoryPrompt);
  645. }
  646. const lastSummarizeIndex = session.messages.length;
  647. console.log(
  648. "[Chat History] ",
  649. toBeSummarizedMsgs,
  650. historyMsgLength,
  651. modelConfig.compressMessageLengthThreshold,
  652. );
  653. if (
  654. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  655. modelConfig.sendMemory
  656. ) {
  657. /** Destruct max_tokens while summarizing
  658. * this param is just shit
  659. **/
  660. const { max_tokens, ...modelcfg } = modelConfig;
  661. api.llm.chat({
  662. messages: toBeSummarizedMsgs.concat(
  663. createMessage({
  664. role: "system",
  665. content: Locale.Store.Prompt.Summarize,
  666. date: "",
  667. }),
  668. ),
  669. config: {
  670. ...modelcfg,
  671. stream: true,
  672. model,
  673. providerName,
  674. },
  675. onUpdate(message) {
  676. session.memoryPrompt = message;
  677. },
  678. onFinish(message, responseRes) {
  679. if (responseRes?.status === 200) {
  680. console.log("[Memory] ", message);
  681. get().updateTargetSession(session, (session) => {
  682. session.lastSummarizeIndex = lastSummarizeIndex;
  683. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  684. });
  685. }
  686. },
  687. onError(err) {
  688. console.error("[Summarize] ", err);
  689. },
  690. });
  691. }
  692. },
  693. updateStat(message: ChatMessage, session: ChatSession) {
  694. get().updateTargetSession(session, (session) => {
  695. session.stat.charCount += message.content.length;
  696. // TODO: should update chat count and word count
  697. });
  698. },
  699. updateTargetSession(
  700. targetSession: ChatSession,
  701. updater: (session: ChatSession) => void,
  702. ) {
  703. const sessions = get().sessions;
  704. const index = sessions.findIndex((s) => s.id === targetSession.id);
  705. if (index < 0) return;
  706. updater(sessions[index]);
  707. set(() => ({ sessions }));
  708. },
  709. async clearAllData() {
  710. await indexedDBStorage.clear();
  711. localStorage.clear();
  712. location.reload();
  713. },
  714. setLastInput(lastInput: string) {
  715. set({
  716. lastInput,
  717. });
  718. },
  719. /** check if the message contains MCP JSON and execute the MCP action */
  720. checkMcpJson(message: ChatMessage) {
  721. const content = getMessageTextContent(message);
  722. if (isMcpJson(content)) {
  723. try {
  724. const mcpRequest = extractMcpJson(content);
  725. if (mcpRequest) {
  726. console.debug("[MCP Request]", mcpRequest);
  727. executeMcpAction(mcpRequest.clientId, mcpRequest.mcp)
  728. .then((result) => {
  729. console.log("[MCP Response]", result);
  730. const mcpResponse =
  731. typeof result === "object"
  732. ? JSON.stringify(result)
  733. : String(result);
  734. get().onUserInput(
  735. `\`\`\`json:mcp-response:${mcpRequest.clientId}\n${mcpResponse}\n\`\`\``,
  736. [],
  737. true,
  738. );
  739. })
  740. .catch((error) => showToast("MCP execution failed", error));
  741. }
  742. } catch (error) {
  743. console.error("[Check MCP JSON]", error);
  744. }
  745. }
  746. },
  747. };
  748. return methods;
  749. },
  750. {
  751. name: StoreKey.Chat,
  752. version: 3.3,
  753. migrate(persistedState, version) {
  754. const state = persistedState as any;
  755. const newState = JSON.parse(
  756. JSON.stringify(state),
  757. ) as typeof DEFAULT_CHAT_STATE;
  758. if (version < 2) {
  759. newState.sessions = [];
  760. const oldSessions = state.sessions;
  761. for (const oldSession of oldSessions) {
  762. const newSession = createEmptySession();
  763. newSession.topic = oldSession.topic;
  764. newSession.messages = [...oldSession.messages];
  765. newSession.mask.modelConfig.sendMemory = true;
  766. newSession.mask.modelConfig.historyMessageCount = 4;
  767. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  768. newState.sessions.push(newSession);
  769. }
  770. }
  771. if (version < 3) {
  772. // migrate id to nanoid
  773. newState.sessions.forEach((s) => {
  774. s.id = nanoid();
  775. s.messages.forEach((m) => (m.id = nanoid()));
  776. });
  777. }
  778. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  779. // Resolve issue of old sessions not automatically enabling.
  780. if (version < 3.1) {
  781. newState.sessions.forEach((s) => {
  782. if (
  783. // Exclude those already set by user
  784. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  785. ) {
  786. // Because users may have changed this configuration,
  787. // the user's current configuration is used instead of the default
  788. const config = useAppConfig.getState();
  789. s.mask.modelConfig.enableInjectSystemPrompts =
  790. config.modelConfig.enableInjectSystemPrompts;
  791. }
  792. });
  793. }
  794. // add default summarize model for every session
  795. if (version < 3.2) {
  796. newState.sessions.forEach((s) => {
  797. const config = useAppConfig.getState();
  798. s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
  799. s.mask.modelConfig.compressProviderName =
  800. config.modelConfig.compressProviderName;
  801. });
  802. }
  803. // revert default summarize model for every session
  804. if (version < 3.3) {
  805. newState.sessions.forEach((s) => {
  806. const config = useAppConfig.getState();
  807. s.mask.modelConfig.compressModel = "";
  808. s.mask.modelConfig.compressProviderName = "";
  809. });
  810. }
  811. return newState as any;
  812. },
  813. },
  814. );