glm.ts 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. "use client";
  2. import {
  3. ApiPath,
  4. CHATGLM_BASE_URL,
  5. ChatGLM,
  6. REQUEST_TIMEOUT_MS,
  7. } from "@/app/constant";
  8. import {
  9. useAccessStore,
  10. useAppConfig,
  11. useChatStore,
  12. ChatMessageTool,
  13. usePluginStore,
  14. } from "@/app/store";
  15. import { stream } from "@/app/utils/chat";
  16. import {
  17. ChatOptions,
  18. getHeaders,
  19. LLMApi,
  20. LLMModel,
  21. SpeechOptions,
  22. } from "../api";
  23. import { getClientConfig } from "@/app/config/client";
  24. import { getMessageTextContent, isVisionModel } from "@/app/utils";
  25. import { RequestPayload } from "./openai";
  26. import { fetch } from "@/app/utils/stream";
  27. import { preProcessImageContent } from "@/app/utils/chat";
  28. interface BasePayload {
  29. model: string;
  30. }
  31. interface ChatPayload extends BasePayload {
  32. messages: ChatOptions["messages"];
  33. stream?: boolean;
  34. temperature?: number;
  35. presence_penalty?: number;
  36. frequency_penalty?: number;
  37. top_p?: number;
  38. }
  39. interface ImageGenerationPayload extends BasePayload {
  40. prompt: string;
  41. size?: string;
  42. user_id?: string;
  43. }
  44. interface VideoGenerationPayload extends BasePayload {
  45. prompt: string;
  46. duration?: number;
  47. resolution?: string;
  48. user_id?: string;
  49. }
  50. type ModelType = "chat" | "image" | "video";
  51. export class ChatGLMApi implements LLMApi {
  52. private disableListModels = true;
  53. private getModelType(model: string): ModelType {
  54. if (model.startsWith("cogview-")) return "image";
  55. if (model.startsWith("cogvideo-")) return "video";
  56. return "chat";
  57. }
  58. private getModelPath(type: ModelType): string {
  59. switch (type) {
  60. case "image":
  61. return ChatGLM.ImagePath;
  62. case "video":
  63. return ChatGLM.VideoPath;
  64. default:
  65. return ChatGLM.ChatPath;
  66. }
  67. }
  68. private createPayload(
  69. messages: ChatOptions["messages"],
  70. modelConfig: any,
  71. options: ChatOptions,
  72. ): BasePayload {
  73. const modelType = this.getModelType(modelConfig.model);
  74. const lastMessage = messages[messages.length - 1];
  75. const prompt =
  76. typeof lastMessage.content === "string"
  77. ? lastMessage.content
  78. : lastMessage.content.map((c) => c.text).join("\n");
  79. switch (modelType) {
  80. case "image":
  81. return {
  82. model: modelConfig.model,
  83. prompt,
  84. size: options.config.size,
  85. } as ImageGenerationPayload;
  86. default:
  87. return {
  88. messages,
  89. stream: options.config.stream,
  90. model: modelConfig.model,
  91. temperature: modelConfig.temperature,
  92. presence_penalty: modelConfig.presence_penalty,
  93. frequency_penalty: modelConfig.frequency_penalty,
  94. top_p: modelConfig.top_p,
  95. } as ChatPayload;
  96. }
  97. }
  98. private parseResponse(modelType: ModelType, json: any): string {
  99. switch (modelType) {
  100. case "image": {
  101. const imageUrl = json.data?.[0]?.url;
  102. return imageUrl ? `![Generated Image](${imageUrl})` : "";
  103. }
  104. case "video": {
  105. const videoUrl = json.data?.[0]?.url;
  106. return videoUrl ? `<video controls src="${videoUrl}"></video>` : "";
  107. }
  108. default:
  109. return this.extractMessage(json);
  110. }
  111. }
  112. path(path: string): string {
  113. const accessStore = useAccessStore.getState();
  114. let baseUrl = "";
  115. if (accessStore.useCustomConfig) {
  116. baseUrl = accessStore.chatglmUrl;
  117. }
  118. if (baseUrl.length === 0) {
  119. const isApp = !!getClientConfig()?.isApp;
  120. const apiPath = ApiPath.ChatGLM;
  121. baseUrl = isApp ? CHATGLM_BASE_URL : apiPath;
  122. }
  123. if (baseUrl.endsWith("/")) {
  124. baseUrl = baseUrl.slice(0, baseUrl.length - 1);
  125. }
  126. if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.ChatGLM)) {
  127. baseUrl = "https://" + baseUrl;
  128. }
  129. console.log("[Proxy Endpoint] ", baseUrl, path);
  130. return [baseUrl, path].join("/");
  131. }
  132. extractMessage(res: any) {
  133. return res.choices?.at(0)?.message?.content ?? "";
  134. }
  135. speech(options: SpeechOptions): Promise<ArrayBuffer> {
  136. throw new Error("Method not implemented.");
  137. }
  138. async chat(options: ChatOptions) {
  139. const visionModel = isVisionModel(options.config.model);
  140. const messages: ChatOptions["messages"] = [];
  141. for (const v of options.messages) {
  142. const content = visionModel
  143. ? await preProcessImageContent(v.content)
  144. : getMessageTextContent(v);
  145. messages.push({ role: v.role, content });
  146. }
  147. const modelConfig = {
  148. ...useAppConfig.getState().modelConfig,
  149. ...useChatStore.getState().currentSession().mask.modelConfig,
  150. ...{
  151. model: options.config.model,
  152. providerName: options.config.providerName,
  153. },
  154. };
  155. const modelType = this.getModelType(modelConfig.model);
  156. const requestPayload = this.createPayload(messages, modelConfig, options);
  157. const path = this.path(this.getModelPath(modelType));
  158. console.log(`[Request] glm ${modelType} payload: `, requestPayload);
  159. const controller = new AbortController();
  160. options.onController?.(controller);
  161. try {
  162. const chatPayload = {
  163. method: "POST",
  164. body: JSON.stringify(requestPayload),
  165. signal: controller.signal,
  166. headers: getHeaders(),
  167. };
  168. const requestTimeoutId = setTimeout(
  169. () => controller.abort(),
  170. REQUEST_TIMEOUT_MS,
  171. );
  172. if (modelType === "image" || modelType === "video") {
  173. const res = await fetch(path, chatPayload);
  174. clearTimeout(requestTimeoutId);
  175. const resJson = await res.json();
  176. console.log(`[Response] glm ${modelType}:`, resJson);
  177. const message = this.parseResponse(modelType, resJson);
  178. options.onFinish(message, res);
  179. return;
  180. }
  181. const shouldStream = !!options.config.stream;
  182. if (shouldStream) {
  183. const [tools, funcs] = usePluginStore
  184. .getState()
  185. .getAsTools(
  186. useChatStore.getState().currentSession().mask?.plugin || [],
  187. );
  188. return stream(
  189. path,
  190. requestPayload,
  191. getHeaders(),
  192. tools as any,
  193. funcs,
  194. controller,
  195. // parseSSE
  196. (text: string, runTools: ChatMessageTool[]) => {
  197. const json = JSON.parse(text);
  198. const choices = json.choices as Array<{
  199. delta: {
  200. content: string;
  201. tool_calls: ChatMessageTool[];
  202. };
  203. }>;
  204. const tool_calls = choices[0]?.delta?.tool_calls;
  205. if (tool_calls?.length > 0) {
  206. const index = tool_calls[0]?.index;
  207. const id = tool_calls[0]?.id;
  208. const args = tool_calls[0]?.function?.arguments;
  209. if (id) {
  210. runTools.push({
  211. id,
  212. type: tool_calls[0]?.type,
  213. function: {
  214. name: tool_calls[0]?.function?.name as string,
  215. arguments: args,
  216. },
  217. });
  218. } else {
  219. // @ts-ignore
  220. runTools[index]["function"]["arguments"] += args;
  221. }
  222. }
  223. return choices[0]?.delta?.content;
  224. },
  225. // processToolMessage
  226. (
  227. requestPayload: RequestPayload,
  228. toolCallMessage: any,
  229. toolCallResult: any[],
  230. ) => {
  231. // @ts-ignore
  232. requestPayload?.messages?.splice(
  233. // @ts-ignore
  234. requestPayload?.messages?.length,
  235. 0,
  236. toolCallMessage,
  237. ...toolCallResult,
  238. );
  239. },
  240. options,
  241. );
  242. } else {
  243. const res = await fetch(path, chatPayload);
  244. clearTimeout(requestTimeoutId);
  245. const resJson = await res.json();
  246. const message = this.extractMessage(resJson);
  247. options.onFinish(message, res);
  248. }
  249. } catch (e) {
  250. console.log("[Request] failed to make a chat request", e);
  251. options.onError?.(e as Error);
  252. }
  253. }
  254. async usage() {
  255. return {
  256. used: 0,
  257. total: 0,
  258. };
  259. }
  260. async models(): Promise<LLMModel[]> {
  261. return [];
  262. }
  263. }