store.ts 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import { requestChat, requestChatStream, requestWithPrompt } from "./requests";
  5. import { trimTopic } from "./utils";
  6. export type Message = ChatCompletionResponseMessage & {
  7. date: string;
  8. streaming?: boolean;
  9. };
  10. export enum SubmitKey {
  11. Enter = "Enter",
  12. CtrlEnter = "Ctrl + Enter",
  13. ShiftEnter = "Shift + Enter",
  14. AltEnter = "Alt + Enter",
  15. }
  16. export enum Theme {
  17. Auto = "auto",
  18. Dark = "dark",
  19. Light = "light",
  20. }
  21. interface ChatConfig {
  22. maxToken?: number;
  23. historyMessageCount: number; // -1 means all
  24. sendBotMessages: boolean; // send bot's message or not
  25. submitKey: SubmitKey;
  26. avatar: string;
  27. theme: Theme;
  28. tightBorder: boolean;
  29. }
  30. const DEFAULT_CONFIG: ChatConfig = {
  31. historyMessageCount: 5,
  32. sendBotMessages: false as boolean,
  33. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  34. avatar: "1fae0",
  35. theme: Theme.Auto as Theme,
  36. tightBorder: false,
  37. };
  38. interface ChatStat {
  39. tokenCount: number;
  40. wordCount: number;
  41. charCount: number;
  42. }
  43. interface ChatSession {
  44. id: number;
  45. topic: string;
  46. memoryPrompt: string;
  47. messages: Message[];
  48. stat: ChatStat;
  49. lastUpdate: string;
  50. deleted?: boolean;
  51. }
  52. const DEFAULT_TOPIC = "新的聊天";
  53. function createEmptySession(): ChatSession {
  54. const createDate = new Date().toLocaleString();
  55. return {
  56. id: Date.now(),
  57. topic: DEFAULT_TOPIC,
  58. memoryPrompt: "",
  59. messages: [
  60. {
  61. role: "assistant",
  62. content: "有什么可以帮你的吗",
  63. date: createDate,
  64. },
  65. ],
  66. stat: {
  67. tokenCount: 0,
  68. wordCount: 0,
  69. charCount: 0,
  70. },
  71. lastUpdate: createDate,
  72. };
  73. }
  74. interface ChatStore {
  75. config: ChatConfig;
  76. sessions: ChatSession[];
  77. currentSessionIndex: number;
  78. removeSession: (index: number) => void;
  79. selectSession: (index: number) => void;
  80. newSession: () => void;
  81. currentSession: () => ChatSession;
  82. onNewMessage: (message: Message) => void;
  83. onUserInput: (content: string) => Promise<void>;
  84. onBotResponse: (message: Message) => void;
  85. summarizeSession: () => void;
  86. updateStat: (message: Message) => void;
  87. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  88. updateMessage: (
  89. sessionIndex: number,
  90. messageIndex: number,
  91. updater: (message?: Message) => void
  92. ) => void;
  93. getConfig: () => ChatConfig;
  94. resetConfig: () => void;
  95. updateConfig: (updater: (config: ChatConfig) => void) => void;
  96. }
  97. export const useChatStore = create<ChatStore>()(
  98. persist(
  99. (set, get) => ({
  100. sessions: [createEmptySession()],
  101. currentSessionIndex: 0,
  102. config: {
  103. ...DEFAULT_CONFIG,
  104. },
  105. resetConfig() {
  106. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  107. },
  108. getConfig() {
  109. return get().config;
  110. },
  111. updateConfig(updater) {
  112. const config = get().config;
  113. updater(config);
  114. set(() => ({ config }));
  115. },
  116. selectSession(index: number) {
  117. set({
  118. currentSessionIndex: index,
  119. });
  120. },
  121. removeSession(index: number) {
  122. set((state) => {
  123. let nextIndex = state.currentSessionIndex;
  124. const sessions = state.sessions;
  125. if (sessions.length === 1) {
  126. return {
  127. currentSessionIndex: 0,
  128. sessions: [createEmptySession()],
  129. };
  130. }
  131. sessions.splice(index, 1);
  132. if (nextIndex === index) {
  133. nextIndex -= 1;
  134. }
  135. return {
  136. currentSessionIndex: nextIndex,
  137. sessions,
  138. };
  139. });
  140. },
  141. newSession() {
  142. set((state) => ({
  143. currentSessionIndex: 0,
  144. sessions: [createEmptySession()].concat(state.sessions),
  145. }));
  146. },
  147. currentSession() {
  148. let index = get().currentSessionIndex;
  149. const sessions = get().sessions;
  150. if (index < 0 || index >= sessions.length) {
  151. index = Math.min(sessions.length - 1, Math.max(0, index));
  152. set(() => ({ currentSessionIndex: index }));
  153. }
  154. const session = sessions[index];
  155. return session;
  156. },
  157. onNewMessage(message) {
  158. get().updateCurrentSession((session) => {
  159. session.messages.push(message);
  160. });
  161. get().updateStat(message);
  162. get().summarizeSession();
  163. },
  164. async onUserInput(content) {
  165. const message: Message = {
  166. role: "user",
  167. content,
  168. date: new Date().toLocaleString(),
  169. };
  170. // get last five messges
  171. const messages = get().currentSession().messages.concat(message);
  172. get().onNewMessage(message);
  173. const botMessage: Message = {
  174. content: "",
  175. role: "assistant",
  176. date: new Date().toLocaleString(),
  177. streaming: true,
  178. };
  179. get().updateCurrentSession((session) => {
  180. session.messages.push(botMessage);
  181. });
  182. const fiveMessages = messages.slice(-5);
  183. requestChatStream(fiveMessages, {
  184. onMessage(content, done) {
  185. if (done) {
  186. botMessage.streaming = false;
  187. get().updateStat(botMessage);
  188. get().summarizeSession();
  189. } else {
  190. botMessage.content = content;
  191. set(() => ({}));
  192. }
  193. },
  194. onError(error) {
  195. botMessage.content = "出错了,稍后重试吧";
  196. botMessage.streaming = false;
  197. set(() => ({}));
  198. },
  199. filterBot: !get().config.sendBotMessages,
  200. });
  201. },
  202. updateMessage(
  203. sessionIndex: number,
  204. messageIndex: number,
  205. updater: (message?: Message) => void
  206. ) {
  207. const sessions = get().sessions;
  208. const session = sessions.at(sessionIndex);
  209. const messages = session?.messages;
  210. updater(messages?.at(messageIndex));
  211. set(() => ({ sessions }));
  212. },
  213. onBotResponse(message) {
  214. get().onNewMessage(message);
  215. },
  216. summarizeSession() {
  217. const session = get().currentSession();
  218. if (session.topic !== DEFAULT_TOPIC) return;
  219. requestWithPrompt(
  220. session.messages,
  221. "简明扼要地 10 字以内总结主题"
  222. ).then((res) => {
  223. get().updateCurrentSession(
  224. (session) => (session.topic = trimTopic(res))
  225. );
  226. });
  227. },
  228. updateStat(message) {
  229. get().updateCurrentSession((session) => {
  230. session.stat.charCount += message.content.length;
  231. // TODO: should update chat count and word count
  232. });
  233. },
  234. updateCurrentSession(updater) {
  235. const sessions = get().sessions;
  236. const index = get().currentSessionIndex;
  237. updater(sessions[index]);
  238. set(() => ({ sessions }));
  239. },
  240. }),
  241. { name: "chat-next-web-store" }
  242. )
  243. );