chat.ts 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841
  1. import { getMessageTextContent, trimTopic } from "../utils";
  2. import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
  3. import { nanoid } from "nanoid";
  4. import type {
  5. ClientApi,
  6. MultimodalContent,
  7. RequestMessage,
  8. } from "../client/api";
  9. import { getClientApi } from "../client/api";
  10. import { ChatControllerPool } from "../client/controller";
  11. import { showToast } from "../components/ui-lib";
  12. import {
  13. DEFAULT_INPUT_TEMPLATE,
  14. DEFAULT_MODELS,
  15. DEFAULT_SYSTEM_TEMPLATE,
  16. KnowledgeCutOffDate,
  17. StoreKey,
  18. SUMMARIZE_MODEL,
  19. GEMINI_SUMMARIZE_MODEL,
  20. ServiceProvider,
  21. } from "../constant";
  22. import Locale, { getLang } from "../locales";
  23. import { isDalle3, safeLocalStorage } from "../utils";
  24. import { prettyObject } from "../utils/format";
  25. import { createPersistStore } from "../utils/store";
  26. import { estimateTokenLength } from "../utils/token";
  27. import { ModelConfig, ModelType, useAppConfig } from "./config";
  28. import { useAccessStore } from "./access";
  29. import { collectModelsWithDefaultModel } from "../utils/model";
  30. import { createEmptyMask, Mask } from "./mask";
  31. const localStorage = safeLocalStorage();
  32. export type ChatMessageTool = {
  33. id: string;
  34. index?: number;
  35. type?: string;
  36. function?: {
  37. name: string;
  38. arguments?: string;
  39. };
  40. content?: string;
  41. isError?: boolean;
  42. errorMsg?: string;
  43. };
  44. export type ChatMessage = RequestMessage & {
  45. date: string;
  46. streaming?: boolean;
  47. isError?: boolean;
  48. id: string;
  49. model?: ModelType;
  50. tools?: ChatMessageTool[];
  51. audio_url?: string;
  52. };
  53. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  54. return {
  55. id: nanoid(),
  56. date: new Date().toLocaleString(),
  57. role: "user",
  58. content: "",
  59. ...override,
  60. };
  61. }
  62. export interface ChatStat {
  63. tokenCount: number;
  64. wordCount: number;
  65. charCount: number;
  66. }
  67. export interface ChatSession {
  68. id: string;
  69. topic: string;
  70. memoryPrompt: string;
  71. messages: ChatMessage[];
  72. stat: ChatStat;
  73. lastUpdate: number;
  74. lastSummarizeIndex: number;
  75. clearContextIndex?: number;
  76. mask: Mask;
  77. }
  78. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  79. export const BOT_HELLO: ChatMessage = createMessage({
  80. role: "assistant",
  81. content: Locale.Store.BotHello,
  82. });
  83. function createEmptySession(): ChatSession {
  84. return {
  85. id: nanoid(),
  86. topic: DEFAULT_TOPIC,
  87. memoryPrompt: "",
  88. messages: [],
  89. stat: {
  90. tokenCount: 0,
  91. wordCount: 0,
  92. charCount: 0,
  93. },
  94. lastUpdate: Date.now(),
  95. lastSummarizeIndex: 0,
  96. mask: createEmptyMask(),
  97. };
  98. }
  99. function getSummarizeModel(
  100. currentModel: string,
  101. providerName: string,
  102. ): string[] {
  103. // if it is using gpt-* models, force to use 4o-mini to summarize
  104. if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
  105. const configStore = useAppConfig.getState();
  106. const accessStore = useAccessStore.getState();
  107. const allModel = collectModelsWithDefaultModel(
  108. configStore.models,
  109. [configStore.customModels, accessStore.customModels].join(","),
  110. accessStore.defaultModel,
  111. );
  112. const summarizeModel = allModel.find(
  113. (m) => m.name === SUMMARIZE_MODEL && m.available,
  114. );
  115. if (summarizeModel) {
  116. return [
  117. summarizeModel.name,
  118. summarizeModel.provider?.providerName as string,
  119. ];
  120. }
  121. }
  122. if (currentModel.startsWith("gemini")) {
  123. return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
  124. }
  125. return [currentModel, providerName];
  126. }
  127. function countMessages(msgs: ChatMessage[]) {
  128. return msgs.reduce(
  129. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  130. 0,
  131. );
  132. }
  133. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  134. const cutoff =
  135. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  136. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  137. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  138. var serviceProvider = "OpenAI";
  139. if (modelInfo) {
  140. // TODO: auto detect the providerName from the modelConfig.model
  141. // Directly use the providerName from the modelInfo
  142. serviceProvider = modelInfo.provider.providerName;
  143. }
  144. const vars = {
  145. ServiceProvider: serviceProvider,
  146. cutoff,
  147. model: modelConfig.model,
  148. time: new Date().toString(),
  149. lang: getLang(),
  150. input: input,
  151. };
  152. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  153. // remove duplicate
  154. if (input.startsWith(output)) {
  155. output = "";
  156. }
  157. // must contains {{input}}
  158. const inputVar = "{{input}}";
  159. if (!output.includes(inputVar)) {
  160. output += "\n" + inputVar;
  161. }
  162. Object.entries(vars).forEach(([name, value]) => {
  163. const regex = new RegExp(`{{${name}}}`, "g");
  164. output = output.replace(regex, value.toString()); // Ensure value is a string
  165. });
  166. return output;
  167. }
  168. const DEFAULT_CHAT_STATE = {
  169. sessions: [createEmptySession()],
  170. currentSessionIndex: 0,
  171. lastInput: "",
  172. };
  173. export const useChatStore = createPersistStore(
  174. DEFAULT_CHAT_STATE,
  175. (set, _get) => {
  176. function get() {
  177. return {
  178. ..._get(),
  179. ...methods,
  180. };
  181. }
  182. const methods = {
  183. forkSession() {
  184. // 获取当前会话
  185. const currentSession = get().currentSession();
  186. if (!currentSession) return;
  187. const newSession = createEmptySession();
  188. newSession.topic = currentSession.topic;
  189. newSession.messages = [...currentSession.messages];
  190. newSession.mask = {
  191. ...currentSession.mask,
  192. modelConfig: {
  193. ...currentSession.mask.modelConfig,
  194. },
  195. };
  196. set((state) => ({
  197. currentSessionIndex: 0,
  198. sessions: [newSession, ...state.sessions],
  199. }));
  200. },
  201. clearSessions() {
  202. set(() => ({
  203. sessions: [createEmptySession()],
  204. currentSessionIndex: 0,
  205. }));
  206. },
  207. selectSession(index: number) {
  208. set({
  209. currentSessionIndex: index,
  210. });
  211. },
  212. moveSession(from: number, to: number) {
  213. set((state) => {
  214. const { sessions, currentSessionIndex: oldIndex } = state;
  215. // move the session
  216. const newSessions = [...sessions];
  217. const session = newSessions[from];
  218. newSessions.splice(from, 1);
  219. newSessions.splice(to, 0, session);
  220. // modify current session id
  221. let newIndex = oldIndex === from ? to : oldIndex;
  222. if (oldIndex > from && oldIndex <= to) {
  223. newIndex -= 1;
  224. } else if (oldIndex < from && oldIndex >= to) {
  225. newIndex += 1;
  226. }
  227. return {
  228. currentSessionIndex: newIndex,
  229. sessions: newSessions,
  230. };
  231. });
  232. },
  233. newSession(mask?: Mask) {
  234. const session = createEmptySession();
  235. if (mask) {
  236. const config = useAppConfig.getState();
  237. const globalModelConfig = config.modelConfig;
  238. session.mask = {
  239. ...mask,
  240. modelConfig: {
  241. ...globalModelConfig,
  242. ...mask.modelConfig,
  243. },
  244. };
  245. session.topic = mask.name;
  246. }
  247. set((state) => ({
  248. currentSessionIndex: 0,
  249. sessions: [session].concat(state.sessions),
  250. }));
  251. },
  252. nextSession(delta: number) {
  253. const n = get().sessions.length;
  254. const limit = (x: number) => (x + n) % n;
  255. const i = get().currentSessionIndex;
  256. get().selectSession(limit(i + delta));
  257. },
  258. deleteSession(index: number) {
  259. const deletingLastSession = get().sessions.length === 1;
  260. const deletedSession = get().sessions.at(index);
  261. if (!deletedSession) return;
  262. const sessions = get().sessions.slice();
  263. sessions.splice(index, 1);
  264. const currentIndex = get().currentSessionIndex;
  265. let nextIndex = Math.min(
  266. currentIndex - Number(index < currentIndex),
  267. sessions.length - 1,
  268. );
  269. if (deletingLastSession) {
  270. nextIndex = 0;
  271. sessions.push(createEmptySession());
  272. }
  273. // for undo delete action
  274. const restoreState = {
  275. currentSessionIndex: get().currentSessionIndex,
  276. sessions: get().sessions.slice(),
  277. };
  278. set(() => ({
  279. currentSessionIndex: nextIndex,
  280. sessions,
  281. }));
  282. showToast(
  283. Locale.Home.DeleteToast,
  284. {
  285. text: Locale.Home.Revert,
  286. onClick() {
  287. set(() => restoreState);
  288. },
  289. },
  290. 5000,
  291. );
  292. },
  293. currentSession() {
  294. let index = get().currentSessionIndex;
  295. const sessions = get().sessions;
  296. if (index < 0 || index >= sessions.length) {
  297. index = Math.min(sessions.length - 1, Math.max(0, index));
  298. set(() => ({ currentSessionIndex: index }));
  299. }
  300. const session = sessions[index];
  301. return session;
  302. },
  303. onNewMessage(message: ChatMessage, targetSession: ChatSession) {
  304. get().updateTargetSession(targetSession, (session) => {
  305. session.messages = session.messages.concat();
  306. session.lastUpdate = Date.now();
  307. });
  308. get().updateStat(message, targetSession);
  309. get().summarizeSession(false, targetSession);
  310. },
  311. async onUserInput(content: string, attachImages?: string[]) {
  312. const session = get().currentSession();
  313. const modelConfig = session.mask.modelConfig;
  314. const userContent = fillTemplateWith(content, modelConfig);
  315. console.log("[User Input] after template: ", userContent);
  316. let mContent: string | MultimodalContent[] = userContent;
  317. if (attachImages && attachImages.length > 0) {
  318. mContent = [
  319. ...(userContent
  320. ? [{ type: "text" as const, text: userContent }]
  321. : []),
  322. ...attachImages.map((url) => ({
  323. type: "image_url" as const,
  324. image_url: { url },
  325. })),
  326. ];
  327. }
  328. let userMessage: ChatMessage = createMessage({
  329. role: "user",
  330. content: mContent,
  331. });
  332. const botMessage: ChatMessage = createMessage({
  333. role: "assistant",
  334. streaming: true,
  335. model: modelConfig.model,
  336. });
  337. // get recent messages
  338. const recentMessages = get().getMessagesWithMemory();
  339. const sendMessages = recentMessages.concat(userMessage);
  340. const messageIndex = session.messages.length + 1;
  341. // save user's and bot's message
  342. get().updateTargetSession(session, (session) => {
  343. const savedUserMessage = {
  344. ...userMessage,
  345. content: mContent,
  346. };
  347. session.messages = session.messages.concat([
  348. savedUserMessage,
  349. botMessage,
  350. ]);
  351. });
  352. const api: ClientApi = getClientApi(modelConfig.providerName);
  353. // make request
  354. api.llm.chat({
  355. messages: sendMessages,
  356. config: { ...modelConfig, stream: true },
  357. onUpdate(message) {
  358. botMessage.streaming = true;
  359. if (message) {
  360. botMessage.content = message;
  361. }
  362. get().updateTargetSession(session, (session) => {
  363. session.messages = session.messages.concat();
  364. });
  365. },
  366. onFinish(message) {
  367. botMessage.streaming = false;
  368. if (message) {
  369. botMessage.content = message;
  370. botMessage.date = new Date().toLocaleString();
  371. get().onNewMessage(botMessage, session);
  372. }
  373. ChatControllerPool.remove(session.id, botMessage.id);
  374. },
  375. onBeforeTool(tool: ChatMessageTool) {
  376. (botMessage.tools = botMessage?.tools || []).push(tool);
  377. get().updateTargetSession(session, (session) => {
  378. session.messages = session.messages.concat();
  379. });
  380. },
  381. onAfterTool(tool: ChatMessageTool) {
  382. botMessage?.tools?.forEach((t, i, tools) => {
  383. if (tool.id == t.id) {
  384. tools[i] = { ...tool };
  385. }
  386. });
  387. get().updateTargetSession(session, (session) => {
  388. session.messages = session.messages.concat();
  389. });
  390. },
  391. onError(error) {
  392. const isAborted = error.message?.includes?.("aborted");
  393. botMessage.content +=
  394. "\n\n" +
  395. prettyObject({
  396. error: true,
  397. message: error.message,
  398. });
  399. botMessage.streaming = false;
  400. userMessage.isError = !isAborted;
  401. botMessage.isError = !isAborted;
  402. get().updateTargetSession(session, (session) => {
  403. session.messages = session.messages.concat();
  404. });
  405. ChatControllerPool.remove(
  406. session.id,
  407. botMessage.id ?? messageIndex,
  408. );
  409. console.error("[Chat] failed ", error);
  410. },
  411. onController(controller) {
  412. // collect controller for stop/retry
  413. ChatControllerPool.addController(
  414. session.id,
  415. botMessage.id ?? messageIndex,
  416. controller,
  417. );
  418. },
  419. });
  420. },
  421. getMemoryPrompt() {
  422. const session = get().currentSession();
  423. if (session.memoryPrompt.length) {
  424. return {
  425. role: "system",
  426. content: Locale.Store.Prompt.History(session.memoryPrompt),
  427. date: "",
  428. } as ChatMessage;
  429. }
  430. },
  431. getMessagesWithMemory() {
  432. const session = get().currentSession();
  433. const modelConfig = session.mask.modelConfig;
  434. const clearContextIndex = session.clearContextIndex ?? 0;
  435. const messages = session.messages.slice();
  436. const totalMessageCount = session.messages.length;
  437. // in-context prompts
  438. const contextPrompts = session.mask.context.slice();
  439. // system prompts, to get close to OpenAI Web ChatGPT
  440. const shouldInjectSystemPrompts =
  441. modelConfig.enableInjectSystemPrompts &&
  442. (session.mask.modelConfig.model.startsWith("gpt-") ||
  443. session.mask.modelConfig.model.startsWith("chatgpt-"));
  444. var systemPrompts: ChatMessage[] = [];
  445. systemPrompts = shouldInjectSystemPrompts
  446. ? [
  447. createMessage({
  448. role: "system",
  449. content: fillTemplateWith("", {
  450. ...modelConfig,
  451. template: DEFAULT_SYSTEM_TEMPLATE,
  452. }),
  453. }),
  454. ]
  455. : [];
  456. if (shouldInjectSystemPrompts) {
  457. console.log(
  458. "[Global System Prompt] ",
  459. systemPrompts.at(0)?.content ?? "empty",
  460. );
  461. }
  462. const memoryPrompt = get().getMemoryPrompt();
  463. // long term memory
  464. const shouldSendLongTermMemory =
  465. modelConfig.sendMemory &&
  466. session.memoryPrompt &&
  467. session.memoryPrompt.length > 0 &&
  468. session.lastSummarizeIndex > clearContextIndex;
  469. const longTermMemoryPrompts =
  470. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  471. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  472. // short term memory
  473. const shortTermMemoryStartIndex = Math.max(
  474. 0,
  475. totalMessageCount - modelConfig.historyMessageCount,
  476. );
  477. // lets concat send messages, including 4 parts:
  478. // 0. system prompt: to get close to OpenAI Web ChatGPT
  479. // 1. long term memory: summarized memory messages
  480. // 2. pre-defined in-context prompts
  481. // 3. short term memory: latest n messages
  482. // 4. newest input message
  483. const memoryStartIndex = shouldSendLongTermMemory
  484. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  485. : shortTermMemoryStartIndex;
  486. // and if user has cleared history messages, we should exclude the memory too.
  487. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  488. const maxTokenThreshold = modelConfig.max_tokens;
  489. // get recent messages as much as possible
  490. const reversedRecentMessages = [];
  491. for (
  492. let i = totalMessageCount - 1, tokenCount = 0;
  493. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  494. i -= 1
  495. ) {
  496. const msg = messages[i];
  497. if (!msg || msg.isError) continue;
  498. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  499. reversedRecentMessages.push(msg);
  500. }
  501. // concat all messages
  502. const recentMessages = [
  503. ...systemPrompts,
  504. ...longTermMemoryPrompts,
  505. ...contextPrompts,
  506. ...reversedRecentMessages.reverse(),
  507. ];
  508. return recentMessages;
  509. },
  510. updateMessage(
  511. sessionIndex: number,
  512. messageIndex: number,
  513. updater: (message?: ChatMessage) => void,
  514. ) {
  515. const sessions = get().sessions;
  516. const session = sessions.at(sessionIndex);
  517. const messages = session?.messages;
  518. updater(messages?.at(messageIndex));
  519. set(() => ({ sessions }));
  520. },
  521. resetSession(session: ChatSession) {
  522. get().updateTargetSession(session, (session) => {
  523. session.messages = [];
  524. session.memoryPrompt = "";
  525. });
  526. },
  527. summarizeSession(
  528. refreshTitle: boolean = false,
  529. targetSession: ChatSession,
  530. ) {
  531. const config = useAppConfig.getState();
  532. const session = targetSession;
  533. const modelConfig = session.mask.modelConfig;
  534. // skip summarize when using dalle3?
  535. if (isDalle3(modelConfig.model)) {
  536. return;
  537. }
  538. // if not config compressModel, then using getSummarizeModel
  539. const [model, providerName] = modelConfig.compressModel
  540. ? [modelConfig.compressModel, modelConfig.compressProviderName]
  541. : getSummarizeModel(
  542. session.mask.modelConfig.model,
  543. session.mask.modelConfig.providerName,
  544. );
  545. const api: ClientApi = getClientApi(providerName as ServiceProvider);
  546. // remove error messages if any
  547. const messages = session.messages;
  548. // should summarize topic after chating more than 50 words
  549. const SUMMARIZE_MIN_LEN = 50;
  550. if (
  551. (config.enableAutoGenerateTitle &&
  552. session.topic === DEFAULT_TOPIC &&
  553. countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
  554. refreshTitle
  555. ) {
  556. const startIndex = Math.max(
  557. 0,
  558. messages.length - modelConfig.historyMessageCount,
  559. );
  560. const topicMessages = messages
  561. .slice(
  562. startIndex < messages.length ? startIndex : messages.length - 1,
  563. messages.length,
  564. )
  565. .concat(
  566. createMessage({
  567. role: "user",
  568. content: Locale.Store.Prompt.Topic,
  569. }),
  570. );
  571. api.llm.chat({
  572. messages: topicMessages,
  573. config: {
  574. model,
  575. stream: false,
  576. providerName,
  577. },
  578. onFinish(message, responseRes) {
  579. if (responseRes?.status === 200) {
  580. get().updateTargetSession(
  581. session,
  582. (session) =>
  583. (session.topic =
  584. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  585. );
  586. }
  587. },
  588. });
  589. }
  590. const summarizeIndex = Math.max(
  591. session.lastSummarizeIndex,
  592. session.clearContextIndex ?? 0,
  593. );
  594. let toBeSummarizedMsgs = messages
  595. .filter((msg) => !msg.isError)
  596. .slice(summarizeIndex);
  597. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  598. if (historyMsgLength > (modelConfig?.max_tokens || 4000)) {
  599. const n = toBeSummarizedMsgs.length;
  600. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  601. Math.max(0, n - modelConfig.historyMessageCount),
  602. );
  603. }
  604. const memoryPrompt = get().getMemoryPrompt();
  605. if (memoryPrompt) {
  606. // add memory prompt
  607. toBeSummarizedMsgs.unshift(memoryPrompt);
  608. }
  609. const lastSummarizeIndex = session.messages.length;
  610. console.log(
  611. "[Chat History] ",
  612. toBeSummarizedMsgs,
  613. historyMsgLength,
  614. modelConfig.compressMessageLengthThreshold,
  615. );
  616. if (
  617. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  618. modelConfig.sendMemory
  619. ) {
  620. /** Destruct max_tokens while summarizing
  621. * this param is just shit
  622. **/
  623. const { max_tokens, ...modelcfg } = modelConfig;
  624. api.llm.chat({
  625. messages: toBeSummarizedMsgs.concat(
  626. createMessage({
  627. role: "system",
  628. content: Locale.Store.Prompt.Summarize,
  629. date: "",
  630. }),
  631. ),
  632. config: {
  633. ...modelcfg,
  634. stream: true,
  635. model,
  636. providerName,
  637. },
  638. onUpdate(message) {
  639. session.memoryPrompt = message;
  640. },
  641. onFinish(message, responseRes) {
  642. if (responseRes?.status === 200) {
  643. console.log("[Memory] ", message);
  644. get().updateTargetSession(session, (session) => {
  645. session.lastSummarizeIndex = lastSummarizeIndex;
  646. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  647. });
  648. }
  649. },
  650. onError(err) {
  651. console.error("[Summarize] ", err);
  652. },
  653. });
  654. }
  655. },
  656. updateStat(message: ChatMessage, session: ChatSession) {
  657. get().updateTargetSession(session, (session) => {
  658. session.stat.charCount += message.content.length;
  659. // TODO: should update chat count and word count
  660. });
  661. },
  662. updateTargetSession(
  663. targetSession: ChatSession,
  664. updater: (session: ChatSession) => void,
  665. ) {
  666. const sessions = get().sessions;
  667. const index = sessions.findIndex((s) => s.id === targetSession.id);
  668. if (index < 0) return;
  669. updater(sessions[index]);
  670. set(() => ({ sessions }));
  671. },
  672. async clearAllData() {
  673. await indexedDBStorage.clear();
  674. localStorage.clear();
  675. location.reload();
  676. },
  677. setLastInput(lastInput: string) {
  678. set({
  679. lastInput,
  680. });
  681. },
  682. };
  683. return methods;
  684. },
  685. {
  686. name: StoreKey.Chat,
  687. version: 3.3,
  688. migrate(persistedState, version) {
  689. const state = persistedState as any;
  690. const newState = JSON.parse(
  691. JSON.stringify(state),
  692. ) as typeof DEFAULT_CHAT_STATE;
  693. if (version < 2) {
  694. newState.sessions = [];
  695. const oldSessions = state.sessions;
  696. for (const oldSession of oldSessions) {
  697. const newSession = createEmptySession();
  698. newSession.topic = oldSession.topic;
  699. newSession.messages = [...oldSession.messages];
  700. newSession.mask.modelConfig.sendMemory = true;
  701. newSession.mask.modelConfig.historyMessageCount = 4;
  702. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  703. newState.sessions.push(newSession);
  704. }
  705. }
  706. if (version < 3) {
  707. // migrate id to nanoid
  708. newState.sessions.forEach((s) => {
  709. s.id = nanoid();
  710. s.messages.forEach((m) => (m.id = nanoid()));
  711. });
  712. }
  713. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  714. // Resolve issue of old sessions not automatically enabling.
  715. if (version < 3.1) {
  716. newState.sessions.forEach((s) => {
  717. if (
  718. // Exclude those already set by user
  719. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  720. ) {
  721. // Because users may have changed this configuration,
  722. // the user's current configuration is used instead of the default
  723. const config = useAppConfig.getState();
  724. s.mask.modelConfig.enableInjectSystemPrompts =
  725. config.modelConfig.enableInjectSystemPrompts;
  726. }
  727. });
  728. }
  729. // add default summarize model for every session
  730. if (version < 3.2) {
  731. newState.sessions.forEach((s) => {
  732. const config = useAppConfig.getState();
  733. s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
  734. s.mask.modelConfig.compressProviderName =
  735. config.modelConfig.compressProviderName;
  736. });
  737. }
  738. // revert default summarize model for every session
  739. if (version < 3.3) {
  740. newState.sessions.forEach((s) => {
  741. const config = useAppConfig.getState();
  742. s.mask.modelConfig.compressModel = "";
  743. s.mask.modelConfig.compressProviderName = "";
  744. });
  745. }
  746. return newState as any;
  747. },
  748. },
  749. );