chat.ts 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. import { trimTopic, getMessageTextContent } from "../utils";
  2. import Locale, { getLang } from "../locales";
  3. import { showToast } from "../components/ui-lib";
  4. import { ModelConfig, ModelType, useAppConfig } from "./config";
  5. import { createEmptyMask, Mask } from "./mask";
  6. import {
  7. DEFAULT_INPUT_TEMPLATE,
  8. DEFAULT_MODELS,
  9. DEFAULT_SYSTEM_TEMPLATE,
  10. KnowledgeCutOffDate,
  11. StoreKey,
  12. SUMMARIZE_MODEL,
  13. GEMINI_SUMMARIZE_MODEL,
  14. } from "../constant";
  15. import { getClientApi } from "../client/api";
  16. import type {
  17. ClientApi,
  18. RequestMessage,
  19. MultimodalContent,
  20. } from "../client/api";
  21. import { ChatControllerPool } from "../client/controller";
  22. import { prettyObject } from "../utils/format";
  23. import { estimateTokenLength } from "../utils/token";
  24. import { nanoid } from "nanoid";
  25. import { createPersistStore } from "../utils/store";
  26. import { collectModelsWithDefaultModel } from "../utils/model";
  27. import { useAccessStore } from "./access";
  28. export type ChatMessage = RequestMessage & {
  29. date: string;
  30. streaming?: boolean;
  31. isError?: boolean;
  32. id: string;
  33. model?: ModelType;
  34. };
  35. export function createMessage(override: Partial<ChatMessage>): ChatMessage {
  36. return {
  37. id: nanoid(),
  38. date: new Date().toLocaleString(),
  39. role: "user",
  40. content: "",
  41. ...override,
  42. };
  43. }
  44. export interface ChatStat {
  45. tokenCount: number;
  46. wordCount: number;
  47. charCount: number;
  48. }
  49. export interface ChatSession {
  50. id: string;
  51. topic: string;
  52. memoryPrompt: string;
  53. messages: ChatMessage[];
  54. stat: ChatStat;
  55. lastUpdate: number;
  56. lastSummarizeIndex: number;
  57. clearContextIndex?: number;
  58. mask: Mask;
  59. }
  60. export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  61. export const BOT_HELLO: ChatMessage = createMessage({
  62. role: "assistant",
  63. content: Locale.Store.BotHello,
  64. });
  65. function createEmptySession(): ChatSession {
  66. return {
  67. id: nanoid(),
  68. topic: DEFAULT_TOPIC,
  69. memoryPrompt: "",
  70. messages: [],
  71. stat: {
  72. tokenCount: 0,
  73. wordCount: 0,
  74. charCount: 0,
  75. },
  76. lastUpdate: Date.now(),
  77. lastSummarizeIndex: 0,
  78. mask: createEmptyMask(),
  79. };
  80. }
  81. function getSummarizeModel(currentModel: string) {
  82. // if it is using gpt-* models, force to use 3.5 to summarize
  83. if (currentModel.startsWith("gpt")) {
  84. const configStore = useAppConfig.getState();
  85. const accessStore = useAccessStore.getState();
  86. const allModel = collectModelsWithDefaultModel(
  87. configStore.models,
  88. [configStore.customModels, accessStore.customModels].join(","),
  89. accessStore.defaultModel,
  90. );
  91. const summarizeModel = allModel.find(
  92. (m) => m.name === SUMMARIZE_MODEL && m.available,
  93. );
  94. return summarizeModel?.name ?? currentModel;
  95. }
  96. if (currentModel.startsWith("gemini")) {
  97. return GEMINI_SUMMARIZE_MODEL;
  98. }
  99. return currentModel;
  100. }
  101. function countMessages(msgs: ChatMessage[]) {
  102. return msgs.reduce(
  103. (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
  104. 0,
  105. );
  106. }
  107. function fillTemplateWith(input: string, modelConfig: ModelConfig) {
  108. const cutoff =
  109. KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
  110. // Find the model in the DEFAULT_MODELS array that matches the modelConfig.model
  111. const modelInfo = DEFAULT_MODELS.find((m) => m.name === modelConfig.model);
  112. var serviceProvider = "OpenAI";
  113. if (modelInfo) {
  114. // TODO: auto detect the providerName from the modelConfig.model
  115. // Directly use the providerName from the modelInfo
  116. serviceProvider = modelInfo.provider.providerName;
  117. }
  118. const vars = {
  119. ServiceProvider: serviceProvider,
  120. cutoff,
  121. model: modelConfig.model,
  122. time: new Date().toString(),
  123. lang: getLang(),
  124. input: input,
  125. };
  126. let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
  127. // remove duplicate
  128. if (input.startsWith(output)) {
  129. output = "";
  130. }
  131. // must contains {{input}}
  132. const inputVar = "{{input}}";
  133. if (!output.includes(inputVar)) {
  134. output += "\n" + inputVar;
  135. }
  136. Object.entries(vars).forEach(([name, value]) => {
  137. const regex = new RegExp(`{{${name}}}`, "g");
  138. output = output.replace(regex, value.toString()); // Ensure value is a string
  139. });
  140. return output;
  141. }
  142. const DEFAULT_CHAT_STATE = {
  143. sessions: [createEmptySession()],
  144. currentSessionIndex: 0,
  145. };
  146. export const useChatStore = createPersistStore(
  147. DEFAULT_CHAT_STATE,
  148. (set, _get) => {
  149. function get() {
  150. return {
  151. ..._get(),
  152. ...methods,
  153. };
  154. }
  155. const methods = {
  156. clearSessions() {
  157. set(() => ({
  158. sessions: [createEmptySession()],
  159. currentSessionIndex: 0,
  160. }));
  161. },
  162. selectSession(index: number) {
  163. set({
  164. currentSessionIndex: index,
  165. });
  166. },
  167. moveSession(from: number, to: number) {
  168. set((state) => {
  169. const { sessions, currentSessionIndex: oldIndex } = state;
  170. // move the session
  171. const newSessions = [...sessions];
  172. const session = newSessions[from];
  173. newSessions.splice(from, 1);
  174. newSessions.splice(to, 0, session);
  175. // modify current session id
  176. let newIndex = oldIndex === from ? to : oldIndex;
  177. if (oldIndex > from && oldIndex <= to) {
  178. newIndex -= 1;
  179. } else if (oldIndex < from && oldIndex >= to) {
  180. newIndex += 1;
  181. }
  182. return {
  183. currentSessionIndex: newIndex,
  184. sessions: newSessions,
  185. };
  186. });
  187. },
  188. newSession(mask?: Mask) {
  189. const session = createEmptySession();
  190. if (mask) {
  191. const config = useAppConfig.getState();
  192. const globalModelConfig = config.modelConfig;
  193. session.mask = {
  194. ...mask,
  195. modelConfig: {
  196. ...globalModelConfig,
  197. ...mask.modelConfig,
  198. },
  199. };
  200. session.topic = mask.name;
  201. }
  202. set((state) => ({
  203. currentSessionIndex: 0,
  204. sessions: [session].concat(state.sessions),
  205. }));
  206. },
  207. nextSession(delta: number) {
  208. const n = get().sessions.length;
  209. const limit = (x: number) => (x + n) % n;
  210. const i = get().currentSessionIndex;
  211. get().selectSession(limit(i + delta));
  212. },
  213. deleteSession(index: number) {
  214. const deletingLastSession = get().sessions.length === 1;
  215. const deletedSession = get().sessions.at(index);
  216. if (!deletedSession) return;
  217. const sessions = get().sessions.slice();
  218. sessions.splice(index, 1);
  219. const currentIndex = get().currentSessionIndex;
  220. let nextIndex = Math.min(
  221. currentIndex - Number(index < currentIndex),
  222. sessions.length - 1,
  223. );
  224. if (deletingLastSession) {
  225. nextIndex = 0;
  226. sessions.push(createEmptySession());
  227. }
  228. // for undo delete action
  229. const restoreState = {
  230. currentSessionIndex: get().currentSessionIndex,
  231. sessions: get().sessions.slice(),
  232. };
  233. set(() => ({
  234. currentSessionIndex: nextIndex,
  235. sessions,
  236. }));
  237. showToast(
  238. Locale.Home.DeleteToast,
  239. {
  240. text: Locale.Home.Revert,
  241. onClick() {
  242. set(() => restoreState);
  243. },
  244. },
  245. 5000,
  246. );
  247. },
  248. currentSession() {
  249. let index = get().currentSessionIndex;
  250. const sessions = get().sessions;
  251. if (index < 0 || index >= sessions.length) {
  252. index = Math.min(sessions.length - 1, Math.max(0, index));
  253. set(() => ({ currentSessionIndex: index }));
  254. }
  255. const session = sessions[index];
  256. return session;
  257. },
  258. onNewMessage(message: ChatMessage) {
  259. get().updateCurrentSession((session) => {
  260. session.messages = session.messages.concat();
  261. session.lastUpdate = Date.now();
  262. });
  263. get().updateStat(message);
  264. get().summarizeSession();
  265. },
  266. async onUserInput(content: string, attachImages?: string[]) {
  267. const session = get().currentSession();
  268. const modelConfig = session.mask.modelConfig;
  269. const userContent = fillTemplateWith(content, modelConfig);
  270. console.log("[User Input] after template: ", userContent);
  271. let mContent: string | MultimodalContent[] = userContent;
  272. if (attachImages && attachImages.length > 0) {
  273. mContent = [
  274. {
  275. type: "text",
  276. text: userContent,
  277. },
  278. ];
  279. mContent = mContent.concat(
  280. attachImages.map((url) => {
  281. return {
  282. type: "image_url",
  283. image_url: {
  284. url: url,
  285. },
  286. };
  287. }),
  288. );
  289. }
  290. let userMessage: ChatMessage = createMessage({
  291. role: "user",
  292. content: mContent,
  293. });
  294. const botMessage: ChatMessage = createMessage({
  295. role: "assistant",
  296. streaming: true,
  297. model: modelConfig.model,
  298. });
  299. // get recent messages
  300. const recentMessages = get().getMessagesWithMemory();
  301. const sendMessages = recentMessages.concat(userMessage);
  302. const messageIndex = get().currentSession().messages.length + 1;
  303. // save user's and bot's message
  304. get().updateCurrentSession((session) => {
  305. const savedUserMessage = {
  306. ...userMessage,
  307. content: mContent,
  308. };
  309. session.messages = session.messages.concat([
  310. savedUserMessage,
  311. botMessage,
  312. ]);
  313. });
  314. const api: ClientApi = getClientApi(modelConfig.providerName);
  315. // make request
  316. api.llm.chat({
  317. messages: sendMessages,
  318. config: { ...modelConfig, stream: true },
  319. onUpdate(message) {
  320. botMessage.streaming = true;
  321. if (message) {
  322. botMessage.content = message;
  323. }
  324. get().updateCurrentSession((session) => {
  325. session.messages = session.messages.concat();
  326. });
  327. },
  328. onFinish(message) {
  329. botMessage.streaming = false;
  330. if (message) {
  331. botMessage.content = message;
  332. get().onNewMessage(botMessage);
  333. }
  334. ChatControllerPool.remove(session.id, botMessage.id);
  335. },
  336. onError(error) {
  337. const isAborted = error.message.includes("aborted");
  338. botMessage.content +=
  339. "\n\n" +
  340. prettyObject({
  341. error: true,
  342. message: error.message,
  343. });
  344. botMessage.streaming = false;
  345. userMessage.isError = !isAborted;
  346. botMessage.isError = !isAborted;
  347. get().updateCurrentSession((session) => {
  348. session.messages = session.messages.concat();
  349. });
  350. ChatControllerPool.remove(
  351. session.id,
  352. botMessage.id ?? messageIndex,
  353. );
  354. console.error("[Chat] failed ", error);
  355. },
  356. onController(controller) {
  357. // collect controller for stop/retry
  358. ChatControllerPool.addController(
  359. session.id,
  360. botMessage.id ?? messageIndex,
  361. controller,
  362. );
  363. },
  364. });
  365. },
  366. getMemoryPrompt() {
  367. const session = get().currentSession();
  368. if (session.memoryPrompt.length) {
  369. return {
  370. role: "system",
  371. content: Locale.Store.Prompt.History(session.memoryPrompt),
  372. date: "",
  373. } as ChatMessage;
  374. }
  375. },
  376. getMessagesWithMemory() {
  377. const session = get().currentSession();
  378. const modelConfig = session.mask.modelConfig;
  379. const clearContextIndex = session.clearContextIndex ?? 0;
  380. const messages = session.messages.slice();
  381. const totalMessageCount = session.messages.length;
  382. // in-context prompts
  383. const contextPrompts = session.mask.context.slice();
  384. // system prompts, to get close to OpenAI Web ChatGPT
  385. const shouldInjectSystemPrompts =
  386. modelConfig.enableInjectSystemPrompts &&
  387. session.mask.modelConfig.model.startsWith("gpt-");
  388. var systemPrompts: ChatMessage[] = [];
  389. systemPrompts = shouldInjectSystemPrompts
  390. ? [
  391. createMessage({
  392. role: "system",
  393. content: fillTemplateWith("", {
  394. ...modelConfig,
  395. template: DEFAULT_SYSTEM_TEMPLATE,
  396. }),
  397. }),
  398. ]
  399. : [];
  400. if (shouldInjectSystemPrompts) {
  401. console.log(
  402. "[Global System Prompt] ",
  403. systemPrompts.at(0)?.content ?? "empty",
  404. );
  405. }
  406. const memoryPrompt = get().getMemoryPrompt();
  407. // long term memory
  408. const shouldSendLongTermMemory =
  409. modelConfig.sendMemory &&
  410. session.memoryPrompt &&
  411. session.memoryPrompt.length > 0 &&
  412. session.lastSummarizeIndex > clearContextIndex;
  413. const longTermMemoryPrompts =
  414. shouldSendLongTermMemory && memoryPrompt ? [memoryPrompt] : [];
  415. const longTermMemoryStartIndex = session.lastSummarizeIndex;
  416. // short term memory
  417. const shortTermMemoryStartIndex = Math.max(
  418. 0,
  419. totalMessageCount - modelConfig.historyMessageCount,
  420. );
  421. // lets concat send messages, including 4 parts:
  422. // 0. system prompt: to get close to OpenAI Web ChatGPT
  423. // 1. long term memory: summarized memory messages
  424. // 2. pre-defined in-context prompts
  425. // 3. short term memory: latest n messages
  426. // 4. newest input message
  427. const memoryStartIndex = shouldSendLongTermMemory
  428. ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
  429. : shortTermMemoryStartIndex;
  430. // and if user has cleared history messages, we should exclude the memory too.
  431. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
  432. const maxTokenThreshold = modelConfig.max_tokens;
  433. // get recent messages as much as possible
  434. const reversedRecentMessages = [];
  435. for (
  436. let i = totalMessageCount - 1, tokenCount = 0;
  437. i >= contextStartIndex && tokenCount < maxTokenThreshold;
  438. i -= 1
  439. ) {
  440. const msg = messages[i];
  441. if (!msg || msg.isError) continue;
  442. tokenCount += estimateTokenLength(getMessageTextContent(msg));
  443. reversedRecentMessages.push(msg);
  444. }
  445. // concat all messages
  446. const recentMessages = [
  447. ...systemPrompts,
  448. ...longTermMemoryPrompts,
  449. ...contextPrompts,
  450. ...reversedRecentMessages.reverse(),
  451. ];
  452. return recentMessages;
  453. },
  454. updateMessage(
  455. sessionIndex: number,
  456. messageIndex: number,
  457. updater: (message?: ChatMessage) => void,
  458. ) {
  459. const sessions = get().sessions;
  460. const session = sessions.at(sessionIndex);
  461. const messages = session?.messages;
  462. updater(messages?.at(messageIndex));
  463. set(() => ({ sessions }));
  464. },
  465. resetSession() {
  466. get().updateCurrentSession((session) => {
  467. session.messages = [];
  468. session.memoryPrompt = "";
  469. });
  470. },
  471. summarizeSession() {
  472. const config = useAppConfig.getState();
  473. const session = get().currentSession();
  474. const modelConfig = session.mask.modelConfig;
  475. const api: ClientApi = getClientApi(modelConfig.providerName);
  476. // remove error messages if any
  477. const messages = session.messages;
  478. // should summarize topic after chating more than 50 words
  479. const SUMMARIZE_MIN_LEN = 50;
  480. if (
  481. config.enableAutoGenerateTitle &&
  482. session.topic === DEFAULT_TOPIC &&
  483. countMessages(messages) >= SUMMARIZE_MIN_LEN
  484. ) {
  485. const topicMessages = messages.concat(
  486. createMessage({
  487. role: "user",
  488. content: Locale.Store.Prompt.Topic,
  489. }),
  490. );
  491. api.llm.chat({
  492. messages: topicMessages,
  493. config: {
  494. model: getSummarizeModel(session.mask.modelConfig.model),
  495. stream: false,
  496. },
  497. onFinish(message) {
  498. get().updateCurrentSession(
  499. (session) =>
  500. (session.topic =
  501. message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
  502. );
  503. },
  504. });
  505. }
  506. const summarizeIndex = Math.max(
  507. session.lastSummarizeIndex,
  508. session.clearContextIndex ?? 0,
  509. );
  510. let toBeSummarizedMsgs = messages
  511. .filter((msg) => !msg.isError)
  512. .slice(summarizeIndex);
  513. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  514. if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
  515. const n = toBeSummarizedMsgs.length;
  516. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  517. Math.max(0, n - modelConfig.historyMessageCount),
  518. );
  519. }
  520. const memoryPrompt = get().getMemoryPrompt();
  521. if (memoryPrompt) {
  522. // add memory prompt
  523. toBeSummarizedMsgs.unshift(memoryPrompt);
  524. }
  525. const lastSummarizeIndex = session.messages.length;
  526. console.log(
  527. "[Chat History] ",
  528. toBeSummarizedMsgs,
  529. historyMsgLength,
  530. modelConfig.compressMessageLengthThreshold,
  531. );
  532. if (
  533. historyMsgLength > modelConfig.compressMessageLengthThreshold &&
  534. modelConfig.sendMemory
  535. ) {
  536. /** Destruct max_tokens while summarizing
  537. * this param is just shit
  538. **/
  539. const { max_tokens, ...modelcfg } = modelConfig;
  540. api.llm.chat({
  541. messages: toBeSummarizedMsgs.concat(
  542. createMessage({
  543. role: "system",
  544. content: Locale.Store.Prompt.Summarize,
  545. date: "",
  546. }),
  547. ),
  548. config: {
  549. ...modelcfg,
  550. stream: true,
  551. model: getSummarizeModel(session.mask.modelConfig.model),
  552. },
  553. onUpdate(message) {
  554. session.memoryPrompt = message;
  555. },
  556. onFinish(message) {
  557. console.log("[Memory] ", message);
  558. get().updateCurrentSession((session) => {
  559. session.lastSummarizeIndex = lastSummarizeIndex;
  560. session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
  561. });
  562. },
  563. onError(err) {
  564. console.error("[Summarize] ", err);
  565. },
  566. });
  567. }
  568. },
  569. updateStat(message: ChatMessage) {
  570. get().updateCurrentSession((session) => {
  571. session.stat.charCount += message.content.length;
  572. // TODO: should update chat count and word count
  573. });
  574. },
  575. updateCurrentSession(updater: (session: ChatSession) => void) {
  576. const sessions = get().sessions;
  577. const index = get().currentSessionIndex;
  578. updater(sessions[index]);
  579. set(() => ({ sessions }));
  580. },
  581. clearAllData() {
  582. localStorage.clear();
  583. location.reload();
  584. },
  585. };
  586. return methods;
  587. },
  588. {
  589. name: StoreKey.Chat,
  590. version: 3.1,
  591. migrate(persistedState, version) {
  592. const state = persistedState as any;
  593. const newState = JSON.parse(
  594. JSON.stringify(state),
  595. ) as typeof DEFAULT_CHAT_STATE;
  596. if (version < 2) {
  597. newState.sessions = [];
  598. const oldSessions = state.sessions;
  599. for (const oldSession of oldSessions) {
  600. const newSession = createEmptySession();
  601. newSession.topic = oldSession.topic;
  602. newSession.messages = [...oldSession.messages];
  603. newSession.mask.modelConfig.sendMemory = true;
  604. newSession.mask.modelConfig.historyMessageCount = 4;
  605. newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
  606. newState.sessions.push(newSession);
  607. }
  608. }
  609. if (version < 3) {
  610. // migrate id to nanoid
  611. newState.sessions.forEach((s) => {
  612. s.id = nanoid();
  613. s.messages.forEach((m) => (m.id = nanoid()));
  614. });
  615. }
  616. // Enable `enableInjectSystemPrompts` attribute for old sessions.
  617. // Resolve issue of old sessions not automatically enabling.
  618. if (version < 3.1) {
  619. newState.sessions.forEach((s) => {
  620. if (
  621. // Exclude those already set by user
  622. !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
  623. ) {
  624. // Because users may have changed this configuration,
  625. // the user's current configuration is used instead of the default
  626. const config = useAppConfig.getState();
  627. s.mask.modelConfig.enableInjectSystemPrompts =
  628. config.modelConfig.enableInjectSystemPrompts;
  629. }
  630. });
  631. }
  632. return newState as any;
  633. },
  634. },
  635. );