chat.ts 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840
  1. import { getMessageTextContent, trimTopic } from "../utils";
  2. import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
  3. import { nanoid } from "nanoid";
  4. import type {
  5. ClientApi,
  6. MultimodalContent,
  7. RequestMessage,
  8. } from "../client/api";
  9. import { getClientApi } from "../client/api";
  10. import { ChatControllerPool } from "../client/controller";
  11. import { showToast } from "../components/ui-lib";
  12. import {
  13. DEFAULT_INPUT_TEMPLATE,
  14. DEFAULT_MODELS,
  15. DEFAULT_SYSTEM_TEMPLATE,
  16. KnowledgeCutOffDate,
  17. StoreKey,
  18. SUMMARIZE_MODEL,
  19. GEMINI_SUMMARIZE_MODEL,
  20. ServiceProvider,
  21. } from "../constant";
  22. import Locale, { getLang } from "../locales";
  23. import { isDalle3, safeLocalStorage } from "../utils";
  24. import { prettyObject } from "../utils/format";
  25. import { createPersistStore } from "../utils/store";
  26. import { estimateTokenLength } from "../utils/token";
  27. import { ModelConfig, ModelType, useAppConfig } from "./config";
  28. import { useAccessStore } from "./access";
  29. import { collectModelsWithDefaultModel } from "../utils/model";
  30. import { createEmptyMask, Mask } from "./mask";
  31. const localStorage = safeLocalStorage();
  32. export type ChatMessageTool = {
  33. id: string;
  34. index?: number;
  35. type?: string;
  36. function?: {
  37. name: string;
  38. arguments?: string;
  39. };
  40. content?: string;
  41. isError?: boolean;
  42. errorMsg?: string;
  43. };
  44. export type ChatMessage = RequestMessage & {
  45. date: string;
  46. streaming?: boolean;
  47. isError?: boolean;
  48. id: string;
  49. model?: ModelType;
  50. tools?: ChatMessageTool[];
  51. };
  52. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  53. return {
  54. id: nanoid(),
  55. date: new Date().toLocaleString(),
  56. role: "user",
  57. content: "",
  58. ...override,
  59. };
  60. }
  61. export interface ChatStat {
  62. tokenCount: number;
  63. wordCount: number;
  64. charCount: number;
  65. }
  66. export interface ChatSession {
  67. id: string;
  68. topic: string;
  69. memoryPrompt: string;
  70. messages: ChatMessage[];
  71. stat: ChatStat;
  72. lastUpdate: number;
  73. lastSummarizeIndex: number;
  74. clearContextIndex?: number;
  75. mask: Mask;
  76. }
  77. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  78. export const BOT_HELLO: ChatMessage = createMessage({
  79. role: "assistant",
  80. content: Locale.Store.BotHello,
  81. });
  82. function createEmptySession(): ChatSession {
  83. return {
  84. id: nanoid(),
  85. topic: DEFAULT_TOPIC,
  86. memoryPrompt: "",
  87. messages: [],
  88. stat: {
  89. tokenCount: 0,
  90. wordCount: 0,
  91. charCount: 0,
  92. },
  93. lastUpdate: Date.now(),
  94. lastSummarizeIndex: 0,
  95. mask: createEmptyMask(),
  96. };
  97. }
  98. function getSummarizeModel(
  99. currentModel: string,
  100. providerName: string,
  101. ): string[] {
  102. // if it is using gpt-* models, force to use 4o-mini to summarize
  103. if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
  104. const configStore = useAppConfig.getState();
  105. const accessStore = useAccessStore.getState();
  106. const allModel = collectModelsWithDefaultModel(
  107. configStore.models,
  108. [configStore.customModels, accessStore.customModels].join(","),
  109. accessStore.defaultModel,
  110. );
  111. const summarizeModel = allModel.find(
  112. (m) => m.name === SUMMARIZE_MODEL && m.available,
  113. );
  114. if (summarizeModel) {
  115. return [
  116. summarizeModel.name,
  117. summarizeModel.provider?.providerName as string,
  118. ];
  119. }
  120. }
  121. if (currentModel.startsWith("gemini")) {
  122. return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
  123. }
  124. return [currentModel, providerName];
  125. }
  126. function countMessages(msgs: ChatMessage[]) {
  127. return msgs.reduce(
  128. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  129. 0,
  130. );
  131. }
  132. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  133. const cutoff =
  134. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  135. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  136. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  137. var serviceProvider = "OpenAI";
  138. if (modelInfo) {
  139. // TODO: auto detect the providerName from the modelConfig.model
  140. // Directly use the providerName from the modelInfo
  141. serviceProvider = modelInfo.provider.providerName;
  142. }
  143. const vars = {
  144. ServiceProvider: serviceProvider,
  145. cutoff,
  146. model: modelConfig.model,
  147. time: new Date().toString(),
  148. lang: getLang(),
  149. input: input,
  150. };
  151. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  152. // remove duplicate
  153. if (input.startsWith(output)) {
  154. output = "";
  155. }
  156. // must contains {{input}}
  157. const inputVar = "{{input}}";
  158. if (!output.includes(inputVar)) {
  159. output += "\n" + inputVar;
  160. }
  161. Object.entries(vars).forEach(([name, value]) => {
  162. const regex = new RegExp(`{{${name}}}`, "g");
  163. output = output.replace(regex, value.toString()); // Ensure value is a string
  164. });
  165. return output;
  166. }
  167. const DEFAULT_CHAT_STATE = {
  168. sessions: [createEmptySession()],
  169. currentSessionIndex: 0,
  170. lastInput: "",
  171. };
  172. export const useChatStore = createPersistStore(
  173. DEFAULT_CHAT_STATE,
  174. (set, _get) => {
  175. function get() {
  176. return {
  177. ..._get(),
  178. ...methods,
  179. };
  180. }
  181. const methods = {
  182. forkSession() {
  183. // 获取当前会话
  184. const currentSession = get().currentSession();
  185. if (!currentSession) return;
  186. const newSession = createEmptySession();
  187. newSession.topic = currentSession.topic;
  188. newSession.messages = [...currentSession.messages];
  189. newSession.mask = {
  190. ...currentSession.mask,
  191. modelConfig: {
  192. ...currentSession.mask.modelConfig,
  193. },
  194. };
  195. set((state) => ({
  196. currentSessionIndex: 0,
  197. sessions: [newSession, ...state.sessions],
  198. }));
  199. },
  200. clearSessions() {
  201. set(() => ({
  202. sessions: [createEmptySession()],
  203. currentSessionIndex: 0,
  204. }));
  205. },
  206. selectSession(index: number) {
  207. set({
  208. currentSessionIndex: index,
  209. });
  210. },
  211. moveSession(from: number, to: number) {
  212. set((state) => {
  213. const { sessions, currentSessionIndex: oldIndex } = state;
  214. // move the session
  215. const newSessions = [...sessions];
  216. const session = newSessions[from];
  217. newSessions.splice(from, 1);
  218. newSessions.splice(to, 0, session);
  219. // modify current session id
  220. let newIndex = oldIndex === from ? to : oldIndex;
  221. if (oldIndex > from && oldIndex <= to) {
  222. newIndex -= 1;
  223. } else if (oldIndex < from && oldIndex >= to) {
  224. newIndex += 1;
  225. }
  226. return {
  227. currentSessionIndex: newIndex,
  228. sessions: newSessions,
  229. };
  230. });
  231. },
  232. newSession(mask?: Mask) {
  233. const session = createEmptySession();
  234. if (mask) {
  235. const config = useAppConfig.getState();
  236. const globalModelConfig = config.modelConfig;
  237. session.mask = {
  238. ...mask,
  239. modelConfig: {
  240. ...globalModelConfig,
  241. ...mask.modelConfig,
  242. },
  243. };
  244. session.topic = mask.name;
  245. }
  246. set((state) => ({
  247. currentSessionIndex: 0,
  248. sessions: [session].concat(state.sessions),
  249. }));
  250. },
  251. nextSession(delta: number) {
  252. const n = get().sessions.length;
  253. const limit = (x: number) => (x + n) % n;
  254. const i = get().currentSessionIndex;
  255. get().selectSession(limit(i + delta));
  256. },
  257. deleteSession(index: number) {
  258. const deletingLastSession = get().sessions.length === 1;
  259. const deletedSession = get().sessions.at(index);
  260. if (!deletedSession) return;
  261. const sessions = get().sessions.slice();
  262. sessions.splice(index, 1);
  263. const currentIndex = get().currentSessionIndex;
  264. let nextIndex = Math.min(
  265. currentIndex - Number(index < currentIndex),
  266. sessions.length - 1,
  267. );
  268. if (deletingLastSession) {
  269. nextIndex = 0;
  270. sessions.push(createEmptySession());
  271. }
  272. // for undo delete action
  273. const restoreState = {
  274. currentSessionIndex: get().currentSessionIndex,
  275. sessions: get().sessions.slice(),
  276. };
  277. set(() => ({
  278. currentSessionIndex: nextIndex,
  279. sessions,
  280. }));
  281. showToast(
  282. Locale.Home.DeleteToast,
  283. {
  284. text: Locale.Home.Revert,
  285. onClick() {
  286. set(() => restoreState);
  287. },
  288. },
  289. 5000,
  290. );
  291. },
  292. currentSession() {
  293. let index = get().currentSessionIndex;
  294. const sessions = get().sessions;
  295. if (index < 0 || index >= sessions.length) {
  296. index = Math.min(sessions.length - 1, Math.max(0, index));
  297. set(() => ({ currentSessionIndex: index }));
  298. }
  299. const session = sessions[index];
  300. return session;
  301. },
  302. onNewMessage(message: ChatMessage, targetSession: ChatSession) {
  303. get().updateTargetSession(targetSession, (session) => {
  304. session.messages = session.messages.concat();
  305. session.lastUpdate = Date.now();
  306. });
  307. get().updateStat(message, targetSession);
  308. get().summarizeSession(false, targetSession);
  309. },
  310. async onUserInput(content: string, attachImages?: string[]) {
  311. const session = get().currentSession();
  312. const modelConfig = session.mask.modelConfig;
  313. const userContent = fillTemplateWith(content, modelConfig);
  314. console.log("[User Input] after template: ", userContent);
  315. let mContent: string | MultimodalContent[] = userContent;
  316. if (attachImages && attachImages.length > 0) {
  317. mContent = [
  318. ...(userContent
  319. ? [{ type: "text" as const, text: userContent }]
  320. : []),
  321. ...attachImages.map((url) => ({
  322. type: "image_url" as const,
  323. image_url: { url },
  324. })),
  325. ];
  326. }
  327. let userMessage: ChatMessage = createMessage({
  328. role: "user",
  329. content: mContent,
  330. });
  331. const botMessage: ChatMessage = createMessage({
  332. role: "assistant",
  333. streaming: true,
  334. model: modelConfig.model,
  335. });
  336. // get recent messages
  337. const recentMessages = get().getMessagesWithMemory();
  338. const sendMessages = recentMessages.concat(userMessage);
  339. const messageIndex = session.messages.length + 1;
  340. // save user's and bot's message
  341. get().updateTargetSession(session, (session) => {
  342. const savedUserMessage = {
  343. ...userMessage,
  344. content: mContent,
  345. };
  346. session.messages = session.messages.concat([
  347. savedUserMessage,
  348. botMessage,
  349. ]);
  350. });
  351. const api: ClientApi = getClientApi(modelConfig.providerName);
  352. // make request
  353. api.llm.chat({
  354. messages: sendMessages,
  355. config: { ...modelConfig, stream: true },
  356. onUpdate(message) {
  357. botMessage.streaming = true;
  358. if (message) {
  359. botMessage.content = message;
  360. }
  361. get().updateTargetSession(session, (session) => {
  362. session.messages = session.messages.concat();
  363. });
  364. },
  365. onFinish(message) {
  366. botMessage.streaming = false;
  367. if (message) {
  368. botMessage.content = message;
  369. botMessage.date = new Date().toLocaleString();
  370. get().onNewMessage(botMessage, session);
  371. }
  372. ChatControllerPool.remove(session.id, botMessage.id);
  373. },
  374. onBeforeTool(tool: ChatMessageTool) {
  375. (botMessage.tools = botMessage?.tools || []).push(tool);
  376. get().updateTargetSession(session, (session) => {
  377. session.messages = session.messages.concat();
  378. });
  379. },
  380. onAfterTool(tool: ChatMessageTool) {
  381. botMessage?.tools?.forEach((t, i, tools) => {
  382. if (tool.id == t.id) {
  383. tools[i] = { ...tool };
  384. }
  385. });
  386. get().updateTargetSession(session, (session) => {
  387. session.messages = session.messages.concat();
  388. });
  389. },
  390. onError(error) {
  391. const isAborted = error.message?.includes?.("aborted");
  392. botMessage.content +=
  393. "\n\n" +
  394. prettyObject({
  395. error: true,
  396. message: error.message,
  397. });
  398. botMessage.streaming = false;
  399. userMessage.isError = !isAborted;
  400. botMessage.isError = !isAborted;
  401. get().updateTargetSession(session, (session) => {
  402. session.messages = session.messages.concat();
  403. });
  404. ChatControllerPool.remove(
  405. session.id,
  406. botMessage.id ?? messageIndex,
  407. );
  408. console.error("[Chat] failed ", error);
  409. },
  410. onController(controller) {
  411. // collect controller for stop/retry
  412. ChatControllerPool.addController(
  413. session.id,
  414. botMessage.id ?? messageIndex,
  415. controller,
  416. );
  417. },
  418. });
  419. },
  420. getMemoryPrompt() {
  421. const session = get().currentSession();
  422. if (session.memoryPrompt.length) {
  423. return {
  424. role: "system",
  425. content: Locale.Store.Prompt.History(session.memoryPrompt),
  426. date: "",
  427. } as ChatMessage;
  428. }
  429. },
  430. getMessagesWithMemory() {
  431. const session = get().currentSession();
  432. const modelConfig = session.mask.modelConfig;
  433. const clearContextIndex = session.clearContextIndex ?? 0;
  434. const messages = session.messages.slice();
  435. const totalMessageCount = session.messages.length;
  436. // in-context prompts
  437. const contextPrompts = session.mask.context.slice();
  438. // system prompts, to get close to OpenAI Web ChatGPT
  439. const shouldInjectSystemPrompts =
  440. modelConfig.enableInjectSystemPrompts &&
  441. (session.mask.modelConfig.model.startsWith("gpt-") ||
  442. session.mask.modelConfig.model.startsWith("chatgpt-"));
  443. var systemPrompts: ChatMessage[] = [];
  444. systemPrompts = shouldInjectSystemPrompts
  445. ? [
  446. createMessage({
  447. role: "system",
  448. content: fillTemplateWith("", {
  449. ...modelConfig,
  450. template: DEFAULT_SYSTEM_TEMPLATE,
  451. }),
  452. }),
  453. ]
  454. : [];
  455. if (shouldInjectSystemPrompts) {
  456. console.log(
  457. "[Global System Prompt] ",
  458. systemPrompts.at(0)?.content ?? "empty",
  459. );
  460. }
  461. const memoryPrompt = get().getMemoryPrompt();
  462. // long term memory
  463. const shouldSendLongTermMemory =
  464. modelConfig.sendMemory &&
  465. session.memoryPrompt &&
  466. session.memoryPrompt.length > 0 &&
  467. session.lastSummarizeIndex > clearContextIndex;
  468. const longTermMemoryPrompts =
  469. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  470. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  471. // short term memory
  472. const shortTermMemoryStartIndex = Math.max(
  473. 0,
  474. totalMessageCount - modelConfig.historyMessageCount,
  475. );
  476. // lets concat send messages, including 4 parts:
  477. // 0. system prompt: to get close to OpenAI Web ChatGPT
  478. // 1. long term memory: summarized memory messages
  479. // 2. pre-defined in-context prompts
  480. // 3. short term memory: latest n messages
  481. // 4. newest input message
  482. const memoryStartIndex = shouldSendLongTermMemory
  483. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  484. : shortTermMemoryStartIndex;
  485. // and if user has cleared history messages, we should exclude the memory too.
  486. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  487. const maxTokenThreshold = modelConfig.max_tokens;
  488. // get recent messages as much as possible
  489. const reversedRecentMessages = [];
  490. for (
  491. let i = totalMessageCount - 1, tokenCount = 0;
  492. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  493. i -= 1
  494. ) {
  495. const msg = messages[i];
  496. if (!msg || msg.isError) continue;
  497. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  498. reversedRecentMessages.push(msg);
  499. }
  500. // concat all messages
  501. const recentMessages = [
  502. ...systemPrompts,
  503. ...longTermMemoryPrompts,
  504. ...contextPrompts,
  505. ...reversedRecentMessages.reverse(),
  506. ];
  507. return recentMessages;
  508. },
  509. updateMessage(
  510. sessionIndex: number,
  511. messageIndex: number,
  512. updater: (message?: ChatMessage) => void,
  513. ) {
  514. const sessions = get().sessions;
  515. const session = sessions.at(sessionIndex);
  516. const messages = session?.messages;
  517. updater(messages?.at(messageIndex));
  518. set(() => ({ sessions }));
  519. },
  520. resetSession(session: ChatSession) {
  521. get().updateTargetSession(session, (session) => {
  522. session.messages = [];
  523. session.memoryPrompt = "";
  524. });
  525. },
  526. summarizeSession(
  527. refreshTitle: boolean = false,
  528. targetSession: ChatSession,
  529. ) {
  530. const config = useAppConfig.getState();
  531. const session = targetSession;
  532. const modelConfig = session.mask.modelConfig;
  533. // skip summarize when using dalle3?
  534. if (isDalle3(modelConfig.model)) {
  535. return;
  536. }
  537. // if not config compressModel, then using getSummarizeModel
  538. const [model, providerName] = modelConfig.compressModel
  539. ? [modelConfig.compressModel, modelConfig.compressProviderName]
  540. : getSummarizeModel(
  541. session.mask.modelConfig.model,
  542. session.mask.modelConfig.providerName,
  543. );
  544. const api: ClientApi = getClientApi(providerName as ServiceProvider);
  545. // remove error messages if any
  546. const messages = session.messages;
  547. // should summarize topic after chating more than 50 words
  548. const SUMMARIZE_MIN_LEN = 50;
  549. if (
  550. (config.enableAutoGenerateTitle &&
  551. session.topic === DEFAULT_TOPIC &&
  552. countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
  553. refreshTitle
  554. ) {
  555. const startIndex = Math.max(
  556. 0,
  557. messages.length - modelConfig.historyMessageCount,
  558. );
  559. const topicMessages = messages
  560. .slice(
  561. startIndex < messages.length ? startIndex : messages.length - 1,
  562. messages.length,
  563. )
  564. .concat(
  565. createMessage({
  566. role: "user",
  567. content: Locale.Store.Prompt.Topic,
  568. }),
  569. );
  570. api.llm.chat({
  571. messages: topicMessages,
  572. config: {
  573. model,
  574. stream: false,
  575. providerName,
  576. },
  577. onFinish(message, responseRes) {
  578. if (responseRes?.status === 200) {
  579. get().updateTargetSession(
  580. session,
  581. (session) =>
  582. (session.topic =
  583. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  584. );
  585. }
  586. },
  587. });
  588. }
  589. const summarizeIndex = Math.max(
  590. session.lastSummarizeIndex,
  591. session.clearContextIndex ?? 0,
  592. );
  593. let toBeSummarizedMsgs = messages
  594. .filter((msg) => !msg.isError)
  595. .slice(summarizeIndex);
  596. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  597. if (historyMsgLength > (modelConfig?.max_tokens || 4000)) {
  598. const n = toBeSummarizedMsgs.length;
  599. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  600. Math.max(0, n - modelConfig.historyMessageCount),
  601. );
  602. }
  603. const memoryPrompt = get().getMemoryPrompt();
  604. if (memoryPrompt) {
  605. // add memory prompt
  606. toBeSummarizedMsgs.unshift(memoryPrompt);
  607. }
  608. const lastSummarizeIndex = session.messages.length;
  609. console.log(
  610. "[Chat History] ",
  611. toBeSummarizedMsgs,
  612. historyMsgLength,
  613. modelConfig.compressMessageLengthThreshold,
  614. );
  615. if (
  616. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  617. modelConfig.sendMemory
  618. ) {
  619. /** Destruct max_tokens while summarizing
  620. * this param is just shit
  621. **/
  622. const { max_tokens, ...modelcfg } = modelConfig;
  623. api.llm.chat({
  624. messages: toBeSummarizedMsgs.concat(
  625. createMessage({
  626. role: "system",
  627. content: Locale.Store.Prompt.Summarize,
  628. date: "",
  629. }),
  630. ),
  631. config: {
  632. ...modelcfg,
  633. stream: true,
  634. model,
  635. providerName,
  636. },
  637. onUpdate(message) {
  638. session.memoryPrompt = message;
  639. },
  640. onFinish(message, responseRes) {
  641. if (responseRes?.status === 200) {
  642. console.log("[Memory] ", message);
  643. get().updateTargetSession(session, (session) => {
  644. session.lastSummarizeIndex = lastSummarizeIndex;
  645. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  646. });
  647. }
  648. },
  649. onError(err) {
  650. console.error("[Summarize] ", err);
  651. },
  652. });
  653. }
  654. },
  655. updateStat(message: ChatMessage, session: ChatSession) {
  656. get().updateTargetSession(session, (session) => {
  657. session.stat.charCount += message.content.length;
  658. // TODO: should update chat count and word count
  659. });
  660. },
  661. updateTargetSession(
  662. targetSession: ChatSession,
  663. updater: (session: ChatSession) => void,
  664. ) {
  665. const sessions = get().sessions;
  666. const index = sessions.findIndex((s) => s.id === targetSession.id);
  667. if (index < 0) return;
  668. updater(sessions[index]);
  669. set(() => ({ sessions }));
  670. },
  671. async clearAllData() {
  672. await indexedDBStorage.clear();
  673. localStorage.clear();
  674. location.reload();
  675. },
  676. setLastInput(lastInput: string) {
  677. set({
  678. lastInput,
  679. });
  680. },
  681. };
  682. return methods;
  683. },
  684. {
  685. name: StoreKey.Chat,
  686. version: 3.3,
  687. migrate(persistedState, version) {
  688. const state = persistedState as any;
  689. const newState = JSON.parse(
  690. JSON.stringify(state),
  691. ) as typeof DEFAULT_CHAT_STATE;
  692. if (version < 2) {
  693. newState.sessions = [];
  694. const oldSessions = state.sessions;
  695. for (const oldSession of oldSessions) {
  696. const newSession = createEmptySession();
  697. newSession.topic = oldSession.topic;
  698. newSession.messages = [...oldSession.messages];
  699. newSession.mask.modelConfig.sendMemory = true;
  700. newSession.mask.modelConfig.historyMessageCount = 4;
  701. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  702. newState.sessions.push(newSession);
  703. }
  704. }
  705. if (version < 3) {
  706. // migrate id to nanoid
  707. newState.sessions.forEach((s) => {
  708. s.id = nanoid();
  709. s.messages.forEach((m) => (m.id = nanoid()));
  710. });
  711. }
  712. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  713. // Resolve issue of old sessions not automatically enabling.
  714. if (version < 3.1) {
  715. newState.sessions.forEach((s) => {
  716. if (
  717. // Exclude those already set by user
  718. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  719. ) {
  720. // Because users may have changed this configuration,
  721. // the user's current configuration is used instead of the default
  722. const config = useAppConfig.getState();
  723. s.mask.modelConfig.enableInjectSystemPrompts =
  724. config.modelConfig.enableInjectSystemPrompts;
  725. }
  726. });
  727. }
  728. // add default summarize model for every session
  729. if (version < 3.2) {
  730. newState.sessions.forEach((s) => {
  731. const config = useAppConfig.getState();
  732. s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
  733. s.mask.modelConfig.compressProviderName =
  734. config.modelConfig.compressProviderName;
  735. });
  736. }
  737. // revert default summarize model for every session
  738. if (version < 3.3) {
  739. newState.sessions.forEach((s) => {
  740. const config = useAppConfig.getState();
  741. s.mask.modelConfig.compressModel = "";
  742. s.mask.modelConfig.compressProviderName = "";
  743. });
  744. }
  745. return newState as any;
  746. },
  747. },
  748. );