openai.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. "use client";
  2. // azure and openai, using same models. so using same LLMApi.
  3. import {
  4. ApiPath,
  5. OPENAI_BASE_URL,
  6. DEFAULT_MODELS,
  7. OpenaiPath,
  8. Azure,
  9. REQUEST_TIMEOUT_MS,
  10. ServiceProvider,
  11. } from "@/app/constant";
  12. import {
  13. ChatMessageTool,
  14. useAccessStore,
  15. useAppConfig,
  16. useChatStore,
  17. usePluginStore,
  18. } from "@/app/store";
  19. import { collectModelsWithDefaultModel } from "@/app/utils/model";
  20. import {
  21. preProcessImageContent,
  22. uploadImage,
  23. base64Image2Blob,
  24. stream,
  25. } from "@/app/utils/chat";
  26. import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
  27. import { ModelSize, DalleQuality, DalleStyle } from "@/app/typing";
  28. import {
  29. ChatOptions,
  30. getHeaders,
  31. LLMApi,
  32. LLMModel,
  33. LLMUsage,
  34. MultimodalContent,
  35. SpeechOptions,
  36. } from "../api";
  37. import Locale from "../../locales";
  38. import { getClientConfig } from "@/app/config/client";
  39. import {
  40. getMessageTextContent,
  41. isVisionModel,
  42. isDalle3 as _isDalle3,
  43. } from "@/app/utils";
  44. import { fetch } from "@/app/utils/stream";
  45. export interface OpenAIListModelResponse {
  46. object: string;
  47. data: Array<{
  48. id: string;
  49. object: string;
  50. root: string;
  51. }>;
  52. }
  53. export interface RequestPayload {
  54. messages: {
  55. role: "system" | "user" | "assistant";
  56. content: string | MultimodalContent[];
  57. }[];
  58. stream?: boolean;
  59. model: string;
  60. temperature: number;
  61. presence_penalty: number;
  62. frequency_penalty: number;
  63. top_p: number;
  64. max_tokens?: number;
  65. max_completion_tokens?: number;
  66. }
  67. export interface DalleRequestPayload {
  68. model: string;
  69. prompt: string;
  70. response_format: "url" | "b64_json";
  71. n: number;
  72. size: ModelSize;
  73. quality: DalleQuality;
  74. style: DalleStyle;
  75. }
  76. export class ChatGPTApi implements LLMApi {
  77. private disableListModels = true;
  78. path(path: string): string {
  79. const accessStore = useAccessStore.getState();
  80. let baseUrl = "";
  81. const isAzure = path.includes("deployments");
  82. if (accessStore.useCustomConfig) {
  83. if (isAzure && !accessStore.isValidAzure()) {
  84. throw Error(
  85. "incomplete azure config, please check it in your settings page",
  86. );
  87. }
  88. baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl;
  89. }
  90. if (baseUrl.length === 0) {
  91. const isApp = !!getClientConfig()?.isApp;
  92. const apiPath = isAzure ? ApiPath.Azure : ApiPath.OpenAI;
  93. baseUrl = isApp ? OPENAI_BASE_URL : apiPath;
  94. }
  95. if (baseUrl.endsWith("/")) {
  96. baseUrl = baseUrl.slice(0, baseUrl.length - 1);
  97. }
  98. if (
  99. !baseUrl.startsWith("http") &&
  100. !isAzure &&
  101. !baseUrl.startsWith(ApiPath.OpenAI)
  102. ) {
  103. baseUrl = "https://" + baseUrl;
  104. }
  105. console.log("[Proxy Endpoint] ", baseUrl, path);
  106. // try rebuild url, when using cloudflare ai gateway in client
  107. return cloudflareAIGatewayUrl([baseUrl, path].join("/"));
  108. }
  109. async extractMessage(res: any) {
  110. if (res.error) {
  111. return "```\n" + JSON.stringify(res, null, 4) + "\n```";
  112. }
  113. // dalle3 model return url, using url create image message
  114. if (res.data) {
  115. let url = res.data?.at(0)?.url ?? "";
  116. const b64_json = res.data?.at(0)?.b64_json ?? "";
  117. if (!url && b64_json) {
  118. // uploadImage
  119. url = await uploadImage(base64Image2Blob(b64_json, "image/png"));
  120. }
  121. return [
  122. {
  123. type: "image_url",
  124. image_url: {
  125. url,
  126. },
  127. },
  128. ];
  129. }
  130. return res.choices?.at(0)?.message?.content ?? res;
  131. }
  132. async speech(options: SpeechOptions): Promise<ArrayBuffer> {
  133. const requestPayload = {
  134. model: options.model,
  135. input: options.input,
  136. voice: options.voice,
  137. response_format: options.response_format,
  138. speed: options.speed,
  139. };
  140. console.log("[Request] openai speech payload: ", requestPayload);
  141. const controller = new AbortController();
  142. options.onController?.(controller);
  143. try {
  144. const speechPath = this.path(OpenaiPath.SpeechPath);
  145. const speechPayload = {
  146. method: "POST",
  147. body: JSON.stringify(requestPayload),
  148. signal: controller.signal,
  149. headers: getHeaders(),
  150. };
  151. // make a fetch request
  152. const requestTimeoutId = setTimeout(
  153. () => controller.abort(),
  154. REQUEST_TIMEOUT_MS,
  155. );
  156. const res = await fetch(speechPath, speechPayload);
  157. clearTimeout(requestTimeoutId);
  158. return await res.arrayBuffer();
  159. } catch (e) {
  160. console.log("[Request] failed to make a speech request", e);
  161. throw e;
  162. }
  163. }
  164. async chat(options: ChatOptions) {
  165. const modelConfig = {
  166. ...useAppConfig.getState().modelConfig,
  167. ...useChatStore.getState().currentSession().mask.modelConfig,
  168. ...{
  169. model: options.config.model,
  170. providerName: options.config.providerName,
  171. },
  172. };
  173. let requestPayload: RequestPayload | DalleRequestPayload;
  174. const isDalle3 = _isDalle3(options.config.model);
  175. const isO1 = options.config.model.startsWith("o1");
  176. if (isDalle3) {
  177. const prompt = getMessageTextContent(
  178. options.messages.slice(-1)?.pop() as any,
  179. );
  180. requestPayload = {
  181. model: options.config.model,
  182. prompt,
  183. // URLs are only valid for 60 minutes after the image has been generated.
  184. response_format: "b64_json", // using b64_json, and save image in CacheStorage
  185. n: 1,
  186. size: options.config?.size ?? "1024x1024",
  187. quality: options.config?.quality ?? "standard",
  188. style: options.config?.style ?? "vivid",
  189. };
  190. } else {
  191. const visionModel = isVisionModel(options.config.model);
  192. const messages: ChatOptions["messages"] = [];
  193. for (const v of options.messages) {
  194. const content = visionModel
  195. ? await preProcessImageContent(v.content)
  196. : getMessageTextContent(v);
  197. if (!(isO1 && v.role === "system"))
  198. messages.push({ role: v.role, content });
  199. }
  200. // O1 not support image, tools (plugin in ChatGPTNextWeb) and system, stream, logprobs, temperature, top_p, n, presence_penalty, frequency_penalty yet.
  201. requestPayload = {
  202. messages,
  203. stream: options.config.stream,
  204. model: modelConfig.model,
  205. temperature: !isO1 ? modelConfig.temperature : 1,
  206. presence_penalty: !isO1 ? modelConfig.presence_penalty : 0,
  207. frequency_penalty: !isO1 ? modelConfig.frequency_penalty : 0,
  208. top_p: !isO1 ? modelConfig.top_p : 1,
  209. // max_tokens: Math.max(modelConfig.max_tokens, 1024),
  210. // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
  211. };
  212. // O1 使用 max_completion_tokens 控制token数 (https://platform.openai.com/docs/guides/reasoning#controlling-costs)
  213. if (isO1) {
  214. requestPayload["max_completion_tokens"] = modelConfig.max_tokens;
  215. }
  216. // add max_tokens to vision model
  217. if (visionModel) {
  218. requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
  219. }
  220. }
  221. console.log("[Request] openai payload: ", requestPayload);
  222. const shouldStream = !isDalle3 && !!options.config.stream;
  223. const controller = new AbortController();
  224. options.onController?.(controller);
  225. try {
  226. let chatPath = "";
  227. if (modelConfig.providerName === ServiceProvider.Azure) {
  228. // find model, and get displayName as deployName
  229. const { models: configModels, customModels: configCustomModels } =
  230. useAppConfig.getState();
  231. const {
  232. defaultModel,
  233. customModels: accessCustomModels,
  234. useCustomConfig,
  235. } = useAccessStore.getState();
  236. const models = collectModelsWithDefaultModel(
  237. configModels,
  238. [configCustomModels, accessCustomModels].join(","),
  239. defaultModel,
  240. );
  241. const model = models.find(
  242. (model) =>
  243. model.name === modelConfig.model &&
  244. model?.provider?.providerName === ServiceProvider.Azure,
  245. );
  246. chatPath = this.path(
  247. (isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
  248. (model?.displayName ?? model?.name) as string,
  249. useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
  250. ),
  251. );
  252. } else {
  253. chatPath = this.path(
  254. isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
  255. );
  256. }
  257. if (shouldStream) {
  258. let index = -1;
  259. const [tools, funcs] = usePluginStore
  260. .getState()
  261. .getAsTools(
  262. useChatStore.getState().currentSession().mask?.plugin || [],
  263. );
  264. // console.log("getAsTools", tools, funcs);
  265. stream(
  266. chatPath,
  267. requestPayload,
  268. getHeaders(),
  269. tools as any,
  270. funcs,
  271. controller,
  272. // parseSSE
  273. (text: string, runTools: ChatMessageTool[]) => {
  274. // console.log("parseSSE", text, runTools);
  275. const json = JSON.parse(text);
  276. const choices = json.choices as Array<{
  277. delta: {
  278. content: string;
  279. tool_calls: ChatMessageTool[];
  280. };
  281. }>;
  282. const tool_calls = choices[0]?.delta?.tool_calls;
  283. if (tool_calls?.length > 0) {
  284. const id = tool_calls[0]?.id;
  285. const args = tool_calls[0]?.function?.arguments;
  286. if (id) {
  287. index += 1;
  288. runTools.push({
  289. id,
  290. type: tool_calls[0]?.type,
  291. function: {
  292. name: tool_calls[0]?.function?.name as string,
  293. arguments: args,
  294. },
  295. });
  296. } else {
  297. // @ts-ignore
  298. runTools[index]["function"]["arguments"] += args;
  299. }
  300. }
  301. return choices[0]?.delta?.content;
  302. },
  303. // processToolMessage, include tool_calls message and tool call results
  304. (
  305. requestPayload: RequestPayload,
  306. toolCallMessage: any,
  307. toolCallResult: any[],
  308. ) => {
  309. // reset index value
  310. index = -1;
  311. // @ts-ignore
  312. requestPayload?.messages?.splice(
  313. // @ts-ignore
  314. requestPayload?.messages?.length,
  315. 0,
  316. toolCallMessage,
  317. ...toolCallResult,
  318. );
  319. },
  320. options,
  321. );
  322. } else {
  323. const chatPayload = {
  324. method: "POST",
  325. body: JSON.stringify(requestPayload),
  326. signal: controller.signal,
  327. headers: getHeaders(),
  328. };
  329. // make a fetch request
  330. const requestTimeoutId = setTimeout(
  331. () => controller.abort(),
  332. isDalle3 || isO1 ? REQUEST_TIMEOUT_MS * 4 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
  333. );
  334. const res = await fetch(chatPath, chatPayload);
  335. clearTimeout(requestTimeoutId);
  336. const resJson = await res.json();
  337. const message = await this.extractMessage(resJson);
  338. options.onFinish(message, res);
  339. }
  340. } catch (e) {
  341. console.log("[Request] failed to make a chat request", e);
  342. options.onError?.(e as Error);
  343. }
  344. }
  345. async usage() {
  346. const formatDate = (d: Date) =>
  347. `${d.getFullYear()}-${(d.getMonth() + 1).toString().padStart(2, "0")}-${d
  348. .getDate()
  349. .toString()
  350. .padStart(2, "0")}`;
  351. const ONE_DAY = 1 * 24 * 60 * 60 * 1000;
  352. const now = new Date();
  353. const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1);
  354. const startDate = formatDate(startOfMonth);
  355. const endDate = formatDate(new Date(Date.now() + ONE_DAY));
  356. const [used, subs] = await Promise.all([
  357. fetch(
  358. this.path(
  359. `${OpenaiPath.UsagePath}?start_date=${startDate}&end_date=${endDate}`,
  360. ),
  361. {
  362. method: "GET",
  363. headers: getHeaders(),
  364. },
  365. ),
  366. fetch(this.path(OpenaiPath.SubsPath), {
  367. method: "GET",
  368. headers: getHeaders(),
  369. }),
  370. ]);
  371. if (used.status === 401) {
  372. throw new Error(Locale.Error.Unauthorized);
  373. }
  374. if (!used.ok || !subs.ok) {
  375. throw new Error("Failed to query usage from openai");
  376. }
  377. const response = (await used.json()) as {
  378. total_usage?: number;
  379. error?: {
  380. type: string;
  381. message: string;
  382. };
  383. };
  384. const total = (await subs.json()) as {
  385. hard_limit_usd?: number;
  386. };
  387. if (response.error && response.error.type) {
  388. throw Error(response.error.message);
  389. }
  390. if (response.total_usage) {
  391. response.total_usage = Math.round(response.total_usage) / 100;
  392. }
  393. if (total.hard_limit_usd) {
  394. total.hard_limit_usd = Math.round(total.hard_limit_usd * 100) / 100;
  395. }
  396. return {
  397. used: response.total_usage,
  398. total: total.hard_limit_usd,
  399. } as LLMUsage;
  400. }
  401. async models(): Promise<LLMModel[]> {
  402. if (this.disableListModels) {
  403. return DEFAULT_MODELS.slice();
  404. }
  405. const res = await fetch(this.path(OpenaiPath.ListModelPath), {
  406. method: "GET",
  407. headers: {
  408. ...getHeaders(),
  409. },
  410. });
  411. const resJson = (await res.json()) as OpenAIListModelResponse;
  412. const chatModels = resJson.data?.filter(
  413. (m) => m.id.startsWith("gpt-") || m.id.startsWith("chatgpt-"),
  414. );
  415. console.log("[Models]", chatModels);
  416. if (!chatModels) {
  417. return [];
  418. }
  419. //由于目前 OpenAI 的 disableListModels 默认为 true,所以当前实际不会运行到这场
  420. let seq = 1000; //同 Constant.ts 中的排序保持一致
  421. return chatModels.map((m) => ({
  422. name: m.id,
  423. available: true,
  424. sorted: seq++,
  425. provider: {
  426. id: "openai",
  427. providerName: "OpenAI",
  428. providerType: "openai",
  429. sorted: 1,
  430. },
  431. }));
  432. }
  433. }
  434. export { OpenaiPath };