chat.ts 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. import {
  2. getMessageTextContent,
  3. isDalle3,
  4. safeLocalStorage,
  5. trimTopic,
  6. } from "../utils";
  7. import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
  8. import { nanoid } from "nanoid";
  9. import type {
  10. ClientApi,
  11. MultimodalContent,
  12. RequestMessage,
  13. } from "../client/api";
  14. import { getClientApi } from "../client/api";
  15. import { ChatControllerPool } from "../client/controller";
  16. import { showToast } from "../components/ui-lib";
  17. import {
  18. DEFAULT_INPUT_TEMPLATE,
  19. DEFAULT_MODELS,
  20. DEFAULT_SYSTEM_TEMPLATE,
  21. GEMINI_SUMMARIZE_MODEL,
  22. KnowledgeCutOffDate,
  23. MCP_SYSTEM_TEMPLATE,
  24. MCP_TOOLS_TEMPLATE,
  25. ServiceProvider,
  26. StoreKey,
  27. SUMMARIZE_MODEL,
  28. } from "../constant";
  29. import Locale, { getLang } from "../locales";
  30. import { prettyObject } from "../utils/format";
  31. import { createPersistStore } from "../utils/store";
  32. import { estimateTokenLength } from "../utils/token";
  33. import { ModelConfig, ModelType, useAppConfig } from "./config";
  34. import { useAccessStore } from "./access";
  35. import { collectModelsWithDefaultModel } from "../utils/model";
  36. import { createEmptyMask, Mask } from "./mask";
  37. import { executeMcpAction, getAllTools } from "../mcp/actions";
  38. import { extractMcpJson, isMcpJson } from "../mcp/utils";
  39. const localStorage = safeLocalStorage();
  40. export type ChatMessageTool = {
  41. id: string;
  42. index?: number;
  43. type?: string;
  44. function?: {
  45. name: string;
  46. arguments?: string;
  47. };
  48. content?: string;
  49. isError?: boolean;
  50. errorMsg?: string;
  51. };
  52. export type ChatMessage = RequestMessage & {
  53. date: string;
  54. streaming?: boolean;
  55. isError?: boolean;
  56. id: string;
  57. model?: ModelType;
  58. tools?: ChatMessageTool[];
  59. audio_url?: string;
  60. isMcpResponse?: boolean;
  61. };
  62. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  63. return {
  64. id: nanoid(),
  65. date: new Date().toLocaleString(),
  66. role: "user",
  67. content: "",
  68. ...override,
  69. };
  70. }
  71. export interface ChatStat {
  72. tokenCount: number;
  73. wordCount: number;
  74. charCount: number;
  75. }
  76. export interface ChatSession {
  77. id: string;
  78. topic: string;
  79. memoryPrompt: string;
  80. messages: ChatMessage[];
  81. stat: ChatStat;
  82. lastUpdate: number;
  83. lastSummarizeIndex: number;
  84. clearContextIndex?: number;
  85. mask: Mask;
  86. }
  87. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  88. export const BOT_HELLO: ChatMessage = createMessage({
  89. role: "assistant",
  90. content: Locale.Store.BotHello,
  91. });
  92. function createEmptySession(): ChatSession {
  93. return {
  94. id: nanoid(),
  95. topic: DEFAULT_TOPIC,
  96. memoryPrompt: "",
  97. messages: [],
  98. stat: {
  99. tokenCount: 0,
  100. wordCount: 0,
  101. charCount: 0,
  102. },
  103. lastUpdate: Date.now(),
  104. lastSummarizeIndex: 0,
  105. mask: createEmptyMask(),
  106. };
  107. }
  108. function getSummarizeModel(
  109. currentModel: string,
  110. providerName: string,
  111. ): string[] {
  112. // if it is using gpt-* models, force to use 4o-mini to summarize
  113. if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
  114. const configStore = useAppConfig.getState();
  115. const accessStore = useAccessStore.getState();
  116. const allModel = collectModelsWithDefaultModel(
  117. configStore.models,
  118. [configStore.customModels, accessStore.customModels].join(","),
  119. accessStore.defaultModel,
  120. );
  121. const summarizeModel = allModel.find(
  122. (m) => m.name === SUMMARIZE_MODEL && m.available,
  123. );
  124. if (summarizeModel) {
  125. return [
  126. summarizeModel.name,
  127. summarizeModel.provider?.providerName as string,
  128. ];
  129. }
  130. }
  131. if (currentModel.startsWith("gemini")) {
  132. return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
  133. }
  134. return [currentModel, providerName];
  135. }
  136. function countMessages(msgs: ChatMessage[]) {
  137. return msgs.reduce(
  138. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  139. 0,
  140. );
  141. }
  142. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  143. const cutoff =
  144. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  145. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  146. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  147. var serviceProvider = "OpenAI";
  148. if (modelInfo) {
  149. // TODO: auto detect the providerName from the modelConfig.model
  150. // Directly use the providerName from the modelInfo
  151. serviceProvider = modelInfo.provider.providerName;
  152. }
  153. const vars = {
  154. ServiceProvider: serviceProvider,
  155. cutoff,
  156. model: modelConfig.model,
  157. time: new Date().toString(),
  158. lang: getLang(),
  159. input: input,
  160. };
  161. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  162. // remove duplicate
  163. if (input.startsWith(output)) {
  164. output = "";
  165. }
  166. // must contains {{input}}
  167. const inputVar = "{{input}}";
  168. if (!output.includes(inputVar)) {
  169. output += "\n" + inputVar;
  170. }
  171. Object.entries(vars).forEach(([name, value]) => {
  172. const regex = new RegExp(`{{${name}}}`, "g");
  173. output = output.replace(regex, value.toString()); // Ensure value is a string
  174. });
  175. return output;
  176. }
  177. async function getMcpSystemPrompt(): Promise<string> {
  178. const tools = await getAllTools();
  179. let toolsStr = "";
  180. tools.forEach((i) => {
  181. // error client has no tools
  182. if (!i.tools) return;
  183. toolsStr += MCP_TOOLS_TEMPLATE.replace(
  184. "{{ clientId }}",
  185. i.clientId,
  186. ).replace(
  187. "{{ tools }}",
  188. i.tools.tools.map((p: object) => JSON.stringify(p, null, 2)).join("\n"),
  189. );
  190. });
  191. return MCP_SYSTEM_TEMPLATE.replace("{{ MCP_TOOLS }}", toolsStr);
  192. }
  193. const DEFAULT_CHAT_STATE = {
  194. sessions: [createEmptySession()],
  195. currentSessionIndex: 0,
  196. lastInput: "",
  197. };
  198. export const useChatStore = createPersistStore(
  199. DEFAULT_CHAT_STATE,
  200. (set, _get) => {
  201. function get() {
  202. return {
  203. ..._get(),
  204. ...methods,
  205. };
  206. }
  207. const methods = {
  208. forkSession() {
  209. // 获取当前会话
  210. const currentSession = get().currentSession();
  211. if (!currentSession) return;
  212. const newSession = createEmptySession();
  213. newSession.topic = currentSession.topic;
  214. // 深拷贝消息
  215. newSession.messages = currentSession.messages.map(msg => ({
  216. ...msg,
  217. id: nanoid(), // 生成新的消息 ID
  218. }));
  219. newSession.mask = {
  220. ...currentSession.mask,
  221. modelConfig: {
  222. ...currentSession.mask.modelConfig,
  223. },
  224. };
  225. set((state) => ({
  226. currentSessionIndex: 0,
  227. sessions: [newSession, ...state.sessions],
  228. }));
  229. },
  230. clearSessions() {
  231. set(() => ({
  232. sessions: [createEmptySession()],
  233. currentSessionIndex: 0,
  234. }));
  235. },
  236. selectSession(index: number) {
  237. set({
  238. currentSessionIndex: index,
  239. });
  240. },
  241. moveSession(from: number, to: number) {
  242. set((state) => {
  243. const { sessions, currentSessionIndex: oldIndex } = state;
  244. // move the session
  245. const newSessions = [...sessions];
  246. const session = newSessions[from];
  247. newSessions.splice(from, 1);
  248. newSessions.splice(to, 0, session);
  249. // modify current session id
  250. let newIndex = oldIndex === from ? to : oldIndex;
  251. if (oldIndex > from && oldIndex <= to) {
  252. newIndex -= 1;
  253. } else if (oldIndex < from && oldIndex >= to) {
  254. newIndex += 1;
  255. }
  256. return {
  257. currentSessionIndex: newIndex,
  258. sessions: newSessions,
  259. };
  260. });
  261. },
  262. newSession(mask?: Mask) {
  263. const session = createEmptySession();
  264. if (mask) {
  265. const config = useAppConfig.getState();
  266. const globalModelConfig = config.modelConfig;
  267. session.mask = {
  268. ...mask,
  269. modelConfig: {
  270. ...globalModelConfig,
  271. ...mask.modelConfig,
  272. },
  273. };
  274. session.topic = mask.name;
  275. }
  276. set((state) => ({
  277. currentSessionIndex: 0,
  278. sessions: [session].concat(state.sessions),
  279. }));
  280. },
  281. nextSession(delta: number) {
  282. const n = get().sessions.length;
  283. const limit = (x: number) => (x + n) % n;
  284. const i = get().currentSessionIndex;
  285. get().selectSession(limit(i + delta));
  286. },
  287. deleteSession(index: number) {
  288. const deletingLastSession = get().sessions.length === 1;
  289. const deletedSession = get().sessions.at(index);
  290. if (!deletedSession) return;
  291. const sessions = get().sessions.slice();
  292. sessions.splice(index, 1);
  293. const currentIndex = get().currentSessionIndex;
  294. let nextIndex = Math.min(
  295. currentIndex - Number(index < currentIndex),
  296. sessions.length - 1,
  297. );
  298. if (deletingLastSession) {
  299. nextIndex = 0;
  300. sessions.push(createEmptySession());
  301. }
  302. // for undo delete action
  303. const restoreState = {
  304. currentSessionIndex: get().currentSessionIndex,
  305. sessions: get().sessions.slice(),
  306. };
  307. set(() => ({
  308. currentSessionIndex: nextIndex,
  309. sessions,
  310. }));
  311. showToast(
  312. Locale.Home.DeleteToast,
  313. {
  314. text: Locale.Home.Revert,
  315. onClick() {
  316. set(() => restoreState);
  317. },
  318. },
  319. 5000,
  320. );
  321. },
  322. currentSession() {
  323. let index = get().currentSessionIndex;
  324. const sessions = get().sessions;
  325. if (index < 0 || index >= sessions.length) {
  326. index = Math.min(sessions.length - 1, Math.max(0, index));
  327. set(() => ({ currentSessionIndex: index }));
  328. }
  329. const session = sessions[index];
  330. return session;
  331. },
  332. onNewMessage(message: ChatMessage, targetSession: ChatSession) {
  333. get().updateTargetSession(targetSession, (session) => {
  334. session.messages = session.messages.concat();
  335. session.lastUpdate = Date.now();
  336. });
  337. get().updateStat(message, targetSession);
  338. get().checkMcpJson(message);
  339. get().summarizeSession(false, targetSession);
  340. },
  341. async onUserInput(
  342. content: string,
  343. attachImages?: string[],
  344. isMcpResponse?: boolean,
  345. ) {
  346. const session = get().currentSession();
  347. const modelConfig = session.mask.modelConfig;
  348. // MCP Response no need to fill template
  349. let mContent: string | MultimodalContent[] = isMcpResponse
  350. ? content
  351. : fillTemplateWith(content, modelConfig);
  352. if (!isMcpResponse && attachImages && attachImages.length > 0) {
  353. mContent = [
  354. ...(content ? [{ type: "text" as const, text: content }] : []),
  355. ...attachImages.map((url) => ({
  356. type: "image_url" as const,
  357. image_url: { url },
  358. })),
  359. ];
  360. }
  361. let userMessage: ChatMessage = createMessage({
  362. role: "user",
  363. content: mContent,
  364. isMcpResponse,
  365. });
  366. const botMessage: ChatMessage = createMessage({
  367. role: "assistant",
  368. streaming: true,
  369. model: modelConfig.model,
  370. });
  371. // get recent messages
  372. const recentMessages = await get().getMessagesWithMemory();
  373. const sendMessages = recentMessages.concat(userMessage);
  374. const messageIndex = session.messages.length + 1;
  375. // save user's and bot's message
  376. get().updateTargetSession(session, (session) => {
  377. const savedUserMessage = {
  378. ...userMessage,
  379. content: mContent,
  380. };
  381. session.messages = session.messages.concat([
  382. savedUserMessage,
  383. botMessage,
  384. ]);
  385. });
  386. const api: ClientApi = getClientApi(modelConfig.providerName);
  387. // make request
  388. api.llm.chat({
  389. messages: sendMessages,
  390. config: { ...modelConfig, stream: true },
  391. onUpdate(message) {
  392. botMessage.streaming = true;
  393. if (message) {
  394. botMessage.content = message;
  395. }
  396. get().updateTargetSession(session, (session) => {
  397. session.messages = session.messages.concat();
  398. });
  399. },
  400. async onFinish(message) {
  401. botMessage.streaming = false;
  402. if (message) {
  403. botMessage.content = message;
  404. botMessage.date = new Date().toLocaleString();
  405. get().onNewMessage(botMessage, session);
  406. }
  407. ChatControllerPool.remove(session.id, botMessage.id);
  408. },
  409. onBeforeTool(tool: ChatMessageTool) {
  410. (botMessage.tools = botMessage?.tools || []).push(tool);
  411. get().updateTargetSession(session, (session) => {
  412. session.messages = session.messages.concat();
  413. });
  414. },
  415. onAfterTool(tool: ChatMessageTool) {
  416. botMessage?.tools?.forEach((t, i, tools) => {
  417. if (tool.id == t.id) {
  418. tools[i] = { ...tool };
  419. }
  420. });
  421. get().updateTargetSession(session, (session) => {
  422. session.messages = session.messages.concat();
  423. });
  424. },
  425. onError(error) {
  426. const isAborted = error.message?.includes?.("aborted");
  427. botMessage.content +=
  428. "\n\n" +
  429. prettyObject({
  430. error: true,
  431. message: error.message,
  432. });
  433. botMessage.streaming = false;
  434. userMessage.isError = !isAborted;
  435. botMessage.isError = !isAborted;
  436. get().updateTargetSession(session, (session) => {
  437. session.messages = session.messages.concat();
  438. });
  439. ChatControllerPool.remove(
  440. session.id,
  441. botMessage.id ?? messageIndex,
  442. );
  443. console.error("[Chat] failed ", error);
  444. },
  445. onController(controller) {
  446. // collect controller for stop/retry
  447. ChatControllerPool.addController(
  448. session.id,
  449. botMessage.id ?? messageIndex,
  450. controller,
  451. );
  452. },
  453. });
  454. },
  455. getMemoryPrompt() {
  456. const session = get().currentSession();
  457. if (session.memoryPrompt.length) {
  458. return {
  459. role: "system",
  460. content: Locale.Store.Prompt.History(session.memoryPrompt),
  461. date: "",
  462. } as ChatMessage;
  463. }
  464. },
  465. async getMessagesWithMemory() {
  466. const session = get().currentSession();
  467. const modelConfig = session.mask.modelConfig;
  468. const clearContextIndex = session.clearContextIndex ?? 0;
  469. const messages = session.messages.slice();
  470. const totalMessageCount = session.messages.length;
  471. // in-context prompts
  472. const contextPrompts = session.mask.context.slice();
  473. // system prompts, to get close to OpenAI Web ChatGPT
  474. const shouldInjectSystemPrompts =
  475. modelConfig.enableInjectSystemPrompts &&
  476. (session.mask.modelConfig.model.startsWith("gpt-") ||
  477. session.mask.modelConfig.model.startsWith("chatgpt-"));
  478. const mcpSystemPrompt = await getMcpSystemPrompt();
  479. var systemPrompts: ChatMessage[] = [];
  480. systemPrompts = shouldInjectSystemPrompts
  481. ? [
  482. createMessage({
  483. role: "system",
  484. content:
  485. fillTemplateWith("", {
  486. ...modelConfig,
  487. template: DEFAULT_SYSTEM_TEMPLATE,
  488. }) + mcpSystemPrompt,
  489. }),
  490. ]
  491. : [
  492. createMessage({
  493. role: "system",
  494. content: mcpSystemPrompt,
  495. }),
  496. ];
  497. if (shouldInjectSystemPrompts) {
  498. console.log(
  499. "[Global System Prompt] ",
  500. systemPrompts.at(0)?.content ?? "empty",
  501. );
  502. }
  503. const memoryPrompt = get().getMemoryPrompt();
  504. // long term memory
  505. const shouldSendLongTermMemory =
  506. modelConfig.sendMemory &&
  507. session.memoryPrompt &&
  508. session.memoryPrompt.length > 0 &&
  509. session.lastSummarizeIndex > clearContextIndex;
  510. const longTermMemoryPrompts =
  511. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  512. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  513. // short term memory
  514. const shortTermMemoryStartIndex = Math.max(
  515. 0,
  516. totalMessageCount - modelConfig.historyMessageCount,
  517. );
  518. // lets concat send messages, including 4 parts:
  519. // 0. system prompt: to get close to OpenAI Web ChatGPT
  520. // 1. long term memory: summarized memory messages
  521. // 2. pre-defined in-context prompts
  522. // 3. short term memory: latest n messages
  523. // 4. newest input message
  524. const memoryStartIndex = shouldSendLongTermMemory
  525. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  526. : shortTermMemoryStartIndex;
  527. // and if user has cleared history messages, we should exclude the memory too.
  528. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  529. const maxTokenThreshold = modelConfig.max_tokens;
  530. // get recent messages as much as possible
  531. const reversedRecentMessages = [];
  532. for (
  533. let i = totalMessageCount - 1, tokenCount = 0;
  534. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  535. i -= 1
  536. ) {
  537. const msg = messages[i];
  538. if (!msg || msg.isError) continue;
  539. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  540. reversedRecentMessages.push(msg);
  541. }
  542. // concat all messages
  543. const recentMessages = [
  544. ...systemPrompts,
  545. ...longTermMemoryPrompts,
  546. ...contextPrompts,
  547. ...reversedRecentMessages.reverse(),
  548. ];
  549. return recentMessages;
  550. },
  551. updateMessage(
  552. sessionIndex: number,
  553. messageIndex: number,
  554. updater: (message?: ChatMessage) => void,
  555. ) {
  556. const sessions = get().sessions;
  557. const session = sessions.at(sessionIndex);
  558. const messages = session?.messages;
  559. updater(messages?.at(messageIndex));
  560. set(() => ({ sessions }));
  561. },
  562. resetSession(session: ChatSession) {
  563. get().updateTargetSession(session, (session) => {
  564. session.messages = [];
  565. session.memoryPrompt = "";
  566. });
  567. },
  568. summarizeSession(
  569. refreshTitle: boolean = false,
  570. targetSession: ChatSession,
  571. ) {
  572. const config = useAppConfig.getState();
  573. const session = targetSession;
  574. const modelConfig = session.mask.modelConfig;
  575. // skip summarize when using dalle3?
  576. if (isDalle3(modelConfig.model)) {
  577. return;
  578. }
  579. // if not config compressModel, then using getSummarizeModel
  580. const [model, providerName] = modelConfig.compressModel
  581. ? [modelConfig.compressModel, modelConfig.compressProviderName]
  582. : getSummarizeModel(
  583. session.mask.modelConfig.model,
  584. session.mask.modelConfig.providerName,
  585. );
  586. const api: ClientApi = getClientApi(providerName as ServiceProvider);
  587. // remove error messages if any
  588. const messages = session.messages;
  589. // should summarize topic after chating more than 50 words
  590. const SUMMARIZE_MIN_LEN = 50;
  591. if (
  592. (config.enableAutoGenerateTitle &&
  593. session.topic === DEFAULT_TOPIC &&
  594. countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
  595. refreshTitle
  596. ) {
  597. const startIndex = Math.max(
  598. 0,
  599. messages.length - modelConfig.historyMessageCount,
  600. );
  601. const topicMessages = messages
  602. .slice(
  603. startIndex < messages.length ? startIndex : messages.length - 1,
  604. messages.length,
  605. )
  606. .concat(
  607. createMessage({
  608. role: "user",
  609. content: Locale.Store.Prompt.Topic,
  610. }),
  611. );
  612. api.llm.chat({
  613. messages: topicMessages,
  614. config: {
  615. model,
  616. stream: false,
  617. providerName,
  618. },
  619. onFinish(message, responseRes) {
  620. if (responseRes?.status === 200) {
  621. get().updateTargetSession(
  622. session,
  623. (session) =>
  624. (session.topic =
  625. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  626. );
  627. }
  628. },
  629. });
  630. }
  631. const summarizeIndex = Math.max(
  632. session.lastSummarizeIndex,
  633. session.clearContextIndex ?? 0,
  634. );
  635. let toBeSummarizedMsgs = messages
  636. .filter((msg) => !msg.isError)
  637. .slice(summarizeIndex);
  638. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  639. if (historyMsgLength > (modelConfig?.max_tokens || 4000)) {
  640. const n = toBeSummarizedMsgs.length;
  641. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  642. Math.max(0, n - modelConfig.historyMessageCount),
  643. );
  644. }
  645. const memoryPrompt = get().getMemoryPrompt();
  646. if (memoryPrompt) {
  647. // add memory prompt
  648. toBeSummarizedMsgs.unshift(memoryPrompt);
  649. }
  650. const lastSummarizeIndex = session.messages.length;
  651. console.log(
  652. "[Chat History] ",
  653. toBeSummarizedMsgs,
  654. historyMsgLength,
  655. modelConfig.compressMessageLengthThreshold,
  656. );
  657. if (
  658. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  659. modelConfig.sendMemory
  660. ) {
  661. /** Destruct max_tokens while summarizing
  662. * this param is just shit
  663. **/
  664. const { max_tokens, ...modelcfg } = modelConfig;
  665. api.llm.chat({
  666. messages: toBeSummarizedMsgs.concat(
  667. createMessage({
  668. role: "system",
  669. content: Locale.Store.Prompt.Summarize,
  670. date: "",
  671. }),
  672. ),
  673. config: {
  674. ...modelcfg,
  675. stream: true,
  676. model,
  677. providerName,
  678. },
  679. onUpdate(message) {
  680. session.memoryPrompt = message;
  681. },
  682. onFinish(message, responseRes) {
  683. if (responseRes?.status === 200) {
  684. console.log("[Memory] ", message);
  685. get().updateTargetSession(session, (session) => {
  686. session.lastSummarizeIndex = lastSummarizeIndex;
  687. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  688. });
  689. }
  690. },
  691. onError(err) {
  692. console.error("[Summarize] ", err);
  693. },
  694. });
  695. }
  696. },
  697. updateStat(message: ChatMessage, session: ChatSession) {
  698. get().updateTargetSession(session, (session) => {
  699. session.stat.charCount += message.content.length;
  700. // TODO: should update chat count and word count
  701. });
  702. },
  703. updateTargetSession(
  704. targetSession: ChatSession,
  705. updater: (session: ChatSession) => void,
  706. ) {
  707. const sessions = get().sessions;
  708. const index = sessions.findIndex((s) => s.id === targetSession.id);
  709. if (index < 0) return;
  710. updater(sessions[index]);
  711. set(() => ({ sessions }));
  712. },
  713. async clearAllData() {
  714. await indexedDBStorage.clear();
  715. localStorage.clear();
  716. location.reload();
  717. },
  718. setLastInput(lastInput: string) {
  719. set({
  720. lastInput,
  721. });
  722. },
  723. /** check if the message contains MCP JSON and execute the MCP action */
  724. checkMcpJson(message: ChatMessage) {
  725. const content = getMessageTextContent(message);
  726. if (isMcpJson(content)) {
  727. try {
  728. const mcpRequest = extractMcpJson(content);
  729. if (mcpRequest) {
  730. console.debug("[MCP Request]", mcpRequest);
  731. executeMcpAction(mcpRequest.clientId, mcpRequest.mcp)
  732. .then((result) => {
  733. console.log("[MCP Response]", result);
  734. const mcpResponse =
  735. typeof result === "object"
  736. ? JSON.stringify(result)
  737. : String(result);
  738. get().onUserInput(
  739. `\`\`\`json:mcp-response:${mcpRequest.clientId}\n${mcpResponse}\n\`\`\``,
  740. [],
  741. true,
  742. );
  743. })
  744. .catch((error) => showToast("MCP execution failed", error));
  745. }
  746. } catch (error) {
  747. console.error("[Check MCP JSON]", error);
  748. }
  749. }
  750. },
  751. };
  752. return methods;
  753. },
  754. {
  755. name: StoreKey.Chat,
  756. version: 3.3,
  757. migrate(persistedState, version) {
  758. const state = persistedState as any;
  759. const newState = JSON.parse(
  760. JSON.stringify(state),
  761. ) as typeof DEFAULT_CHAT_STATE;
  762. if (version < 2) {
  763. newState.sessions = [];
  764. const oldSessions = state.sessions;
  765. for (const oldSession of oldSessions) {
  766. const newSession = createEmptySession();
  767. newSession.topic = oldSession.topic;
  768. newSession.messages = [...oldSession.messages];
  769. newSession.mask.modelConfig.sendMemory = true;
  770. newSession.mask.modelConfig.historyMessageCount = 4;
  771. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  772. newState.sessions.push(newSession);
  773. }
  774. }
  775. if (version < 3) {
  776. // migrate id to nanoid
  777. newState.sessions.forEach((s) => {
  778. s.id = nanoid();
  779. s.messages.forEach((m) => (m.id = nanoid()));
  780. });
  781. }
  782. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  783. // Resolve issue of old sessions not automatically enabling.
  784. if (version < 3.1) {
  785. newState.sessions.forEach((s) => {
  786. if (
  787. // Exclude those already set by user
  788. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  789. ) {
  790. // Because users may have changed this configuration,
  791. // the user's current configuration is used instead of the default
  792. const config = useAppConfig.getState();
  793. s.mask.modelConfig.enableInjectSystemPrompts =
  794. config.modelConfig.enableInjectSystemPrompts;
  795. }
  796. });
  797. }
  798. // add default summarize model for every session
  799. if (version < 3.2) {
  800. newState.sessions.forEach((s) => {
  801. const config = useAppConfig.getState();
  802. s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
  803. s.mask.modelConfig.compressProviderName =
  804. config.modelConfig.compressProviderName;
  805. });
  806. }
  807. // revert default summarize model for every session
  808. if (version < 3.3) {
  809. newState.sessions.forEach((s) => {
  810. const config = useAppConfig.getState();
  811. s.mask.modelConfig.compressModel = "";
  812. s.mask.modelConfig.compressProviderName = "";
  813. });
  814. }
  815. return newState as any;
  816. },
  817. },
  818. );