chat.ts 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840
  1. import { getMessageTextContent, trimTopic } from "../utils";
  2. import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
  3. import { nanoid } from "nanoid";
  4. import type {
  5. ClientApi,
  6. MultimodalContent,
  7. RequestMessage,
  8. } from "../client/api";
  9. import { getClientApi } from "../client/api";
  10. import { ChatControllerPool } from "../client/controller";
  11. import { showToast } from "../components/ui-lib";
  12. import {
  13. DEFAULT_INPUT_TEMPLATE,
  14. DEFAULT_MODELS,
  15. DEFAULT_SYSTEM_TEMPLATE,
  16. KnowledgeCutOffDate,
  17. StoreKey,
  18. SUMMARIZE_MODEL,
  19. GEMINI_SUMMARIZE_MODEL,
  20. ServiceProvider,
  21. } from "../constant";
  22. import Locale, { getLang } from "../locales";
  23. import { isDalle3, safeLocalStorage } from "../utils";
  24. import { prettyObject } from "../utils/format";
  25. import { createPersistStore } from "../utils/store";
  26. import { estimateTokenLength } from "../utils/token";
  27. import { ModelConfig, ModelType, useAppConfig } from "./config";
  28. import { useAccessStore } from "./access";
  29. import { collectModelsWithDefaultModel } from "../utils/model";
  30. import { createEmptyMask, Mask } from "./mask";
  31. const localStorage = safeLocalStorage();
  32. export type ChatMessageTool = {
  33. id: string;
  34. index?: number;
  35. type?: string;
  36. function?: {
  37. name: string;
  38. arguments?: string;
  39. };
  40. content?: string;
  41. isError?: boolean;
  42. errorMsg?: string;
  43. };
  44. export type ChatMessage = RequestMessage & {
  45. date: string;
  46. streaming?: boolean;
  47. isError?: boolean;
  48. id: string;
  49. model?: ModelType;
  50. tools?: ChatMessageTool[];
  51. };
  52. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  53. return {
  54. id: nanoid(),
  55. date: new Date().toLocaleString(),
  56. role: "user",
  57. content: "",
  58. ...override,
  59. };
  60. }
  61. export interface ChatStat {
  62. tokenCount: number;
  63. wordCount: number;
  64. charCount: number;
  65. }
  66. export interface ChatSession {
  67. id: string;
  68. topic: string;
  69. memoryPrompt: string;
  70. messages: ChatMessage[];
  71. stat: ChatStat;
  72. lastUpdate: number;
  73. lastSummarizeIndex: number;
  74. clearContextIndex?: number;
  75. mask: Mask;
  76. }
  77. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  78. export const BOT_HELLO: ChatMessage = createMessage({
  79. role: "assistant",
  80. content: Locale.Store.BotHello,
  81. });
  82. function createEmptySession(): ChatSession {
  83. return {
  84. id: nanoid(),
  85. topic: DEFAULT_TOPIC,
  86. memoryPrompt: "",
  87. messages: [],
  88. stat: {
  89. tokenCount: 0,
  90. wordCount: 0,
  91. charCount: 0,
  92. },
  93. lastUpdate: Date.now(),
  94. lastSummarizeIndex: 0,
  95. mask: createEmptyMask(),
  96. };
  97. }
  98. function getSummarizeModel(
  99. currentModel: string,
  100. providerName: string,
  101. ): string[] {
  102. // if it is using gpt-* models, force to use 4o-mini to summarize
  103. if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
  104. const configStore = useAppConfig.getState();
  105. const accessStore = useAccessStore.getState();
  106. const allModel = collectModelsWithDefaultModel(
  107. configStore.models,
  108. [configStore.customModels, accessStore.customModels].join(","),
  109. accessStore.defaultModel,
  110. );
  111. const summarizeModel = allModel.find(
  112. (m) => m.name === SUMMARIZE_MODEL && m.available,
  113. );
  114. if (summarizeModel) {
  115. return [
  116. summarizeModel.name,
  117. summarizeModel.provider?.providerName as string,
  118. ];
  119. }
  120. }
  121. if (currentModel.startsWith("gemini")) {
  122. return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
  123. }
  124. return [currentModel, providerName];
  125. }
  126. function countMessages(msgs: ChatMessage[]) {
  127. return msgs.reduce(
  128. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  129. 0,
  130. );
  131. }
  132. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  133. const cutoff =
  134. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  135. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  136. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  137. var serviceProvider = "OpenAI";
  138. if (modelInfo) {
  139. // TODO: auto detect the providerName from the modelConfig.model
  140. // Directly use the providerName from the modelInfo
  141. serviceProvider = modelInfo.provider.providerName;
  142. }
  143. const vars = {
  144. ServiceProvider: serviceProvider,
  145. cutoff,
  146. model: modelConfig.model,
  147. time: new Date().toString(),
  148. lang: getLang(),
  149. input: input,
  150. };
  151. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  152. // remove duplicate
  153. if (input.startsWith(output)) {
  154. output = "";
  155. }
  156. // must contains {{input}}
  157. const inputVar = "{{input}}";
  158. if (!output.includes(inputVar)) {
  159. output += "\n" + inputVar;
  160. }
  161. Object.entries(vars).forEach(([name, value]) => {
  162. const regex = new RegExp(`{{${name}}}`, "g");
  163. output = output.replace(regex, value.toString()); // Ensure value is a string
  164. });
  165. return output;
  166. }
  167. const DEFAULT_CHAT_STATE = {
  168. sessions: [createEmptySession()],
  169. currentSessionIndex: 0,
  170. lastInput: "",
  171. };
  172. export const useChatStore = createPersistStore(
  173. DEFAULT_CHAT_STATE,
  174. (set, _get) => {
  175. function get() {
  176. return {
  177. ..._get(),
  178. ...methods,
  179. };
  180. }
  181. const methods = {
  182. forkSession() {
  183. // 获取当前会话
  184. const currentSession = get().currentSession();
  185. if (!currentSession) return;
  186. const newSession = createEmptySession();
  187. newSession.topic = currentSession.topic;
  188. newSession.messages = [...currentSession.messages];
  189. newSession.mask = {
  190. ...currentSession.mask,
  191. modelConfig: {
  192. ...currentSession.mask.modelConfig,
  193. },
  194. };
  195. set((state) => ({
  196. currentSessionIndex: 0,
  197. sessions: [newSession, ...state.sessions],
  198. }));
  199. },
  200. clearSessions() {
  201. set(() => ({
  202. sessions: [createEmptySession()],
  203. currentSessionIndex: 0,
  204. }));
  205. },
  206. selectSession(index: number) {
  207. set({
  208. currentSessionIndex: index,
  209. });
  210. },
  211. moveSession(from: number, to: number) {
  212. set((state) => {
  213. const { sessions, currentSessionIndex: oldIndex } = state;
  214. // move the session
  215. const newSessions = [...sessions];
  216. const session = newSessions[from];
  217. newSessions.splice(from, 1);
  218. newSessions.splice(to, 0, session);
  219. // modify current session id
  220. let newIndex = oldIndex === from ? to : oldIndex;
  221. if (oldIndex > from && oldIndex <= to) {
  222. newIndex -= 1;
  223. } else if (oldIndex < from && oldIndex >= to) {
  224. newIndex += 1;
  225. }
  226. return {
  227. currentSessionIndex: newIndex,
  228. sessions: newSessions,
  229. };
  230. });
  231. },
  232. newSession(mask?: Mask) {
  233. const session = createEmptySession();
  234. if (mask) {
  235. const config = useAppConfig.getState();
  236. const globalModelConfig = config.modelConfig;
  237. session.mask = {
  238. ...mask,
  239. modelConfig: {
  240. ...globalModelConfig,
  241. ...mask.modelConfig,
  242. },
  243. };
  244. session.topic = mask.name;
  245. }
  246. set((state) => ({
  247. currentSessionIndex: 0,
  248. sessions: [session].concat(state.sessions),
  249. }));
  250. },
  251. nextSession(delta: number) {
  252. const n = get().sessions.length;
  253. const limit = (x: number) => (x + n) % n;
  254. const i = get().currentSessionIndex;
  255. get().selectSession(limit(i + delta));
  256. },
  257. deleteSession(index: number) {
  258. const deletingLastSession = get().sessions.length === 1;
  259. const deletedSession = get().sessions.at(index);
  260. if (!deletedSession) return;
  261. const sessions = get().sessions.slice();
  262. sessions.splice(index, 1);
  263. const currentIndex = get().currentSessionIndex;
  264. let nextIndex = Math.min(
  265. currentIndex - Number(index < currentIndex),
  266. sessions.length - 1,
  267. );
  268. if (deletingLastSession) {
  269. nextIndex = 0;
  270. sessions.push(createEmptySession());
  271. }
  272. // for undo delete action
  273. const restoreState = {
  274. currentSessionIndex: get().currentSessionIndex,
  275. sessions: get().sessions.slice(),
  276. };
  277. set(() => ({
  278. currentSessionIndex: nextIndex,
  279. sessions,
  280. }));
  281. showToast(
  282. Locale.Home.DeleteToast,
  283. {
  284. text: Locale.Home.Revert,
  285. onClick() {
  286. set(() => restoreState);
  287. },
  288. },
  289. 5000,
  290. );
  291. },
  292. currentSession() {
  293. let index = get().currentSessionIndex;
  294. const sessions = get().sessions;
  295. if (index < 0 || index >= sessions.length) {
  296. index = Math.min(sessions.length - 1, Math.max(0, index));
  297. set(() => ({ currentSessionIndex: index }));
  298. }
  299. const session = sessions[index];
  300. return session;
  301. },
  302. onNewMessage(message: ChatMessage) {
  303. get().updateCurrentSession((session) => {
  304. session.messages = session.messages.concat();
  305. session.lastUpdate = Date.now();
  306. });
  307. get().updateStat(message);
  308. get().summarizeSession();
  309. },
  310. async onUserInput(content: string, attachImages?: string[]) {
  311. const session = get().currentSession();
  312. const modelConfig = session.mask.modelConfig;
  313. const userContent = fillTemplateWith(content, modelConfig);
  314. console.log("[User Input] after template: ", userContent);
  315. let mContent: string | MultimodalContent[] = userContent;
  316. if (attachImages && attachImages.length > 0) {
  317. mContent = [
  318. {
  319. type: "text",
  320. text: userContent,
  321. },
  322. ];
  323. mContent = mContent.concat(
  324. attachImages.map((url) => {
  325. return {
  326. type: "image_url",
  327. image_url: {
  328. url: url,
  329. },
  330. };
  331. }),
  332. );
  333. }
  334. let userMessage: ChatMessage = createMessage({
  335. role: "user",
  336. content: mContent,
  337. });
  338. const botMessage: ChatMessage = createMessage({
  339. role: "assistant",
  340. streaming: true,
  341. model: modelConfig.model,
  342. });
  343. // get recent messages
  344. const recentMessages = get().getMessagesWithMemory();
  345. const sendMessages = recentMessages.concat(userMessage);
  346. const messageIndex = get().currentSession().messages.length + 1;
  347. // save user's and bot's message
  348. get().updateCurrentSession((session) => {
  349. const savedUserMessage = {
  350. ...userMessage,
  351. content: mContent,
  352. };
  353. session.messages = session.messages.concat([
  354. savedUserMessage,
  355. botMessage,
  356. ]);
  357. });
  358. const api: ClientApi = getClientApi(modelConfig.providerName);
  359. // make request
  360. api.llm.chat({
  361. messages: sendMessages,
  362. config: { ...modelConfig, stream: true },
  363. onUpdate(message) {
  364. botMessage.streaming = true;
  365. if (message) {
  366. botMessage.content = message;
  367. }
  368. get().updateCurrentSession((session) => {
  369. session.messages = session.messages.concat();
  370. });
  371. },
  372. onFinish(message) {
  373. botMessage.streaming = false;
  374. if (message) {
  375. botMessage.content = message;
  376. get().onNewMessage(botMessage);
  377. }
  378. ChatControllerPool.remove(session.id, botMessage.id);
  379. },
  380. onBeforeTool(tool: ChatMessageTool) {
  381. (botMessage.tools = botMessage?.tools || []).push(tool);
  382. get().updateCurrentSession((session) => {
  383. session.messages = session.messages.concat();
  384. });
  385. },
  386. onAfterTool(tool: ChatMessageTool) {
  387. botMessage?.tools?.forEach((t, i, tools) => {
  388. if (tool.id == t.id) {
  389. tools[i] = { ...tool };
  390. }
  391. });
  392. get().updateCurrentSession((session) => {
  393. session.messages = session.messages.concat();
  394. });
  395. },
  396. onError(error) {
  397. const isAborted = error.message?.includes?.("aborted");
  398. botMessage.content +=
  399. "\n\n" +
  400. prettyObject({
  401. error: true,
  402. message: error.message,
  403. });
  404. botMessage.streaming = false;
  405. userMessage.isError = !isAborted;
  406. botMessage.isError = !isAborted;
  407. get().updateCurrentSession((session) => {
  408. session.messages = session.messages.concat();
  409. });
  410. ChatControllerPool.remove(
  411. session.id,
  412. botMessage.id ?? messageIndex,
  413. );
  414. console.error("[Chat] failed ", error);
  415. },
  416. onController(controller) {
  417. // collect controller for stop/retry
  418. ChatControllerPool.addController(
  419. session.id,
  420. botMessage.id ?? messageIndex,
  421. controller,
  422. );
  423. },
  424. });
  425. },
  426. getMemoryPrompt() {
  427. const session = get().currentSession();
  428. if (session.memoryPrompt.length) {
  429. return {
  430. role: "system",
  431. content: Locale.Store.Prompt.History(session.memoryPrompt),
  432. date: "",
  433. } as ChatMessage;
  434. }
  435. },
  436. getMessagesWithMemory() {
  437. const session = get().currentSession();
  438. const modelConfig = session.mask.modelConfig;
  439. const clearContextIndex = session.clearContextIndex ?? 0;
  440. const messages = session.messages.slice();
  441. const totalMessageCount = session.messages.length;
  442. // in-context prompts
  443. const contextPrompts = session.mask.context.slice();
  444. // system prompts, to get close to OpenAI Web ChatGPT
  445. const shouldInjectSystemPrompts =
  446. modelConfig.enableInjectSystemPrompts &&
  447. (session.mask.modelConfig.model.startsWith("gpt-") ||
  448. session.mask.modelConfig.model.startsWith("chatgpt-"));
  449. var systemPrompts: ChatMessage[] = [];
  450. systemPrompts = shouldInjectSystemPrompts
  451. ? [
  452. createMessage({
  453. role: "system",
  454. content: fillTemplateWith("", {
  455. ...modelConfig,
  456. template: DEFAULT_SYSTEM_TEMPLATE,
  457. }),
  458. }),
  459. ]
  460. : [];
  461. if (shouldInjectSystemPrompts) {
  462. console.log(
  463. "[Global System Prompt] ",
  464. systemPrompts.at(0)?.content ?? "empty",
  465. );
  466. }
  467. const memoryPrompt = get().getMemoryPrompt();
  468. // long term memory
  469. const shouldSendLongTermMemory =
  470. modelConfig.sendMemory &&
  471. session.memoryPrompt &&
  472. session.memoryPrompt.length > 0 &&
  473. session.lastSummarizeIndex > clearContextIndex;
  474. const longTermMemoryPrompts =
  475. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  476. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  477. // short term memory
  478. const shortTermMemoryStartIndex = Math.max(
  479. 0,
  480. totalMessageCount - modelConfig.historyMessageCount,
  481. );
  482. // lets concat send messages, including 4 parts:
  483. // 0. system prompt: to get close to OpenAI Web ChatGPT
  484. // 1. long term memory: summarized memory messages
  485. // 2. pre-defined in-context prompts
  486. // 3. short term memory: latest n messages
  487. // 4. newest input message
  488. const memoryStartIndex = shouldSendLongTermMemory
  489. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  490. : shortTermMemoryStartIndex;
  491. // and if user has cleared history messages, we should exclude the memory too.
  492. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  493. const maxTokenThreshold = modelConfig.max_tokens;
  494. // get recent messages as much as possible
  495. const reversedRecentMessages = [];
  496. for (
  497. let i = totalMessageCount - 1, tokenCount = 0;
  498. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  499. i -= 1
  500. ) {
  501. const msg = messages[i];
  502. if (!msg || msg.isError) continue;
  503. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  504. reversedRecentMessages.push(msg);
  505. }
  506. // concat all messages
  507. const recentMessages = [
  508. ...systemPrompts,
  509. ...longTermMemoryPrompts,
  510. ...contextPrompts,
  511. ...reversedRecentMessages.reverse(),
  512. ];
  513. return recentMessages;
  514. },
  515. updateMessage(
  516. sessionIndex: number,
  517. messageIndex: number,
  518. updater: (message?: ChatMessage) => void,
  519. ) {
  520. const sessions = get().sessions;
  521. const session = sessions.at(sessionIndex);
  522. const messages = session?.messages;
  523. updater(messages?.at(messageIndex));
  524. set(() => ({ sessions }));
  525. },
  526. resetSession() {
  527. get().updateCurrentSession((session) => {
  528. session.messages = [];
  529. session.memoryPrompt = "";
  530. });
  531. },
  532. summarizeSession(refreshTitle: boolean = false) {
  533. const config = useAppConfig.getState();
  534. const session = get().currentSession();
  535. const modelConfig = session.mask.modelConfig;
  536. // skip summarize when using dalle3?
  537. if (isDalle3(modelConfig.model)) {
  538. return;
  539. }
  540. // if not config compressModel, then using getSummarizeModel
  541. const [model, providerName] = modelConfig.compressModel
  542. ? [modelConfig.compressModel, modelConfig.compressProviderName]
  543. : getSummarizeModel(
  544. session.mask.modelConfig.model,
  545. session.mask.modelConfig.providerName,
  546. );
  547. const api: ClientApi = getClientApi(providerName as ServiceProvider);
  548. // remove error messages if any
  549. const messages = session.messages;
  550. // should summarize topic after chating more than 50 words
  551. const SUMMARIZE_MIN_LEN = 50;
  552. if (
  553. (config.enableAutoGenerateTitle &&
  554. session.topic === DEFAULT_TOPIC &&
  555. countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
  556. refreshTitle
  557. ) {
  558. const startIndex = Math.max(
  559. 0,
  560. messages.length - modelConfig.historyMessageCount,
  561. );
  562. const topicMessages = messages
  563. .slice(
  564. startIndex < messages.length ? startIndex : messages.length - 1,
  565. messages.length,
  566. )
  567. .concat(
  568. createMessage({
  569. role: "user",
  570. content: Locale.Store.Prompt.Topic,
  571. }),
  572. );
  573. api.llm.chat({
  574. messages: topicMessages,
  575. config: {
  576. model,
  577. stream: false,
  578. providerName,
  579. },
  580. onFinish(message) {
  581. if (!isValidMessage(message)) return;
  582. get().updateCurrentSession(
  583. (session) =>
  584. (session.topic =
  585. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  586. );
  587. },
  588. });
  589. }
  590. const summarizeIndex = Math.max(
  591. session.lastSummarizeIndex,
  592. session.clearContextIndex ?? 0,
  593. );
  594. let toBeSummarizedMsgs = messages
  595. .filter((msg) => !msg.isError)
  596. .slice(summarizeIndex);
  597. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  598. if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
  599. const n = toBeSummarizedMsgs.length;
  600. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  601. Math.max(0, n - modelConfig.historyMessageCount),
  602. );
  603. }
  604. const memoryPrompt = get().getMemoryPrompt();
  605. if (memoryPrompt) {
  606. // add memory prompt
  607. toBeSummarizedMsgs.unshift(memoryPrompt);
  608. }
  609. const lastSummarizeIndex = session.messages.length;
  610. console.log(
  611. "[Chat History] ",
  612. toBeSummarizedMsgs,
  613. historyMsgLength,
  614. modelConfig.compressMessageLengthThreshold,
  615. );
  616. if (
  617. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  618. modelConfig.sendMemory
  619. ) {
  620. /** Destruct max_tokens while summarizing
  621. * this param is just shit
  622. **/
  623. const { max_tokens, ...modelcfg } = modelConfig;
  624. api.llm.chat({
  625. messages: toBeSummarizedMsgs.concat(
  626. createMessage({
  627. role: "system",
  628. content: Locale.Store.Prompt.Summarize,
  629. date: "",
  630. }),
  631. ),
  632. config: {
  633. ...modelcfg,
  634. stream: true,
  635. model,
  636. providerName,
  637. },
  638. onUpdate(message) {
  639. session.memoryPrompt = message;
  640. },
  641. onFinish(message) {
  642. console.log("[Memory] ", message);
  643. get().updateCurrentSession((session) => {
  644. session.lastSummarizeIndex = lastSummarizeIndex;
  645. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  646. });
  647. },
  648. onError(err) {
  649. console.error("[Summarize] ", err);
  650. },
  651. });
  652. }
  653. function isValidMessage(message: any): boolean {
  654. return typeof message === "string" && !message.startsWith("```json");
  655. }
  656. },
  657. updateStat(message: ChatMessage) {
  658. get().updateCurrentSession((session) => {
  659. session.stat.charCount += message.content.length;
  660. // TODO: should update chat count and word count
  661. });
  662. },
  663. updateCurrentSession(updater: (session: ChatSession) => void) {
  664. const sessions = get().sessions;
  665. const index = get().currentSessionIndex;
  666. updater(sessions[index]);
  667. set(() => ({ sessions }));
  668. },
  669. async clearAllData() {
  670. await indexedDBStorage.clear();
  671. localStorage.clear();
  672. location.reload();
  673. },
  674. setLastInput(lastInput: string) {
  675. set({
  676. lastInput,
  677. });
  678. },
  679. };
  680. return methods;
  681. },
  682. {
  683. name: StoreKey.Chat,
  684. version: 3.3,
  685. migrate(persistedState, version) {
  686. const state = persistedState as any;
  687. const newState = JSON.parse(
  688. JSON.stringify(state),
  689. ) as typeof DEFAULT_CHAT_STATE;
  690. if (version < 2) {
  691. newState.sessions = [];
  692. const oldSessions = state.sessions;
  693. for (const oldSession of oldSessions) {
  694. const newSession = createEmptySession();
  695. newSession.topic = oldSession.topic;
  696. newSession.messages = [...oldSession.messages];
  697. newSession.mask.modelConfig.sendMemory = true;
  698. newSession.mask.modelConfig.historyMessageCount = 4;
  699. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  700. newState.sessions.push(newSession);
  701. }
  702. }
  703. if (version < 3) {
  704. // migrate id to nanoid
  705. newState.sessions.forEach((s) => {
  706. s.id = nanoid();
  707. s.messages.forEach((m) => (m.id = nanoid()));
  708. });
  709. }
  710. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  711. // Resolve issue of old sessions not automatically enabling.
  712. if (version < 3.1) {
  713. newState.sessions.forEach((s) => {
  714. if (
  715. // Exclude those already set by user
  716. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  717. ) {
  718. // Because users may have changed this configuration,
  719. // the user's current configuration is used instead of the default
  720. const config = useAppConfig.getState();
  721. s.mask.modelConfig.enableInjectSystemPrompts =
  722. config.modelConfig.enableInjectSystemPrompts;
  723. }
  724. });
  725. }
  726. // add default summarize model for every session
  727. if (version < 3.2) {
  728. newState.sessions.forEach((s) => {
  729. const config = useAppConfig.getState();
  730. s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
  731. s.mask.modelConfig.compressProviderName =
  732. config.modelConfig.compressProviderName;
  733. });
  734. }
  735. // revert default summarize model for every session
  736. if (version < 3.3) {
  737. newState.sessions.forEach((s) => {
  738. const config = useAppConfig.getState();
  739. s.mask.modelConfig.compressModel = "";
  740. s.mask.modelConfig.compressProviderName = "";
  741. });
  742. }
  743. return newState as any;
  744. },
  745. },
  746. );