소스 검색

using b64_json for dall-e-3

lloydzhou 1 년 전
부모
커밋
8c83fe23a1
1개의 변경된 파일18개의 추가작업 그리고 6개의 파일을 삭제
  1. 18 6
      app/client/platforms/openai.ts

+ 18 - 6
app/client/platforms/openai.ts

@@ -11,7 +11,11 @@ import {
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
 import { collectModelsWithDefaultModel } from "@/app/utils/model";
-import { preProcessImageContent } from "@/app/utils/chat";
+import {
+  preProcessImageContent,
+  uploadImage,
+  base64Image2Blob,
+} from "@/app/utils/chat";
 import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
 import { DalleSize } from "@/app/typing";
 
@@ -63,6 +67,7 @@ export interface RequestPayload {
 export interface DalleRequestPayload {
   model: string;
   prompt: string;
+  response_format: "url" | "b64_json";
   n: number;
   size: DalleSize;
 }
@@ -109,13 +114,18 @@ export class ChatGPTApi implements LLMApi {
     return cloudflareAIGatewayUrl([baseUrl, path].join("/"));
   }
 
-  extractMessage(res: any) {
+  async extractMessage(res: any) {
     if (res.error) {
       return "```\n" + JSON.stringify(res, null, 4) + "\n```";
     }
-    // dalle3 model return url, just return
+    // dalle3 model return url, using url create image message
     if (res.data) {
-      const url = res.data?.at(0)?.url ?? "";
+      let url = res.data?.at(0)?.url ?? "";
+      const b64_json = res.data?.at(0)?.b64_json ?? "";
+      if (!url && b64_json) {
+        // uploadImage
+        url = await uploadImage(base64Image2Blob(b64_json, "image/png"));
+      }
       return [
         {
           type: "image_url",
@@ -148,6 +158,8 @@ export class ChatGPTApi implements LLMApi {
       requestPayload = {
         model: options.config.model,
         prompt,
+        // URLs are only valid for 60 minutes after the image has been generated.
+        response_format: "b64_json", // using b64_json, and save image in CacheStorage
         n: 1,
         size: options.config?.size ?? "1024x1024",
       };
@@ -227,7 +239,7 @@ export class ChatGPTApi implements LLMApi {
       // make a fetch request
       const requestTimeoutId = setTimeout(
         () => controller.abort(),
-        REQUEST_TIMEOUT_MS,
+        isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
       );
 
       if (shouldStream) {
@@ -358,7 +370,7 @@ export class ChatGPTApi implements LLMApi {
         clearTimeout(requestTimeoutId);
 
         const resJson = await res.json();
-        const message = this.extractMessage(resJson);
+        const message = await this.extractMessage(resJson);
         options.onFinish(message);
       }
     } catch (e) {