chat.ts 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. import { getMessageTextContent, trimTopic } from "../utils";
  2. import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
  3. import { nanoid } from "nanoid";
  4. import type {
  5. ClientApi,
  6. MultimodalContent,
  7. RequestMessage,
  8. } from "../client/api";
  9. import { getClientApi } from "../client/api";
  10. import { ChatControllerPool } from "../client/controller";
  11. import { showToast } from "../components/ui-lib";
  12. import {
  13. DEFAULT_INPUT_TEMPLATE,
  14. DEFAULT_MODELS,
  15. DEFAULT_SYSTEM_TEMPLATE,
  16. KnowledgeCutOffDate,
  17. StoreKey,
  18. SUMMARIZE_MODEL,
  19. GEMINI_SUMMARIZE_MODEL,
  20. ServiceProvider,
  21. } from "../constant";
  22. import Locale, { getLang } from "../locales";
  23. import { isDalle3, safeLocalStorage } from "../utils";
  24. import { prettyObject } from "../utils/format";
  25. import { createPersistStore } from "../utils/store";
  26. import { estimateTokenLength } from "../utils/token";
  27. import { ModelConfig, ModelType, useAppConfig } from "./config";
  28. import { useAccessStore } from "./access";
  29. import { collectModelsWithDefaultModel } from "../utils/model";
  30. import { createEmptyMask, Mask } from "./mask";
  31. const localStorage = safeLocalStorage();
  32. export type ChatMessageTool = {
  33. id: string;
  34. index?: number;
  35. type?: string;
  36. function?: {
  37. name: string;
  38. arguments?: string;
  39. };
  40. content?: string;
  41. isError?: boolean;
  42. errorMsg?: string;
  43. };
  44. export type ChatMessage = RequestMessage & {
  45. date: string;
  46. streaming?: boolean;
  47. isError?: boolean;
  48. id: string;
  49. model?: ModelType;
  50. tools?: ChatMessageTool[];
  51. audio_url?: string;
  52. };
  53. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  54. return {
  55. id: nanoid(),
  56. date: new Date().toLocaleString(),
  57. role: "user",
  58. content: "",
  59. ...override,
  60. };
  61. }
  62. export interface ChatStat {
  63. tokenCount: number;
  64. wordCount: number;
  65. charCount: number;
  66. }
  67. export interface ChatSession {
  68. id: string;
  69. topic: string;
  70. memoryPrompt: string;
  71. messages: ChatMessage[];
  72. stat: ChatStat;
  73. lastUpdate: number;
  74. lastSummarizeIndex: number;
  75. clearContextIndex?: number;
  76. mask: Mask;
  77. }
  78. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  79. export const BOT_HELLO: ChatMessage = createMessage({
  80. role: "assistant",
  81. content: Locale.Store.BotHello,
  82. });
  83. function createEmptySession(): ChatSession {
  84. return {
  85. id: nanoid(),
  86. topic: DEFAULT_TOPIC,
  87. memoryPrompt: "",
  88. messages: [],
  89. stat: {
  90. tokenCount: 0,
  91. wordCount: 0,
  92. charCount: 0,
  93. },
  94. lastUpdate: Date.now(),
  95. lastSummarizeIndex: 0,
  96. mask: createEmptyMask(),
  97. };
  98. }
  99. function getSummarizeModel(
  100. currentModel: string,
  101. providerName: string,
  102. ): string[] {
  103. // if it is using gpt-* models, force to use 4o-mini to summarize
  104. if (currentModel.startsWith("gpt") || currentModel.startsWith("chatgpt")) {
  105. const configStore = useAppConfig.getState();
  106. const accessStore = useAccessStore.getState();
  107. const allModel = collectModelsWithDefaultModel(
  108. configStore.models,
  109. [configStore.customModels, accessStore.customModels].join(","),
  110. accessStore.defaultModel,
  111. );
  112. const summarizeModel = allModel.find(
  113. (m) => m.name === SUMMARIZE_MODEL && m.available,
  114. );
  115. if (summarizeModel) {
  116. return [
  117. summarizeModel.name,
  118. summarizeModel.provider?.providerName as string,
  119. ];
  120. }
  121. }
  122. if (currentModel.startsWith("gemini")) {
  123. return [GEMINI_SUMMARIZE_MODEL, ServiceProvider.Google];
  124. }
  125. return [currentModel, providerName];
  126. }
  127. function countMessages(msgs: ChatMessage[]) {
  128. return msgs.reduce(
  129. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  130. 0,
  131. );
  132. }
  133. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  134. const cutoff =
  135. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  136. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  137. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  138. var serviceProvider = "OpenAI";
  139. if (modelInfo) {
  140. // TODO: auto detect the providerName from the modelConfig.model
  141. // Directly use the providerName from the modelInfo
  142. serviceProvider = modelInfo.provider.providerName;
  143. }
  144. const vars = {
  145. ServiceProvider: serviceProvider,
  146. cutoff,
  147. model: modelConfig.model,
  148. time: new Date().toString(),
  149. lang: getLang(),
  150. input: input,
  151. };
  152. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  153. // remove duplicate
  154. if (input.startsWith(output)) {
  155. output = "";
  156. }
  157. // must contains {{input}}
  158. const inputVar = "{{input}}";
  159. if (!output.includes(inputVar)) {
  160. output += "\n" + inputVar;
  161. }
  162. Object.entries(vars).forEach(([name, value]) => {
  163. const regex = new RegExp(`{{${name}}}`, "g");
  164. output = output.replace(regex, value.toString()); // Ensure value is a string
  165. });
  166. return output;
  167. }
  168. const DEFAULT_CHAT_STATE = {
  169. sessions: [createEmptySession()],
  170. currentSessionIndex: 0,
  171. lastInput: "",
  172. };
  173. export const useChatStore = createPersistStore(
  174. DEFAULT_CHAT_STATE,
  175. (set, _get) => {
  176. function get() {
  177. return {
  178. ..._get(),
  179. ...methods,
  180. };
  181. }
  182. const methods = {
  183. forkSession() {
  184. // 获取当前会话
  185. const currentSession = get().currentSession();
  186. if (!currentSession) return;
  187. const newSession = createEmptySession();
  188. newSession.topic = currentSession.topic;
  189. // 深拷贝消息
  190. newSession.messages = currentSession.messages.map(msg => ({
  191. ...msg,
  192. id: nanoid(), // 生成新的消息 ID
  193. }));
  194. newSession.mask = {
  195. ...currentSession.mask,
  196. modelConfig: {
  197. ...currentSession.mask.modelConfig,
  198. },
  199. };
  200. set((state) => ({
  201. currentSessionIndex: 0,
  202. sessions: [newSession, ...state.sessions],
  203. }));
  204. },
  205. clearSessions() {
  206. set(() => ({
  207. sessions: [createEmptySession()],
  208. currentSessionIndex: 0,
  209. }));
  210. },
  211. selectSession(index: number) {
  212. set({
  213. currentSessionIndex: index,
  214. });
  215. },
  216. moveSession(from: number, to: number) {
  217. set((state) => {
  218. const { sessions, currentSessionIndex: oldIndex } = state;
  219. // move the session
  220. const newSessions = [...sessions];
  221. const session = newSessions[from];
  222. newSessions.splice(from, 1);
  223. newSessions.splice(to, 0, session);
  224. // modify current session id
  225. let newIndex = oldIndex === from ? to : oldIndex;
  226. if (oldIndex > from && oldIndex <= to) {
  227. newIndex -= 1;
  228. } else if (oldIndex < from && oldIndex >= to) {
  229. newIndex += 1;
  230. }
  231. return {
  232. currentSessionIndex: newIndex,
  233. sessions: newSessions,
  234. };
  235. });
  236. },
  237. newSession(mask?: Mask) {
  238. const session = createEmptySession();
  239. if (mask) {
  240. const config = useAppConfig.getState();
  241. const globalModelConfig = config.modelConfig;
  242. session.mask = {
  243. ...mask,
  244. modelConfig: {
  245. ...globalModelConfig,
  246. ...mask.modelConfig,
  247. },
  248. };
  249. session.topic = mask.name;
  250. }
  251. set((state) => ({
  252. currentSessionIndex: 0,
  253. sessions: [session].concat(state.sessions),
  254. }));
  255. },
  256. nextSession(delta: number) {
  257. const n = get().sessions.length;
  258. const limit = (x: number) => (x + n) % n;
  259. const i = get().currentSessionIndex;
  260. get().selectSession(limit(i + delta));
  261. },
  262. deleteSession(index: number) {
  263. const deletingLastSession = get().sessions.length === 1;
  264. const deletedSession = get().sessions.at(index);
  265. if (!deletedSession) return;
  266. const sessions = get().sessions.slice();
  267. sessions.splice(index, 1);
  268. const currentIndex = get().currentSessionIndex;
  269. let nextIndex = Math.min(
  270. currentIndex - Number(index < currentIndex),
  271. sessions.length - 1,
  272. );
  273. if (deletingLastSession) {
  274. nextIndex = 0;
  275. sessions.push(createEmptySession());
  276. }
  277. // for undo delete action
  278. const restoreState = {
  279. currentSessionIndex: get().currentSessionIndex,
  280. sessions: get().sessions.slice(),
  281. };
  282. set(() => ({
  283. currentSessionIndex: nextIndex,
  284. sessions,
  285. }));
  286. showToast(
  287. Locale.Home.DeleteToast,
  288. {
  289. text: Locale.Home.Revert,
  290. onClick() {
  291. set(() => restoreState);
  292. },
  293. },
  294. 5000,
  295. );
  296. },
  297. currentSession() {
  298. let index = get().currentSessionIndex;
  299. const sessions = get().sessions;
  300. if (index < 0 || index >= sessions.length) {
  301. index = Math.min(sessions.length - 1, Math.max(0, index));
  302. set(() => ({ currentSessionIndex: index }));
  303. }
  304. const session = sessions[index];
  305. return session;
  306. },
  307. onNewMessage(message: ChatMessage, targetSession: ChatSession) {
  308. get().updateTargetSession(targetSession, (session) => {
  309. session.messages = session.messages.concat();
  310. session.lastUpdate = Date.now();
  311. });
  312. get().updateStat(message, targetSession);
  313. get().summarizeSession(false, targetSession);
  314. },
  315. async onUserInput(content: string, attachImages?: string[]) {
  316. const session = get().currentSession();
  317. const modelConfig = session.mask.modelConfig;
  318. const userContent = fillTemplateWith(content, modelConfig);
  319. console.log("[User Input] after template: ", userContent);
  320. let mContent: string | MultimodalContent[] = userContent;
  321. if (attachImages && attachImages.length > 0) {
  322. mContent = [
  323. ...(userContent
  324. ? [{ type: "text" as const, text: userContent }]
  325. : []),
  326. ...attachImages.map((url) => ({
  327. type: "image_url" as const,
  328. image_url: { url },
  329. })),
  330. ];
  331. }
  332. let userMessage: ChatMessage = createMessage({
  333. role: "user",
  334. content: mContent,
  335. });
  336. const botMessage: ChatMessage = createMessage({
  337. role: "assistant",
  338. streaming: true,
  339. model: modelConfig.model,
  340. });
  341. // get recent messages
  342. const recentMessages = get().getMessagesWithMemory();
  343. const sendMessages = recentMessages.concat(userMessage);
  344. const messageIndex = session.messages.length + 1;
  345. // save user's and bot's message
  346. get().updateTargetSession(session, (session) => {
  347. const savedUserMessage = {
  348. ...userMessage,
  349. content: mContent,
  350. };
  351. session.messages = session.messages.concat([
  352. savedUserMessage,
  353. botMessage,
  354. ]);
  355. });
  356. const api: ClientApi = getClientApi(modelConfig.providerName);
  357. // make request
  358. api.llm.chat({
  359. messages: sendMessages,
  360. config: { ...modelConfig, stream: true },
  361. onUpdate(message) {
  362. botMessage.streaming = true;
  363. if (message) {
  364. botMessage.content = message;
  365. }
  366. get().updateTargetSession(session, (session) => {
  367. session.messages = session.messages.concat();
  368. });
  369. },
  370. onFinish(message) {
  371. botMessage.streaming = false;
  372. if (message) {
  373. botMessage.content = message;
  374. botMessage.date = new Date().toLocaleString();
  375. get().onNewMessage(botMessage, session);
  376. }
  377. ChatControllerPool.remove(session.id, botMessage.id);
  378. },
  379. onBeforeTool(tool: ChatMessageTool) {
  380. (botMessage.tools = botMessage?.tools || []).push(tool);
  381. get().updateTargetSession(session, (session) => {
  382. session.messages = session.messages.concat();
  383. });
  384. },
  385. onAfterTool(tool: ChatMessageTool) {
  386. botMessage?.tools?.forEach((t, i, tools) => {
  387. if (tool.id == t.id) {
  388. tools[i] = { ...tool };
  389. }
  390. });
  391. get().updateTargetSession(session, (session) => {
  392. session.messages = session.messages.concat();
  393. });
  394. },
  395. onError(error) {
  396. const isAborted = error.message?.includes?.("aborted");
  397. botMessage.content +=
  398. "\n\n" +
  399. prettyObject({
  400. error: true,
  401. message: error.message,
  402. });
  403. botMessage.streaming = false;
  404. userMessage.isError = !isAborted;
  405. botMessage.isError = !isAborted;
  406. get().updateTargetSession(session, (session) => {
  407. session.messages = session.messages.concat();
  408. });
  409. ChatControllerPool.remove(
  410. session.id,
  411. botMessage.id ?? messageIndex,
  412. );
  413. console.error("[Chat] failed ", error);
  414. },
  415. onController(controller) {
  416. // collect controller for stop/retry
  417. ChatControllerPool.addController(
  418. session.id,
  419. botMessage.id ?? messageIndex,
  420. controller,
  421. );
  422. },
  423. });
  424. },
  425. getMemoryPrompt() {
  426. const session = get().currentSession();
  427. if (session.memoryPrompt.length) {
  428. return {
  429. role: "system",
  430. content: Locale.Store.Prompt.History(session.memoryPrompt),
  431. date: "",
  432. } as ChatMessage;
  433. }
  434. },
  435. getMessagesWithMemory() {
  436. const session = get().currentSession();
  437. const modelConfig = session.mask.modelConfig;
  438. const clearContextIndex = session.clearContextIndex ?? 0;
  439. const messages = session.messages.slice();
  440. const totalMessageCount = session.messages.length;
  441. // in-context prompts
  442. const contextPrompts = session.mask.context.slice();
  443. // system prompts, to get close to OpenAI Web ChatGPT
  444. const shouldInjectSystemPrompts =
  445. modelConfig.enableInjectSystemPrompts &&
  446. (session.mask.modelConfig.model.startsWith("gpt-") ||
  447. session.mask.modelConfig.model.startsWith("chatgpt-"));
  448. var systemPrompts: ChatMessage[] = [];
  449. systemPrompts = shouldInjectSystemPrompts
  450. ? [
  451. createMessage({
  452. role: "system",
  453. content: fillTemplateWith("", {
  454. ...modelConfig,
  455. template: DEFAULT_SYSTEM_TEMPLATE,
  456. }),
  457. }),
  458. ]
  459. : [];
  460. if (shouldInjectSystemPrompts) {
  461. console.log(
  462. "[Global System Prompt] ",
  463. systemPrompts.at(0)?.content ?? "empty",
  464. );
  465. }
  466. const memoryPrompt = get().getMemoryPrompt();
  467. // long term memory
  468. const shouldSendLongTermMemory =
  469. modelConfig.sendMemory &&
  470. session.memoryPrompt &&
  471. session.memoryPrompt.length > 0 &&
  472. session.lastSummarizeIndex > clearContextIndex;
  473. const longTermMemoryPrompts =
  474. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  475. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  476. // short term memory
  477. const shortTermMemoryStartIndex = Math.max(
  478. 0,
  479. totalMessageCount - modelConfig.historyMessageCount,
  480. );
  481. // lets concat send messages, including 4 parts:
  482. // 0. system prompt: to get close to OpenAI Web ChatGPT
  483. // 1. long term memory: summarized memory messages
  484. // 2. pre-defined in-context prompts
  485. // 3. short term memory: latest n messages
  486. // 4. newest input message
  487. const memoryStartIndex = shouldSendLongTermMemory
  488. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  489. : shortTermMemoryStartIndex;
  490. // and if user has cleared history messages, we should exclude the memory too.
  491. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  492. const maxTokenThreshold = modelConfig.max_tokens;
  493. // get recent messages as much as possible
  494. const reversedRecentMessages = [];
  495. for (
  496. let i = totalMessageCount - 1, tokenCount = 0;
  497. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  498. i -= 1
  499. ) {
  500. const msg = messages[i];
  501. if (!msg || msg.isError) continue;
  502. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  503. reversedRecentMessages.push(msg);
  504. }
  505. // concat all messages
  506. const recentMessages = [
  507. ...systemPrompts,
  508. ...longTermMemoryPrompts,
  509. ...contextPrompts,
  510. ...reversedRecentMessages.reverse(),
  511. ];
  512. return recentMessages;
  513. },
  514. updateMessage(
  515. sessionIndex: number,
  516. messageIndex: number,
  517. updater: (message?: ChatMessage) => void,
  518. ) {
  519. const sessions = get().sessions;
  520. const session = sessions.at(sessionIndex);
  521. const messages = session?.messages;
  522. updater(messages?.at(messageIndex));
  523. set(() => ({ sessions }));
  524. },
  525. resetSession(session: ChatSession) {
  526. get().updateTargetSession(session, (session) => {
  527. session.messages = [];
  528. session.memoryPrompt = "";
  529. });
  530. },
  531. summarizeSession(
  532. refreshTitle: boolean = false,
  533. targetSession: ChatSession,
  534. ) {
  535. const config = useAppConfig.getState();
  536. const session = targetSession;
  537. const modelConfig = session.mask.modelConfig;
  538. // skip summarize when using dalle3?
  539. if (isDalle3(modelConfig.model)) {
  540. return;
  541. }
  542. // if not config compressModel, then using getSummarizeModel
  543. const [model, providerName] = modelConfig.compressModel
  544. ? [modelConfig.compressModel, modelConfig.compressProviderName]
  545. : getSummarizeModel(
  546. session.mask.modelConfig.model,
  547. session.mask.modelConfig.providerName,
  548. );
  549. const api: ClientApi = getClientApi(providerName as ServiceProvider);
  550. // remove error messages if any
  551. const messages = session.messages;
  552. // should summarize topic after chating more than 50 words
  553. const SUMMARIZE_MIN_LEN = 50;
  554. if (
  555. (config.enableAutoGenerateTitle &&
  556. session.topic === DEFAULT_TOPIC &&
  557. countMessages(messages) >= SUMMARIZE_MIN_LEN) ||
  558. refreshTitle
  559. ) {
  560. const startIndex = Math.max(
  561. 0,
  562. messages.length - modelConfig.historyMessageCount,
  563. );
  564. const topicMessages = messages
  565. .slice(
  566. startIndex < messages.length ? startIndex : messages.length - 1,
  567. messages.length,
  568. )
  569. .concat(
  570. createMessage({
  571. role: "user",
  572. content: Locale.Store.Prompt.Topic,
  573. }),
  574. );
  575. api.llm.chat({
  576. messages: topicMessages,
  577. config: {
  578. model,
  579. stream: false,
  580. providerName,
  581. },
  582. onFinish(message, responseRes) {
  583. if (responseRes?.status === 200) {
  584. get().updateTargetSession(
  585. session,
  586. (session) =>
  587. (session.topic =
  588. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  589. );
  590. }
  591. },
  592. });
  593. }
  594. const summarizeIndex = Math.max(
  595. session.lastSummarizeIndex,
  596. session.clearContextIndex ?? 0,
  597. );
  598. let toBeSummarizedMsgs = messages
  599. .filter((msg) => !msg.isError)
  600. .slice(summarizeIndex);
  601. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  602. if (historyMsgLength > (modelConfig?.max_tokens || 4000)) {
  603. const n = toBeSummarizedMsgs.length;
  604. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  605. Math.max(0, n - modelConfig.historyMessageCount),
  606. );
  607. }
  608. const memoryPrompt = get().getMemoryPrompt();
  609. if (memoryPrompt) {
  610. // add memory prompt
  611. toBeSummarizedMsgs.unshift(memoryPrompt);
  612. }
  613. const lastSummarizeIndex = session.messages.length;
  614. console.log(
  615. "[Chat History] ",
  616. toBeSummarizedMsgs,
  617. historyMsgLength,
  618. modelConfig.compressMessageLengthThreshold,
  619. );
  620. if (
  621. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  622. modelConfig.sendMemory
  623. ) {
  624. /** Destruct max_tokens while summarizing
  625. * this param is just shit
  626. **/
  627. const { max_tokens, ...modelcfg } = modelConfig;
  628. api.llm.chat({
  629. messages: toBeSummarizedMsgs.concat(
  630. createMessage({
  631. role: "system",
  632. content: Locale.Store.Prompt.Summarize,
  633. date: "",
  634. }),
  635. ),
  636. config: {
  637. ...modelcfg,
  638. stream: true,
  639. model,
  640. providerName,
  641. },
  642. onUpdate(message) {
  643. session.memoryPrompt = message;
  644. },
  645. onFinish(message, responseRes) {
  646. if (responseRes?.status === 200) {
  647. console.log("[Memory] ", message);
  648. get().updateTargetSession(session, (session) => {
  649. session.lastSummarizeIndex = lastSummarizeIndex;
  650. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  651. });
  652. }
  653. },
  654. onError(err) {
  655. console.error("[Summarize] ", err);
  656. },
  657. });
  658. }
  659. },
  660. updateStat(message: ChatMessage, session: ChatSession) {
  661. get().updateTargetSession(session, (session) => {
  662. session.stat.charCount += message.content.length;
  663. // TODO: should update chat count and word count
  664. });
  665. },
  666. updateTargetSession(
  667. targetSession: ChatSession,
  668. updater: (session: ChatSession) => void,
  669. ) {
  670. const sessions = get().sessions;
  671. const index = sessions.findIndex((s) => s.id === targetSession.id);
  672. if (index < 0) return;
  673. updater(sessions[index]);
  674. set(() => ({ sessions }));
  675. },
  676. async clearAllData() {
  677. await indexedDBStorage.clear();
  678. localStorage.clear();
  679. location.reload();
  680. },
  681. setLastInput(lastInput: string) {
  682. set({
  683. lastInput,
  684. });
  685. },
  686. };
  687. return methods;
  688. },
  689. {
  690. name: StoreKey.Chat,
  691. version: 3.3,
  692. migrate(persistedState, version) {
  693. const state = persistedState as any;
  694. const newState = JSON.parse(
  695. JSON.stringify(state),
  696. ) as typeof DEFAULT_CHAT_STATE;
  697. if (version < 2) {
  698. newState.sessions = [];
  699. const oldSessions = state.sessions;
  700. for (const oldSession of oldSessions) {
  701. const newSession = createEmptySession();
  702. newSession.topic = oldSession.topic;
  703. newSession.messages = [...oldSession.messages];
  704. newSession.mask.modelConfig.sendMemory = true;
  705. newSession.mask.modelConfig.historyMessageCount = 4;
  706. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  707. newState.sessions.push(newSession);
  708. }
  709. }
  710. if (version < 3) {
  711. // migrate id to nanoid
  712. newState.sessions.forEach((s) => {
  713. s.id = nanoid();
  714. s.messages.forEach((m) => (m.id = nanoid()));
  715. });
  716. }
  717. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  718. // Resolve issue of old sessions not automatically enabling.
  719. if (version < 3.1) {
  720. newState.sessions.forEach((s) => {
  721. if (
  722. // Exclude those already set by user
  723. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  724. ) {
  725. // Because users may have changed this configuration,
  726. // the user's current configuration is used instead of the default
  727. const config = useAppConfig.getState();
  728. s.mask.modelConfig.enableInjectSystemPrompts =
  729. config.modelConfig.enableInjectSystemPrompts;
  730. }
  731. });
  732. }
  733. // add default summarize model for every session
  734. if (version < 3.2) {
  735. newState.sessions.forEach((s) => {
  736. const config = useAppConfig.getState();
  737. s.mask.modelConfig.compressModel = config.modelConfig.compressModel;
  738. s.mask.modelConfig.compressProviderName =
  739. config.modelConfig.compressProviderName;
  740. });
  741. }
  742. // revert default summarize model for every session
  743. if (version < 3.3) {
  744. newState.sessions.forEach((s) => {
  745. const config = useAppConfig.getState();
  746. s.mask.modelConfig.compressModel = "";
  747. s.mask.modelConfig.compressProviderName = "";
  748. });
  749. }
  750. return newState as any;
  751. },
  752. },
  753. );