chat.ts 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858
  1. import { getMessageTextContent, trimTopic } from "../utils";
  2. import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
  3. import { nanoid } from "nanoid";
  4. import type {
  5. ClientApi,
  6. MultimodalContent,
  7. RequestMessage,
  8. } from "../client/api";
  9. import { getClientApi } from "../client/api";
  10. import { ChatControllerPool } from "../client/controller";
  11. import { showToast } from "../components/ui-lib";
  12. import {
  13. DEFAULT_INPUT_TEMPLATE,
  14. DEFAULT_MODELS,
  15. DEFAULT_SYSTEM_TEMPLATE,
  16. KnowledgeCutOffDate,
  17. StoreKey,
  18. SUMMARIZE_MODEL,
  19. GEMINI_SUMMARIZE_MODEL,
  20. ServiceProvider,
  21. } from "../constant";
  22. import Locale, { getLang } from "../locales";
  23. import { isDalle3, safeLocalStorage } from "../utils";
  24. import { prettyObject } from "../utils/format";
  25. import { createPersistStore } from "../utils/store";
  26. import { estimateTokenLength } from "../utils/token";
  27. import { ModelConfig, ModelType, useAppConfig } from "./config";
  28. import { useAccessStore } from "./access";
  29. import { collectModelsWithDefaultModel } from "../utils/model";
  30. import { createEmptyMask, Mask } from "./mask";
  31. import { executeMcpAction } from "../mcp/actions";
  32. const localStorage = safeLocalStorage();
  33. export type ChatMessageTool = {
  34. id: string;
  35. index?: number;
  36. type?: string;
  37. function?: {
  38. name: string;
  39. arguments?: string;
  40. };
  41. content?: string;
  42. isError?: boolean;
  43. errorMsg?: string;
  44. };
  45. export type ChatMessage = RequestMessage & {
  46. date: string;
  47. streaming?: boolean;
  48. isError?: boolean;
  49. id: string;
  50. model?: ModelType;
  51. tools?: ChatMessageTool[];
  52. audio_url?: string;
  53. };
  54. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  55. return {
  56. id: nanoid(),
  57. date: new Date().toLocaleString(),
  58. role: "user",
  59. content: "",
  60. ...override,
  61. };
  62. }
  63. export interface ChatStat {
  64. tokenCount: number;
  65. wordCount: number;
  66. charCount: number;
  67. }
  68. export interface ChatSession {
  69. id: string;
  70. topic: string;
  71. memoryPrompt: string;
  72. messages: ChatMessage[];
  73. stat: ChatStat;
  74. lastUpdate: number;
  75. lastSummarizeIndex: number;
  76. clearContextIndex?: number;
  77. mask: Mask;
  78. }
  79. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  80. export const BOT_HELLO: ChatMessage = createMessage({
  81. role: "assistant",
  82. content: Locale.Store.BotHello,
  83. });
  84. function createEmptySession(): ChatSession {
  85. return {
  86. id: nanoid(),
  87. topic: DEFAULT_TOPIC,
  88. memoryPrompt: "",
  89. messages: [],
  90. stat: {
  91. tokenCount: 0,
  92. wordCount: 0,
  93. charCount: 0,
  94. },
  95. lastUpdate: Date.now(),
  96. lastSummarizeIndex: 0,
  97. mask: createEmptyMask(),
  98. };
  99. }
  100. function getSummarizeModel(
  101. currentModel: string,
  102. providerName: string,
  103. ): string[] {
  104. // if it is using gpt-* models, force to use 4o-mini to summarize
  105. if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
  106. const configStore = useAppConfig.getState();
  107. const accessStore = useAccessStore.getState();
  108. const allModel = collectModelsWithDefaultModel(
  109. configStore.models,
  110. [configStore.customModels, accessStore.customModels].join(","),
  111. accessStore.defaultModel,
  112. );
  113. const summarizeModel = allModel.find(
  114. (m) => m.name === SUMMARIZE_MODEL && m.available,
  115. );
  116. if (summarizeModel) {
  117. return [
  118. summarizeModel.name,
  119. summarizeModel.provider?.providerName as string,
  120. ];
  121. }
  122. }
  123. if (currentModel.startsWith("gemini")) {
  124. return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
  125. }
  126. return [currentModel, providerName];
  127. }
  128. function countMessages(msgs: ChatMessage[]) {
  129. return msgs.reduce(
  130. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  131. 0,
  132. );
  133. }
  134. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  135. const cutoff =
  136. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  137. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  138. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  139. var serviceProvider = "OpenAI";
  140. if (modelInfo) {
  141. // TODO: auto detect the providerName from the modelConfig.model
  142. // Directly use the providerName from the modelInfo
  143. serviceProvider = modelInfo.provider.providerName;
  144. }
  145. const vars = {
  146. ServiceProvider: serviceProvider,
  147. cutoff,
  148. model: modelConfig.model,
  149. time: new Date().toString(),
  150. lang: getLang(),
  151. input: input,
  152. };
  153. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  154. // remove duplicate
  155. if (input.startsWith(output)) {
  156. output = "";
  157. }
  158. // must contains {{input}}
  159. const inputVar = "{{input}}";
  160. if (!output.includes(inputVar)) {
  161. output += "\n" + inputVar;
  162. }
  163. Object.entries(vars).forEach(([name, value]) => {
  164. const regex = new RegExp(`{{${name}}}`, "g");
  165. output = output.replace(regex, value.toString()); // Ensure value is a string
  166. });
  167. return output;
  168. }
  169. const DEFAULT_CHAT_STATE = {
  170. sessions: [createEmptySession()],
  171. currentSessionIndex: 0,
  172. lastInput: "",
  173. };
  174. export const useChatStore = createPersistStore(
  175. DEFAULT_CHAT_STATE,
  176. (set, _get) => {
  177. function get() {
  178. return {
  179. ..._get(),
  180. ...methods,
  181. };
  182. }
  183. const methods = {
  184. forkSession() {
  185. // 获取当前会话
  186. const currentSession = get().currentSession();
  187. if (!currentSession) return;
  188. const newSession = createEmptySession();
  189. newSession.topic = currentSession.topic;
  190. newSession.messages = [...currentSession.messages];
  191. newSession.mask = {
  192. ...currentSession.mask,
  193. modelConfig: {
  194. ...currentSession.mask.modelConfig,
  195. },
  196. };
  197. set((state) => ({
  198. currentSessionIndex: 0,
  199. sessions: [newSession, ...state.sessions],
  200. }));
  201. },
  202. clearSessions() {
  203. set(() => ({
  204. sessions: [createEmptySession()],
  205. currentSessionIndex: 0,
  206. }));
  207. },
  208. selectSession(index: number) {
  209. set({
  210. currentSessionIndex: index,
  211. });
  212. },
  213. moveSession(from: number, to: number) {
  214. set((state) => {
  215. const { sessions, currentSessionIndex: oldIndex } = state;
  216. // move the session
  217. const newSessions = [...sessions];
  218. const session = newSessions[from];
  219. newSessions.splice(from, 1);
  220. newSessions.splice(to, 0, session);
  221. // modify current session id
  222. let newIndex = oldIndex === from ? to : oldIndex;
  223. if (oldIndex > from && oldIndex <= to) {
  224. newIndex -= 1;
  225. } else if (oldIndex < from && oldIndex >= to) {
  226. newIndex += 1;
  227. }
  228. return {
  229. currentSessionIndex: newIndex,
  230. sessions: newSessions,
  231. };
  232. });
  233. },
  234. newSession(mask?: Mask) {
  235. const session = createEmptySession();
  236. if (mask) {
  237. const config = useAppConfig.getState();
  238. const globalModelConfig = config.modelConfig;
  239. session.mask = {
  240. ...mask,
  241. modelConfig: {
  242. ...globalModelConfig,
  243. ...mask.modelConfig,
  244. },
  245. };
  246. session.topic = mask.name;
  247. }
  248. set((state) => ({
  249. currentSessionIndex: 0,
  250. sessions: [session].concat(state.sessions),
  251. }));
  252. },
  253. nextSession(delta: number) {
  254. const n = get().sessions.length;
  255. const limit = (x: number) => (x + n) % n;
  256. const i = get().currentSessionIndex;
  257. get().selectSession(limit(i + delta));
  258. },
  259. deleteSession(index: number) {
  260. const deletingLastSession = get().sessions.length === 1;
  261. const deletedSession = get().sessions.at(index);
  262. if (!deletedSession) return;
  263. const sessions = get().sessions.slice();
  264. sessions.splice(index, 1);
  265. const currentIndex = get().currentSessionIndex;
  266. let nextIndex = Math.min(
  267. currentIndex - Number(index < currentIndex),
  268. sessions.length - 1,
  269. );
  270. if (deletingLastSession) {
  271. nextIndex = 0;
  272. sessions.push(createEmptySession());
  273. }
  274. // for undo delete action
  275. const restoreState = {
  276. currentSessionIndex: get().currentSessionIndex,
  277. sessions: get().sessions.slice(),
  278. };
  279. set(() => ({
  280. currentSessionIndex: nextIndex,
  281. sessions,
  282. }));
  283. showToast(
  284. Locale.Home.DeleteToast,
  285. {
  286. text: Locale.Home.Revert,
  287. onClick() {
  288. set(() => restoreState);
  289. },
  290. },
  291. 5000,
  292. );
  293. },
  294. currentSession() {
  295. let index = get().currentSessionIndex;
  296. const sessions = get().sessions;
  297. if (index < 0 || index >= sessions.length) {
  298. index = Math.min(sessions.length - 1, Math.max(0, index));
  299. set(() => ({ currentSessionIndex: index }));
  300. }
  301. const session = sessions[index];
  302. return session;
  303. },
  304. onNewMessage(message: ChatMessage, targetSession: ChatSession) {
  305. get().updateTargetSession(targetSession, (session) => {
  306. session.messages = session.messages.concat();
  307. session.lastUpdate = Date.now();
  308. });
  309. get().updateStat(message, targetSession);
  310. get().summarizeSession(false, targetSession);
  311. },
  312. async onUserInput(content: string, attachImages?: string[]) {
  313. const session = get().currentSession();
  314. const modelConfig = session.mask.modelConfig;
  315. const userContent = fillTemplateWith(content, modelConfig);
  316. console.log("[User Input] after template: ", userContent);
  317. let mContent: string | MultimodalContent[] = userContent;
  318. if (attachImages && attachImages.length > 0) {
  319. mContent = [
  320. ...(userContent
  321. ? [{ type: "text" as const, text: userContent }]
  322. : []),
  323. ...attachImages.map((url) => ({
  324. type: "image_url" as const,
  325. image_url: { url },
  326. })),
  327. ];
  328. }
  329. let userMessage: ChatMessage = createMessage({
  330. role: "user",
  331. content: mContent,
  332. });
  333. const botMessage: ChatMessage = createMessage({
  334. role: "assistant",
  335. streaming: true,
  336. model: modelConfig.model,
  337. });
  338. // get recent messages
  339. const recentMessages = get().getMessagesWithMemory();
  340. const sendMessages = recentMessages.concat(userMessage);
  341. const messageIndex = session.messages.length + 1;
  342. // save user's and bot's message
  343. get().updateTargetSession(session, (session) => {
  344. const savedUserMessage = {
  345. ...userMessage,
  346. content: mContent,
  347. };
  348. session.messages = session.messages.concat([
  349. savedUserMessage,
  350. botMessage,
  351. ]);
  352. });
  353. const api: ClientApi = getClientApi(modelConfig.providerName);
  354. // make request
  355. api.llm.chat({
  356. messages: sendMessages,
  357. config: { ...modelConfig, stream: true },
  358. onUpdate(message) {
  359. botMessage.streaming = true;
  360. if (message) {
  361. botMessage.content = message;
  362. }
  363. get().updateTargetSession(session, (session) => {
  364. session.messages = session.messages.concat();
  365. });
  366. },
  367. async onFinish(message) {
  368. botMessage.streaming = false;
  369. if (message) {
  370. // console.log("[Bot Response] ", message);
  371. const mcpMatch = message.match(/```json:mcp([\s\S]*?)```/);
  372. if (mcpMatch) {
  373. try {
  374. const mcp = JSON.parse(mcpMatch[1]);
  375. console.log("[MCP Request]", mcp);
  376. // 直接调用服务器端 action
  377. const result = await executeMcpAction(mcp);
  378. console.log("[MCP Response]", result);
  379. } catch (error) {
  380. console.error("[MCP Error]", error);
  381. }
  382. } else {
  383. console.log("[MCP] No MCP found in response");
  384. }
  385. botMessage.content = message;
  386. botMessage.date = new Date().toLocaleString();
  387. get().onNewMessage(botMessage, session);
  388. }
  389. ChatControllerPool.remove(session.id, botMessage.id);
  390. },
  391. onBeforeTool(tool: ChatMessageTool) {
  392. (botMessage.tools = botMessage?.tools || []).push(tool);
  393. get().updateTargetSession(session, (session) => {
  394. session.messages = session.messages.concat();
  395. });
  396. },
  397. onAfterTool(tool: ChatMessageTool) {
  398. botMessage?.tools?.forEach((t, i, tools) => {
  399. if (tool.id == t.id) {
  400. tools[i] = { ...tool };
  401. }
  402. });
  403. get().updateTargetSession(session, (session) => {
  404. session.messages = session.messages.concat();
  405. });
  406. },
  407. onError(error) {
  408. const isAborted = error.message?.includes?.("aborted");
  409. botMessage.content +=
  410. "\n\n" +
  411. prettyObject({
  412. error: true,
  413. message: error.message,
  414. });
  415. botMessage.streaming = false;
  416. userMessage.isError = !isAborted;
  417. botMessage.isError = !isAborted;
  418. get().updateTargetSession(session, (session) => {
  419. session.messages = session.messages.concat();
  420. });
  421. ChatControllerPool.remove(
  422. session.id,
  423. botMessage.id ?? messageIndex,
  424. );
  425. console.error("[Chat] failed ", error);
  426. },
  427. onController(controller) {
  428. // collect controller for stop/retry
  429. ChatControllerPool.addController(
  430. session.id,
  431. botMessage.id ?? messageIndex,
  432. controller,
  433. );
  434. },
  435. });
  436. },
  437. getMemoryPrompt() {
  438. const session = get().currentSession();
  439. if (session.memoryPrompt.length) {
  440. return {
  441. role: "system",
  442. content: Locale.Store.Prompt.History(session.memoryPrompt),
  443. date: "",
  444. } as ChatMessage;
  445. }
  446. },
  447. getMessagesWithMemory() {
  448. const session = get().currentSession();
  449. const modelConfig = session.mask.modelConfig;
  450. const clearContextIndex = session.clearContextIndex ?? 0;
  451. const messages = session.messages.slice();
  452. const totalMessageCount = session.messages.length;
  453. // in-context prompts
  454. const contextPrompts = session.mask.context.slice();
  455. // system prompts, to get close to OpenAI Web ChatGPT
  456. const shouldInjectSystemPrompts =
  457. modelConfig.enableInjectSystemPrompts &&
  458. (session.mask.modelConfig.model.startsWith("gpt-") ||
  459. session.mask.modelConfig.model.startsWith("chatgpt-"));
  460. var systemPrompts: ChatMessage[] = [];
  461. systemPrompts = shouldInjectSystemPrompts
  462. ? [
  463. createMessage({
  464. role: "system",
  465. content: fillTemplateWith("", {
  466. ...modelConfig,
  467. template: DEFAULT_SYSTEM_TEMPLATE,
  468. }),
  469. }),
  470. ]
  471. : [];
  472. if (shouldInjectSystemPrompts) {
  473. console.log(
  474. "[Global System Prompt] ",
  475. systemPrompts.at(0)?.content ?? "empty",
  476. );
  477. }
  478. const memoryPrompt = get().getMemoryPrompt();
  479. // long term memory
  480. const shouldSendLongTermMemory =
  481. modelConfig.sendMemory &&
  482. session.memoryPrompt &&
  483. session.memoryPrompt.length > 0 &&
  484. session.lastSummarizeIndex > clearContextIndex;
  485. const longTermMemoryPrompts =
  486. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  487. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  488. // short term memory
  489. const shortTermMemoryStartIndex = Math.max(
  490. 0,
  491. totalMessageCount - modelConfig.historyMessageCount,
  492. );
  493. // lets concat send messages, including 4 parts:
  494. // 0. system prompt: to get close to OpenAI Web ChatGPT
  495. // 1. long term memory: summarized memory messages
  496. // 2. pre-defined in-context prompts
  497. // 3. short term memory: latest n messages
  498. // 4. newest input message
  499. const memoryStartIndex = shouldSendLongTermMemory
  500. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  501. : shortTermMemoryStartIndex;
  502. // and if user has cleared history messages, we should exclude the memory too.
  503. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  504. const maxTokenThreshold = modelConfig.max_tokens;
  505. // get recent messages as much as possible
  506. const reversedRecentMessages = [];
  507. for (
  508. let i = totalMessageCount - 1, tokenCount = 0;
  509. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  510. i -= 1
  511. ) {
  512. const msg = messages[i];
  513. if (!msg || msg.isError) continue;
  514. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  515. reversedRecentMessages.push(msg);
  516. }
  517. // concat all messages
  518. const recentMessages = [
  519. ...systemPrompts,
  520. ...longTermMemoryPrompts,
  521. ...contextPrompts,
  522. ...reversedRecentMessages.reverse(),
  523. ];
  524. return recentMessages;
  525. },
  526. updateMessage(
  527. sessionIndex: number,
  528. messageIndex: number,
  529. updater: (message?: ChatMessage) => void,
  530. ) {
  531. const sessions = get().sessions;
  532. const session = sessions.at(sessionIndex);
  533. const messages = session?.messages;
  534. updater(messages?.at(messageIndex));
  535. set(() => ({ sessions }));
  536. },
  537. resetSession(session: ChatSession) {
  538. get().updateTargetSession(session, (session) => {
  539. session.messages = [];
  540. session.memoryPrompt = "";
  541. });
  542. },
  543. summarizeSession(
  544. refreshTitle: boolean = false,
  545. targetSession: ChatSession,
  546. ) {
  547. const config = useAppConfig.getState();
  548. const session = targetSession;
  549. const modelConfig = session.mask.modelConfig;
  550. // skip summarize when using dalle3?
  551. if (isDalle3(modelConfig.model)) {
  552. return;
  553. }
  554. // if not config compressModel, then using getSummarizeModel
  555. const [model, providerName] = modelConfig.compressModel
  556. ? [modelConfig.compressModel, modelConfig.compressProviderName]
  557. : getSummarizeModel(
  558. session.mask.modelConfig.model,
  559. session.mask.modelConfig.providerName,
  560. );
  561. const api: ClientApi = getClientApi(providerName as ServiceProvider);
  562. // remove error messages if any
  563. const messages = session.messages;
  564. // should summarize topic after chating more than 50 words
  565. const SUMMARIZE_MIN_LEN = 50;
  566. if (
  567. (config.enableAutoGenerateTitle &&
  568. session.topic === DEFAULT_TOPIC &&
  569. countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
  570. refreshTitle
  571. ) {
  572. const startIndex = Math.max(
  573. 0,
  574. messages.length - modelConfig.historyMessageCount,
  575. );
  576. const topicMessages = messages
  577. .slice(
  578. startIndex < messages.length ? startIndex : messages.length - 1,
  579. messages.length,
  580. )
  581. .concat(
  582. createMessage({
  583. role: "user",
  584. content: Locale.Store.Prompt.Topic,
  585. }),
  586. );
  587. api.llm.chat({
  588. messages: topicMessages,
  589. config: {
  590. model,
  591. stream: false,
  592. providerName,
  593. },
  594. onFinish(message, responseRes) {
  595. if (responseRes?.status === 200) {
  596. get().updateTargetSession(
  597. session,
  598. (session) =>
  599. (session.topic =
  600. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  601. );
  602. }
  603. },
  604. });
  605. }
  606. const summarizeIndex = Math.max(
  607. session.lastSummarizeIndex,
  608. session.clearContextIndex ?? 0,
  609. );
  610. let toBeSummarizedMsgs = messages
  611. .filter((msg) => !msg.isError)
  612. .slice(summarizeIndex);
  613. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  614. if (historyMsgLength > (modelConfig?.max_tokens || 4000)) {
  615. const n = toBeSummarizedMsgs.length;
  616. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  617. Math.max(0, n - modelConfig.historyMessageCount),
  618. );
  619. }
  620. const memoryPrompt = get().getMemoryPrompt();
  621. if (memoryPrompt) {
  622. // add memory prompt
  623. toBeSummarizedMsgs.unshift(memoryPrompt);
  624. }
  625. const lastSummarizeIndex = session.messages.length;
  626. console.log(
  627. "[Chat History] ",
  628. toBeSummarizedMsgs,
  629. historyMsgLength,
  630. modelConfig.compressMessageLengthThreshold,
  631. );
  632. if (
  633. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  634. modelConfig.sendMemory
  635. ) {
  636. /** Destruct max_tokens while summarizing
  637. * this param is just shit
  638. **/
  639. const { max_tokens, ...modelcfg } = modelConfig;
  640. api.llm.chat({
  641. messages: toBeSummarizedMsgs.concat(
  642. createMessage({
  643. role: "system",
  644. content: Locale.Store.Prompt.Summarize,
  645. date: "",
  646. }),
  647. ),
  648. config: {
  649. ...modelcfg,
  650. stream: true,
  651. model,
  652. providerName,
  653. },
  654. onUpdate(message) {
  655. session.memoryPrompt = message;
  656. },
  657. onFinish(message, responseRes) {
  658. if (responseRes?.status === 200) {
  659. console.log("[Memory] ", message);
  660. get().updateTargetSession(session, (session) => {
  661. session.lastSummarizeIndex = lastSummarizeIndex;
  662. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  663. });
  664. }
  665. },
  666. onError(err) {
  667. console.error("[Summarize] ", err);
  668. },
  669. });
  670. }
  671. },
  672. updateStat(message: ChatMessage, session: ChatSession) {
  673. get().updateTargetSession(session, (session) => {
  674. session.stat.charCount += message.content.length;
  675. // TODO: should update chat count and word count
  676. });
  677. },
  678. updateTargetSession(
  679. targetSession: ChatSession,
  680. updater: (session: ChatSession) => void,
  681. ) {
  682. const sessions = get().sessions;
  683. const index = sessions.findIndex((s) => s.id === targetSession.id);
  684. if (index < 0) return;
  685. updater(sessions[index]);
  686. set(() => ({ sessions }));
  687. },
  688. async clearAllData() {
  689. await indexedDBStorage.clear();
  690. localStorage.clear();
  691. location.reload();
  692. },
  693. setLastInput(lastInput: string) {
  694. set({
  695. lastInput,
  696. });
  697. },
  698. };
  699. return methods;
  700. },
  701. {
  702. name: StoreKey.Chat,
  703. version: 3.3,
  704. migrate(persistedState, version) {
  705. const state = persistedState as any;
  706. const newState = JSON.parse(
  707. JSON.stringify(state),
  708. ) as typeof DEFAULT_CHAT_STATE;
  709. if (version < 2) {
  710. newState.sessions = [];
  711. const oldSessions = state.sessions;
  712. for (const oldSession of oldSessions) {
  713. const newSession = createEmptySession();
  714. newSession.topic = oldSession.topic;
  715. newSession.messages = [...oldSession.messages];
  716. newSession.mask.modelConfig.sendMemory = true;
  717. newSession.mask.modelConfig.historyMessageCount = 4;
  718. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  719. newState.sessions.push(newSession);
  720. }
  721. }
  722. if (version < 3) {
  723. // migrate id to nanoid
  724. newState.sessions.forEach((s) => {
  725. s.id = nanoid();
  726. s.messages.forEach((m) => (m.id = nanoid()));
  727. });
  728. }
  729. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  730. // Resolve issue of old sessions not automatically enabling.
  731. if (version < 3.1) {
  732. newState.sessions.forEach((s) => {
  733. if (
  734. // Exclude those already set by user
  735. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  736. ) {
  737. // Because users may have changed this configuration,
  738. // the user's current configuration is used instead of the default
  739. const config = useAppConfig.getState();
  740. s.mask.modelConfig.enableInjectSystemPrompts =
  741. config.modelConfig.enableInjectSystemPrompts;
  742. }
  743. });
  744. }
  745. // add default summarize model for every session
  746. if (version < 3.2) {
  747. newState.sessions.forEach((s) => {
  748. const config = useAppConfig.getState();
  749. s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
  750. s.mask.modelConfig.compressProviderName =
  751. config.modelConfig.compressProviderName;
  752. });
  753. }
  754. // revert default summarize model for every session
  755. if (version < 3.3) {
  756. newState.sessions.forEach((s) => {
  757. const config = useAppConfig.getState();
  758. s.mask.modelConfig.compressModel = "";
  759. s.mask.modelConfig.compressProviderName = "";
  760. });
  761. }
  762. return newState as any;
  763. },
  764. },
  765. );