chat.ts 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915
  1. import {
  2. getMessageTextContent,
  3. isDalle3,
  4. safeLocalStorage,
  5. trimTopic,
  6. } from "../utils";
  7. import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
  8. import { nanoid } from "nanoid";
  9. import type {
  10. ClientApi,
  11. MultimodalContent,
  12. RequestMessage,
  13. } from "../client/api";
  14. import { getClientApi } from "../client/api";
  15. import { ChatControllerPool } from "../client/controller";
  16. import { showToast } from "../components/ui-lib";
  17. import {
  18. DEFAULT_INPUT_TEMPLATE,
  19. DEFAULT_MODELS,
  20. DEFAULT_SYSTEM_TEMPLATE,
  21. GEMINI_SUMMARIZE_MODEL,
  22. KnowledgeCutOffDate,
  23. MCP_PRIMITIVES_TEMPLATE,
  24. MCP_SYSTEM_TEMPLATE,
  25. ServiceProvider,
  26. StoreKey,
  27. SUMMARIZE_MODEL,
  28. } from "../constant";
  29. import Locale, { getLang } from "../locales";
  30. import { prettyObject } from "../utils/format";
  31. import { createPersistStore } from "../utils/store";
  32. import { estimateTokenLength } from "../utils/token";
  33. import { ModelConfig, ModelType, useAppConfig } from "./config";
  34. import { useAccessStore } from "./access";
  35. import { collectModelsWithDefaultModel } from "../utils/model";
  36. import { createEmptyMask, Mask } from "./mask";
  37. import { executeMcpAction, getAllPrimitives } from "../mcp/actions";
  38. import { extractMcpJson, isMcpJson } from "../mcp/utils";
  39. const localStorage = safeLocalStorage();
  40. export type ChatMessageTool = {
  41. id: string;
  42. index?: number;
  43. type?: string;
  44. function?: {
  45. name: string;
  46. arguments?: string;
  47. };
  48. content?: string;
  49. isError?: boolean;
  50. errorMsg?: string;
  51. };
  52. export type ChatMessage = RequestMessage & {
  53. date: string;
  54. streaming?: boolean;
  55. isError?: boolean;
  56. id: string;
  57. model?: ModelType;
  58. tools?: ChatMessageTool[];
  59. audio_url?: string;
  60. isMcpResponse?: boolean;
  61. };
  62. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  63. return {
  64. id: nanoid(),
  65. date: new Date().toLocaleString(),
  66. role: "user",
  67. content: "",
  68. ...override,
  69. };
  70. }
  71. export interface ChatStat {
  72. tokenCount: number;
  73. wordCount: number;
  74. charCount: number;
  75. }
  76. export interface ChatSession {
  77. id: string;
  78. topic: string;
  79. memoryPrompt: string;
  80. messages: ChatMessage[];
  81. stat: ChatStat;
  82. lastUpdate: number;
  83. lastSummarizeIndex: number;
  84. clearContextIndex?: number;
  85. mask: Mask;
  86. }
  87. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  88. export const BOT_HELLO: ChatMessage = createMessage({
  89. role: "assistant",
  90. content: Locale.Store.BotHello,
  91. });
  92. function createEmptySession(): ChatSession {
  93. return {
  94. id: nanoid(),
  95. topic: DEFAULT_TOPIC,
  96. memoryPrompt: "",
  97. messages: [],
  98. stat: {
  99. tokenCount: 0,
  100. wordCount: 0,
  101. charCount: 0,
  102. },
  103. lastUpdate: Date.now(),
  104. lastSummarizeIndex: 0,
  105. mask: createEmptyMask(),
  106. };
  107. }
  108. function getSummarizeModel(
  109. currentModel: string,
  110. providerName: string,
  111. ): string[] {
  112. // if it is using gpt-* models, force to use 4o-mini to summarize
  113. if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
  114. const configStore = useAppConfig.getState();
  115. const accessStore = useAccessStore.getState();
  116. const allModel = collectModelsWithDefaultModel(
  117. configStore.models,
  118. [configStore.customModels, accessStore.customModels].join(","),
  119. accessStore.defaultModel,
  120. );
  121. const summarizeModel = allModel.find(
  122. (m) => m.name === SUMMARIZE_MODEL && m.available,
  123. );
  124. if (summarizeModel) {
  125. return [
  126. summarizeModel.name,
  127. summarizeModel.provider?.providerName as string,
  128. ];
  129. }
  130. }
  131. if (currentModel.startsWith("gemini")) {
  132. return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
  133. }
  134. return [currentModel, providerName];
  135. }
  136. function countMessages(msgs: ChatMessage[]) {
  137. return msgs.reduce(
  138. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  139. 0,
  140. );
  141. }
  142. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  143. const cutoff =
  144. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  145. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  146. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  147. var serviceProvider = "OpenAI";
  148. if (modelInfo) {
  149. // TODO: auto detect the providerName from the modelConfig.model
  150. // Directly use the providerName from the modelInfo
  151. serviceProvider = modelInfo.provider.providerName;
  152. }
  153. const vars = {
  154. ServiceProvider: serviceProvider,
  155. cutoff,
  156. model: modelConfig.model,
  157. time: new Date().toString(),
  158. lang: getLang(),
  159. input: input,
  160. };
  161. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  162. // remove duplicate
  163. if (input.startsWith(output)) {
  164. output = "";
  165. }
  166. // must contains {{input}}
  167. const inputVar = "{{input}}";
  168. if (!output.includes(inputVar)) {
  169. output += "\n" + inputVar;
  170. }
  171. Object.entries(vars).forEach(([name, value]) => {
  172. const regex = new RegExp(`{{${name}}}`, "g");
  173. output = output.replace(regex, value.toString()); // Ensure value is a string
  174. });
  175. return output;
  176. }
  177. async function getMcpSystemPrompt(): Promise<string> {
  178. let primitives = await getAllPrimitives();
  179. primitives = primitives.filter((i) =>
  180. i.primitives.some((p) => p.type === "tool"),
  181. );
  182. let primitivesString = "";
  183. primitives.forEach((i) => {
  184. primitivesString += MCP_PRIMITIVES_TEMPLATE.replace(
  185. "{{ clientId }}",
  186. i.clientId,
  187. ).replace(
  188. "{{ primitives }}",
  189. i.primitives.map((p) => JSON.stringify(p, null, 2)).join("\n"),
  190. );
  191. });
  192. return MCP_SYSTEM_TEMPLATE.replace("{{ MCP_PRIMITIVES }}", primitivesString);
  193. }
  194. const DEFAULT_CHAT_STATE = {
  195. sessions: [createEmptySession()],
  196. currentSessionIndex: 0,
  197. lastInput: "",
  198. };
  199. export const useChatStore = createPersistStore(
  200. DEFAULT_CHAT_STATE,
  201. (set, _get) => {
  202. function get() {
  203. return {
  204. ..._get(),
  205. ...methods,
  206. };
  207. }
  208. const methods = {
  209. forkSession() {
  210. // 获取当前会话
  211. const currentSession = get().currentSession();
  212. if (!currentSession) return;
  213. const newSession = createEmptySession();
  214. newSession.topic = currentSession.topic;
  215. newSession.messages = [...currentSession.messages];
  216. newSession.mask = {
  217. ...currentSession.mask,
  218. modelConfig: {
  219. ...currentSession.mask.modelConfig,
  220. },
  221. };
  222. set((state) => ({
  223. currentSessionIndex: 0,
  224. sessions: [newSession, ...state.sessions],
  225. }));
  226. },
  227. clearSessions() {
  228. set(() => ({
  229. sessions: [createEmptySession()],
  230. currentSessionIndex: 0,
  231. }));
  232. },
  233. selectSession(index: number) {
  234. set({
  235. currentSessionIndex: index,
  236. });
  237. },
  238. moveSession(from: number, to: number) {
  239. set((state) => {
  240. const { sessions, currentSessionIndex: oldIndex } = state;
  241. // move the session
  242. const newSessions = [...sessions];
  243. const session = newSessions[from];
  244. newSessions.splice(from, 1);
  245. newSessions.splice(to, 0, session);
  246. // modify current session id
  247. let newIndex = oldIndex === from ? to : oldIndex;
  248. if (oldIndex > from && oldIndex <= to) {
  249. newIndex -= 1;
  250. } else if (oldIndex < from && oldIndex >= to) {
  251. newIndex += 1;
  252. }
  253. return {
  254. currentSessionIndex: newIndex,
  255. sessions: newSessions,
  256. };
  257. });
  258. },
  259. newSession(mask?: Mask) {
  260. const session = createEmptySession();
  261. if (mask) {
  262. const config = useAppConfig.getState();
  263. const globalModelConfig = config.modelConfig;
  264. session.mask = {
  265. ...mask,
  266. modelConfig: {
  267. ...globalModelConfig,
  268. ...mask.modelConfig,
  269. },
  270. };
  271. session.topic = mask.name;
  272. }
  273. set((state) => ({
  274. currentSessionIndex: 0,
  275. sessions: [session].concat(state.sessions),
  276. }));
  277. },
  278. nextSession(delta: number) {
  279. const n = get().sessions.length;
  280. const limit = (x: number) => (x + n) % n;
  281. const i = get().currentSessionIndex;
  282. get().selectSession(limit(i + delta));
  283. },
  284. deleteSession(index: number) {
  285. const deletingLastSession = get().sessions.length === 1;
  286. const deletedSession = get().sessions.at(index);
  287. if (!deletedSession) return;
  288. const sessions = get().sessions.slice();
  289. sessions.splice(index, 1);
  290. const currentIndex = get().currentSessionIndex;
  291. let nextIndex = Math.min(
  292. currentIndex - Number(index < currentIndex),
  293. sessions.length - 1,
  294. );
  295. if (deletingLastSession) {
  296. nextIndex = 0;
  297. sessions.push(createEmptySession());
  298. }
  299. // for undo delete action
  300. const restoreState = {
  301. currentSessionIndex: get().currentSessionIndex,
  302. sessions: get().sessions.slice(),
  303. };
  304. set(() => ({
  305. currentSessionIndex: nextIndex,
  306. sessions,
  307. }));
  308. showToast(
  309. Locale.Home.DeleteToast,
  310. {
  311. text: Locale.Home.Revert,
  312. onClick() {
  313. set(() => restoreState);
  314. },
  315. },
  316. 5000,
  317. );
  318. },
  319. currentSession() {
  320. let index = get().currentSessionIndex;
  321. const sessions = get().sessions;
  322. if (index < 0 || index >= sessions.length) {
  323. index = Math.min(sessions.length - 1, Math.max(0, index));
  324. set(() => ({ currentSessionIndex: index }));
  325. }
  326. const session = sessions[index];
  327. return session;
  328. },
  329. onNewMessage(message: ChatMessage, targetSession: ChatSession) {
  330. get().updateTargetSession(targetSession, (session) => {
  331. session.messages = session.messages.concat();
  332. session.lastUpdate = Date.now();
  333. });
  334. get().updateStat(message, targetSession);
  335. get().checkMcpJson(message);
  336. get().summarizeSession(false, targetSession);
  337. },
  338. async onUserInput(
  339. content: string,
  340. attachImages?: string[],
  341. isMcpResponse?: boolean,
  342. ) {
  343. const session = get().currentSession();
  344. const modelConfig = session.mask.modelConfig;
  345. // MCP Response no need to fill template
  346. let mContent: string | MultimodalContent[] = isMcpResponse
  347. ? content
  348. : fillTemplateWith(content, modelConfig);
  349. if (!isMcpResponse && attachImages && attachImages.length > 0) {
  350. mContent = [
  351. ...(content ? [{ type: "text" as const, text: content }] : []),
  352. ...attachImages.map((url) => ({
  353. type: "image_url" as const,
  354. image_url: { url },
  355. })),
  356. ];
  357. }
  358. let userMessage: ChatMessage = createMessage({
  359. role: "user",
  360. content: mContent,
  361. isMcpResponse,
  362. });
  363. const botMessage: ChatMessage = createMessage({
  364. role: "assistant",
  365. streaming: true,
  366. model: modelConfig.model,
  367. });
  368. // get recent messages
  369. const recentMessages = await get().getMessagesWithMemory();
  370. const sendMessages = recentMessages.concat(userMessage);
  371. const messageIndex = session.messages.length + 1;
  372. // save user's and bot's message
  373. get().updateTargetSession(session, (session) => {
  374. const savedUserMessage = {
  375. ...userMessage,
  376. content: mContent,
  377. };
  378. session.messages = session.messages.concat([
  379. savedUserMessage,
  380. botMessage,
  381. ]);
  382. });
  383. const api: ClientApi = getClientApi(modelConfig.providerName);
  384. // make request
  385. api.llm.chat({
  386. messages: sendMessages,
  387. config: { ...modelConfig, stream: true },
  388. onUpdate(message) {
  389. botMessage.streaming = true;
  390. if (message) {
  391. botMessage.content = message;
  392. }
  393. get().updateTargetSession(session, (session) => {
  394. session.messages = session.messages.concat();
  395. });
  396. },
  397. async onFinish(message) {
  398. botMessage.streaming = false;
  399. if (message) {
  400. botMessage.content = message;
  401. botMessage.date = new Date().toLocaleString();
  402. get().onNewMessage(botMessage, session);
  403. }
  404. ChatControllerPool.remove(session.id, botMessage.id);
  405. },
  406. onBeforeTool(tool: ChatMessageTool) {
  407. (botMessage.tools = botMessage?.tools || []).push(tool);
  408. get().updateTargetSession(session, (session) => {
  409. session.messages = session.messages.concat();
  410. });
  411. },
  412. onAfterTool(tool: ChatMessageTool) {
  413. botMessage?.tools?.forEach((t, i, tools) => {
  414. if (tool.id == t.id) {
  415. tools[i] = { ...tool };
  416. }
  417. });
  418. get().updateTargetSession(session, (session) => {
  419. session.messages = session.messages.concat();
  420. });
  421. },
  422. onError(error) {
  423. const isAborted = error.message?.includes?.("aborted");
  424. botMessage.content +=
  425. "\n\n" +
  426. prettyObject({
  427. error: true,
  428. message: error.message,
  429. });
  430. botMessage.streaming = false;
  431. userMessage.isError = !isAborted;
  432. botMessage.isError = !isAborted;
  433. get().updateTargetSession(session, (session) => {
  434. session.messages = session.messages.concat();
  435. });
  436. ChatControllerPool.remove(
  437. session.id,
  438. botMessage.id ?? messageIndex,
  439. );
  440. console.error("[Chat] failed ", error);
  441. },
  442. onController(controller) {
  443. // collect controller for stop/retry
  444. ChatControllerPool.addController(
  445. session.id,
  446. botMessage.id ?? messageIndex,
  447. controller,
  448. );
  449. },
  450. });
  451. },
  452. getMemoryPrompt() {
  453. const session = get().currentSession();
  454. if (session.memoryPrompt.length) {
  455. return {
  456. role: "system",
  457. content: Locale.Store.Prompt.History(session.memoryPrompt),
  458. date: "",
  459. } as ChatMessage;
  460. }
  461. },
  462. async getMessagesWithMemory() {
  463. const session = get().currentSession();
  464. const modelConfig = session.mask.modelConfig;
  465. const clearContextIndex = session.clearContextIndex ?? 0;
  466. const messages = session.messages.slice();
  467. const totalMessageCount = session.messages.length;
  468. // in-context prompts
  469. const contextPrompts = session.mask.context.slice();
  470. // system prompts, to get close to OpenAI Web ChatGPT
  471. const shouldInjectSystemPrompts =
  472. modelConfig.enableInjectSystemPrompts &&
  473. (session.mask.modelConfig.model.startsWith("gpt-") ||
  474. session.mask.modelConfig.model.startsWith("chatgpt-"));
  475. const mcpSystemPrompt = await getMcpSystemPrompt();
  476. var systemPrompts: ChatMessage[] = [];
  477. systemPrompts = shouldInjectSystemPrompts
  478. ? [
  479. createMessage({
  480. role: "system",
  481. content:
  482. fillTemplateWith("", {
  483. ...modelConfig,
  484. template: DEFAULT_SYSTEM_TEMPLATE,
  485. }) + mcpSystemPrompt,
  486. }),
  487. ]
  488. : [
  489. createMessage({
  490. role: "system",
  491. content: mcpSystemPrompt,
  492. }),
  493. ];
  494. if (shouldInjectSystemPrompts) {
  495. console.log(
  496. "[Global System Prompt] ",
  497. systemPrompts.at(0)?.content ?? "empty",
  498. );
  499. }
  500. const memoryPrompt = get().getMemoryPrompt();
  501. // long term memory
  502. const shouldSendLongTermMemory =
  503. modelConfig.sendMemory &&
  504. session.memoryPrompt &&
  505. session.memoryPrompt.length > 0 &&
  506. session.lastSummarizeIndex > clearContextIndex;
  507. const longTermMemoryPrompts =
  508. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  509. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  510. // short term memory
  511. const shortTermMemoryStartIndex = Math.max(
  512. 0,
  513. totalMessageCount - modelConfig.historyMessageCount,
  514. );
  515. // lets concat send messages, including 4 parts:
  516. // 0. system prompt: to get close to OpenAI Web ChatGPT
  517. // 1. long term memory: summarized memory messages
  518. // 2. pre-defined in-context prompts
  519. // 3. short term memory: latest n messages
  520. // 4. newest input message
  521. const memoryStartIndex = shouldSendLongTermMemory
  522. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  523. : shortTermMemoryStartIndex;
  524. // and if user has cleared history messages, we should exclude the memory too.
  525. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  526. const maxTokenThreshold = modelConfig.max_tokens;
  527. // get recent messages as much as possible
  528. const reversedRecentMessages = [];
  529. for (
  530. let i = totalMessageCount - 1, tokenCount = 0;
  531. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  532. i -= 1
  533. ) {
  534. const msg = messages[i];
  535. if (!msg || msg.isError) continue;
  536. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  537. reversedRecentMessages.push(msg);
  538. }
  539. // concat all messages
  540. const recentMessages = [
  541. ...systemPrompts,
  542. ...longTermMemoryPrompts,
  543. ...contextPrompts,
  544. ...reversedRecentMessages.reverse(),
  545. ];
  546. return recentMessages;
  547. },
  548. updateMessage(
  549. sessionIndex: number,
  550. messageIndex: number,
  551. updater: (message?: ChatMessage) => void,
  552. ) {
  553. const sessions = get().sessions;
  554. const session = sessions.at(sessionIndex);
  555. const messages = session?.messages;
  556. updater(messages?.at(messageIndex));
  557. set(() => ({ sessions }));
  558. },
  559. resetSession(session: ChatSession) {
  560. get().updateTargetSession(session, (session) => {
  561. session.messages = [];
  562. session.memoryPrompt = "";
  563. });
  564. },
  565. summarizeSession(
  566. refreshTitle: boolean = false,
  567. targetSession: ChatSession,
  568. ) {
  569. const config = useAppConfig.getState();
  570. const session = targetSession;
  571. const modelConfig = session.mask.modelConfig;
  572. // skip summarize when using dalle3?
  573. if (isDalle3(modelConfig.model)) {
  574. return;
  575. }
  576. // if not config compressModel, then using getSummarizeModel
  577. const [model, providerName] = modelConfig.compressModel
  578. ? [modelConfig.compressModel, modelConfig.compressProviderName]
  579. : getSummarizeModel(
  580. session.mask.modelConfig.model,
  581. session.mask.modelConfig.providerName,
  582. );
  583. const api: ClientApi = getClientApi(providerName as ServiceProvider);
  584. // remove error messages if any
  585. const messages = session.messages;
  586. // should summarize topic after chating more than 50 words
  587. const SUMMARIZE_MIN_LEN = 50;
  588. if (
  589. (config.enableAutoGenerateTitle &&
  590. session.topic === DEFAULT_TOPIC &&
  591. countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
  592. refreshTitle
  593. ) {
  594. const startIndex = Math.max(
  595. 0,
  596. messages.length - modelConfig.historyMessageCount,
  597. );
  598. const topicMessages = messages
  599. .slice(
  600. startIndex < messages.length ? startIndex : messages.length - 1,
  601. messages.length,
  602. )
  603. .concat(
  604. createMessage({
  605. role: "user",
  606. content: Locale.Store.Prompt.Topic,
  607. }),
  608. );
  609. api.llm.chat({
  610. messages: topicMessages,
  611. config: {
  612. model,
  613. stream: false,
  614. providerName,
  615. },
  616. onFinish(message, responseRes) {
  617. if (responseRes?.status === 200) {
  618. get().updateTargetSession(
  619. session,
  620. (session) =>
  621. (session.topic =
  622. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  623. );
  624. }
  625. },
  626. });
  627. }
  628. const summarizeIndex = Math.max(
  629. session.lastSummarizeIndex,
  630. session.clearContextIndex ?? 0,
  631. );
  632. let toBeSummarizedMsgs = messages
  633. .filter((msg) => !msg.isError)
  634. .slice(summarizeIndex);
  635. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  636. if (historyMsgLength > (modelConfig?.max_tokens || 4000)) {
  637. const n = toBeSummarizedMsgs.length;
  638. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  639. Math.max(0, n - modelConfig.historyMessageCount),
  640. );
  641. }
  642. const memoryPrompt = get().getMemoryPrompt();
  643. if (memoryPrompt) {
  644. // add memory prompt
  645. toBeSummarizedMsgs.unshift(memoryPrompt);
  646. }
  647. const lastSummarizeIndex = session.messages.length;
  648. console.log(
  649. "[Chat History] ",
  650. toBeSummarizedMsgs,
  651. historyMsgLength,
  652. modelConfig.compressMessageLengthThreshold,
  653. );
  654. if (
  655. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  656. modelConfig.sendMemory
  657. ) {
  658. /** Destruct max_tokens while summarizing
  659. * this param is just shit
  660. **/
  661. const { max_tokens, ...modelcfg } = modelConfig;
  662. api.llm.chat({
  663. messages: toBeSummarizedMsgs.concat(
  664. createMessage({
  665. role: "system",
  666. content: Locale.Store.Prompt.Summarize,
  667. date: "",
  668. }),
  669. ),
  670. config: {
  671. ...modelcfg,
  672. stream: true,
  673. model,
  674. providerName,
  675. },
  676. onUpdate(message) {
  677. session.memoryPrompt = message;
  678. },
  679. onFinish(message, responseRes) {
  680. if (responseRes?.status === 200) {
  681. console.log("[Memory] ", message);
  682. get().updateTargetSession(session, (session) => {
  683. session.lastSummarizeIndex = lastSummarizeIndex;
  684. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  685. });
  686. }
  687. },
  688. onError(err) {
  689. console.error("[Summarize] ", err);
  690. },
  691. });
  692. }
  693. },
  694. updateStat(message: ChatMessage, session: ChatSession) {
  695. get().updateTargetSession(session, (session) => {
  696. session.stat.charCount += message.content.length;
  697. // TODO: should update chat count and word count
  698. });
  699. },
  700. updateTargetSession(
  701. targetSession: ChatSession,
  702. updater: (session: ChatSession) => void,
  703. ) {
  704. const sessions = get().sessions;
  705. const index = sessions.findIndex((s) => s.id === targetSession.id);
  706. if (index < 0) return;
  707. updater(sessions[index]);
  708. set(() => ({ sessions }));
  709. },
  710. async clearAllData() {
  711. await indexedDBStorage.clear();
  712. localStorage.clear();
  713. location.reload();
  714. },
  715. setLastInput(lastInput: string) {
  716. set({
  717. lastInput,
  718. });
  719. },
  720. /** check if the message contains MCP JSON and execute the MCP action */
  721. checkMcpJson(message: ChatMessage) {
  722. const content = getMessageTextContent(message);
  723. if (isMcpJson(content)) {
  724. try {
  725. const mcpRequest = extractMcpJson(content);
  726. if (mcpRequest) {
  727. console.debug("[MCP Request]", mcpRequest);
  728. executeMcpAction(mcpRequest.clientId, mcpRequest.mcp)
  729. .then((result) => {
  730. console.log("[MCP Response]", result);
  731. const mcpResponse =
  732. typeof result === "object"
  733. ? JSON.stringify(result)
  734. : String(result);
  735. get().onUserInput(
  736. `\`\`\`json:mcp-response:${mcpRequest.clientId}\n${mcpResponse}\n\`\`\``,
  737. [],
  738. true,
  739. );
  740. })
  741. .catch((error) => showToast("MCP execution failed", error));
  742. }
  743. } catch (error) {
  744. console.error("[Check MCP JSON]", error);
  745. }
  746. }
  747. },
  748. };
  749. return methods;
  750. },
  751. {
  752. name: StoreKey.Chat,
  753. version: 3.3,
  754. migrate(persistedState, version) {
  755. const state = persistedState as any;
  756. const newState = JSON.parse(
  757. JSON.stringify(state),
  758. ) as typeof DEFAULT_CHAT_STATE;
  759. if (version < 2) {
  760. newState.sessions = [];
  761. const oldSessions = state.sessions;
  762. for (const oldSession of oldSessions) {
  763. const newSession = createEmptySession();
  764. newSession.topic = oldSession.topic;
  765. newSession.messages = [...oldSession.messages];
  766. newSession.mask.modelConfig.sendMemory = true;
  767. newSession.mask.modelConfig.historyMessageCount = 4;
  768. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  769. newState.sessions.push(newSession);
  770. }
  771. }
  772. if (version < 3) {
  773. // migrate id to nanoid
  774. newState.sessions.forEach((s) => {
  775. s.id = nanoid();
  776. s.messages.forEach((m) => (m.id = nanoid()));
  777. });
  778. }
  779. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  780. // Resolve issue of old sessions not automatically enabling.
  781. if (version < 3.1) {
  782. newState.sessions.forEach((s) => {
  783. if (
  784. // Exclude those already set by user
  785. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  786. ) {
  787. // Because users may have changed this configuration,
  788. // the user's current configuration is used instead of the default
  789. const config = useAppConfig.getState();
  790. s.mask.modelConfig.enableInjectSystemPrompts =
  791. config.modelConfig.enableInjectSystemPrompts;
  792. }
  793. });
  794. }
  795. // add default summarize model for every session
  796. if (version < 3.2) {
  797. newState.sessions.forEach((s) => {
  798. const config = useAppConfig.getState();
  799. s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
  800. s.mask.modelConfig.compressProviderName =
  801. config.modelConfig.compressProviderName;
  802. });
  803. }
  804. // revert default summarize model for every session
  805. if (version < 3.3) {
  806. newState.sessions.forEach((s) => {
  807. const config = useAppConfig.getState();
  808. s.mask.modelConfig.compressModel = "";
  809. s.mask.modelConfig.compressProviderName = "";
  810. });
  811. }
  812. return newState as any;
  813. },
  814. },
  815. );