Browse Source

Merge pull request #5173 from ConnectAI-E/feature/dalle

add dalle3 model
Dogtiti 1 year ago
parent
commit
fec80c6c51
9 changed files with 141 additions and 33 deletions
  1. 2 1
      app/client/api.ts
  2. 84 31
      app/client/platforms/openai.ts
  3. 35 0
      app/components/chat.tsx
  4. 6 1
      app/constant.ts
  5. 1 0
      app/icons/size.svg
  6. 5 0
      app/store/chat.ts
  7. 2 0
      app/store/config.ts
  8. 2 0
      app/typing.ts
  9. 4 0
      app/utils.ts

+ 2 - 1
app/client/api.ts

@@ -6,7 +6,7 @@ import {
   ServiceProvider,
 } from "../constant";
 import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
-import { ChatGPTApi } from "./platforms/openai";
+import { ChatGPTApi, DalleRequestPayload } from "./platforms/openai";
 import { GeminiProApi } from "./platforms/google";
 import { ClaudeApi } from "./platforms/anthropic";
 import { ErnieApi } from "./platforms/baidu";
@@ -42,6 +42,7 @@ export interface LLMConfig {
   stream?: boolean;
   presence_penalty?: number;
   frequency_penalty?: number;
+  size?: DalleRequestPayload["size"];
 }
 
 export interface ChatOptions {

+ 84 - 31
app/client/platforms/openai.ts

@@ -11,8 +11,13 @@ import {
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
 import { collectModelsWithDefaultModel } from "@/app/utils/model";
-import { preProcessImageContent } from "@/app/utils/chat";
+import {
+  preProcessImageContent,
+  uploadImage,
+  base64Image2Blob,
+} from "@/app/utils/chat";
 import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
+import { DalleSize } from "@/app/typing";
 
 import {
   ChatOptions,
@@ -33,6 +38,7 @@ import {
   getMessageTextContent,
   getMessageImages,
   isVisionModel,
+  isDalle3 as _isDalle3,
 } from "@/app/utils";
 
 export interface OpenAIListModelResponse {
@@ -58,6 +64,14 @@ export interface RequestPayload {
   max_tokens?: number;
 }
 
+export interface DalleRequestPayload {
+  model: string;
+  prompt: string;
+  response_format: "url" | "b64_json";
+  n: number;
+  size: DalleSize;
+}
+
 export class ChatGPTApi implements LLMApi {
   private disableListModels = true;
 
@@ -100,20 +114,31 @@ export class ChatGPTApi implements LLMApi {
     return cloudflareAIGatewayUrl([baseUrl, path].join("/"));
   }
 
-  extractMessage(res: any) {
-    return res.choices?.at(0)?.message?.content ?? "";
+  async extractMessage(res: any) {
+    if (res.error) {
+      return "```\n" + JSON.stringify(res, null, 4) + "\n```";
+    }
+    // dalle3 model return url, using url create image message
+    if (res.data) {
+      let url = res.data?.at(0)?.url ?? "";
+      const b64_json = res.data?.at(0)?.b64_json ?? "";
+      if (!url && b64_json) {
+        // uploadImage
+        url = await uploadImage(base64Image2Blob(b64_json, "image/png"));
+      }
+      return [
+        {
+          type: "image_url",
+          image_url: {
+            url,
+          },
+        },
+      ];
+    }
+    return res.choices?.at(0)?.message?.content ?? res;
   }
 
   async chat(options: ChatOptions) {
-    const visionModel = isVisionModel(options.config.model);
-    const messages: ChatOptions["messages"] = [];
-    for (const v of options.messages) {
-      const content = visionModel
-        ? await preProcessImageContent(v.content)
-        : getMessageTextContent(v);
-      messages.push({ role: v.role, content });
-    }
-
     const modelConfig = {
       ...useAppConfig.getState().modelConfig,
       ...useChatStore.getState().currentSession().mask.modelConfig,
@@ -123,26 +148,52 @@ export class ChatGPTApi implements LLMApi {
       },
     };
 
-    const requestPayload: RequestPayload = {
-      messages,
-      stream: options.config.stream,
-      model: modelConfig.model,
-      temperature: modelConfig.temperature,
-      presence_penalty: modelConfig.presence_penalty,
-      frequency_penalty: modelConfig.frequency_penalty,
-      top_p: modelConfig.top_p,
-      // max_tokens: Math.max(modelConfig.max_tokens, 1024),
-      // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
-    };
+    let requestPayload: RequestPayload | DalleRequestPayload;
+
+    const isDalle3 = _isDalle3(options.config.model);
+    if (isDalle3) {
+      const prompt = getMessageTextContent(
+        options.messages.slice(-1)?.pop() as any,
+      );
+      requestPayload = {
+        model: options.config.model,
+        prompt,
+        // URLs are only valid for 60 minutes after the image has been generated.
+        response_format: "b64_json", // using b64_json, and save image in CacheStorage
+        n: 1,
+        size: options.config?.size ?? "1024x1024",
+      };
+    } else {
+      const visionModel = isVisionModel(options.config.model);
+      const messages: ChatOptions["messages"] = [];
+      for (const v of options.messages) {
+        const content = visionModel
+          ? await preProcessImageContent(v.content)
+          : getMessageTextContent(v);
+        messages.push({ role: v.role, content });
+      }
 
-    // add max_tokens to vision model
-    if (visionModel && modelConfig.model.includes("preview")) {
-      requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
+      requestPayload = {
+        messages,
+        stream: options.config.stream,
+        model: modelConfig.model,
+        temperature: modelConfig.temperature,
+        presence_penalty: modelConfig.presence_penalty,
+        frequency_penalty: modelConfig.frequency_penalty,
+        top_p: modelConfig.top_p,
+        // max_tokens: Math.max(modelConfig.max_tokens, 1024),
+        // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
+      };
+
+      // add max_tokens to vision model
+      if (visionModel && modelConfig.model.includes("preview")) {
+        requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
+      }
     }
 
     console.log("[Request] openai payload: ", requestPayload);
 
-    const shouldStream = !!options.config.stream;
+    const shouldStream = !isDalle3 && !!options.config.stream;
     const controller = new AbortController();
     options.onController?.(controller);
 
@@ -168,13 +219,15 @@ export class ChatGPTApi implements LLMApi {
             model?.provider?.providerName === ServiceProvider.Azure,
         );
         chatPath = this.path(
-          Azure.ChatPath(
+          (isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
             (model?.displayName ?? model?.name) as string,
             useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
           ),
         );
       } else {
-        chatPath = this.path(OpenaiPath.ChatPath);
+        chatPath = this.path(
+          isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
+        );
       }
       const chatPayload = {
         method: "POST",
@@ -186,7 +239,7 @@ export class ChatGPTApi implements LLMApi {
       // make a fetch request
       const requestTimeoutId = setTimeout(
         () => controller.abort(),
-        REQUEST_TIMEOUT_MS,
+        isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow.
       );
 
       if (shouldStream) {
@@ -317,7 +370,7 @@ export class ChatGPTApi implements LLMApi {
         clearTimeout(requestTimeoutId);
 
         const resJson = await res.json();
-        const message = this.extractMessage(resJson);
+        const message = await this.extractMessage(resJson);
         options.onFinish(message);
       }
     } catch (e) {

+ 35 - 0
app/components/chat.tsx

@@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg";
 import BottomIcon from "../icons/bottom.svg";
 import StopIcon from "../icons/pause.svg";
 import RobotIcon from "../icons/robot.svg";
+import SizeIcon from "../icons/size.svg";
 import PluginIcon from "../icons/plugin.svg";
 
 import {
@@ -60,6 +61,7 @@ import {
   getMessageTextContent,
   getMessageImages,
   isVisionModel,
+  isDalle3,
 } from "../utils";
 
 import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
@@ -67,6 +69,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
 import dynamic from "next/dynamic";
 
 import { ChatControllerPool } from "../client/controller";
+import { DalleSize } from "../typing";
 import { Prompt, usePromptStore } from "../store/prompt";
 import Locale from "../locales";
 
@@ -481,6 +484,11 @@ export function ChatActions(props: {
   const [showPluginSelector, setShowPluginSelector] = useState(false);
   const [showUploadImage, setShowUploadImage] = useState(false);
 
+  const [showSizeSelector, setShowSizeSelector] = useState(false);
+  const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
+  const currentSize =
+    chatStore.currentSession().mask.modelConfig?.size ?? "1024x1024";
+
   useEffect(() => {
     const show = isVisionModel(currentModel);
     setShowUploadImage(show);
@@ -624,6 +632,33 @@ export function ChatActions(props: {
         />
       )}
 
+      {isDalle3(currentModel) && (
+        <ChatAction
+          onClick={() => setShowSizeSelector(true)}
+          text={currentSize}
+          icon={<SizeIcon />}
+        />
+      )}
+
+      {showSizeSelector && (
+        <Selector
+          defaultSelectedValue={currentSize}
+          items={dalle3Sizes.map((m) => ({
+            title: m,
+            value: m,
+          }))}
+          onClose={() => setShowSizeSelector(false)}
+          onSelection={(s) => {
+            if (s.length === 0) return;
+            const size = s[0];
+            chatStore.updateCurrentSession((session) => {
+              session.mask.modelConfig.size = size;
+            });
+            showToast(size);
+          }}
+        />
+      )}
+
       <ChatAction
         onClick={() => setShowPluginSelector(true)}
         text={Locale.Plugin.Name}

+ 6 - 1
app/constant.ts

@@ -146,6 +146,7 @@ export const Anthropic = {
 
 export const OpenaiPath = {
   ChatPath: "v1/chat/completions",
+  ImagePath: "v1/images/generations",
   UsagePath: "dashboard/billing/usage",
   SubsPath: "dashboard/billing/subscription",
   ListModelPath: "v1/models",
@@ -154,7 +155,10 @@ export const OpenaiPath = {
 export const Azure = {
   ChatPath: (deployName: string, apiVersion: string) =>
     `deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
-  ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
+  // https://<your_resource_name>.openai.azure.com/openai/deployments/<your_deployment_name>/images/generations?api-version=<api_version>
+  ImagePath: (deployName: string, apiVersion: string) =>
+    `deployments/${deployName}/images/generations?api-version=${apiVersion}`,
+  ExampleEndpoint: "https://{resource-url}/openai",
 };
 
 export const Google = {
@@ -256,6 +260,7 @@ const openaiModels = [
   "gpt-4-vision-preview",
   "gpt-4-turbo-2024-04-09",
   "gpt-4-1106-preview",
+  "dall-e-3",
 ];
 
 const googleModels = [

+ 1 - 0
app/icons/size.svg

@@ -0,0 +1 @@
+<?xml version="1.0" encoding="UTF-8"?><svg width="16" height="16" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M42 7H6C4.89543 7 4 7.89543 4 9V39C4 40.1046 4.89543 41 6 41H42C43.1046 41 44 40.1046 44 39V9C44 7.89543 43.1046 7 42 7Z" fill="none" stroke="#333" stroke-width="4"/><path d="M30 30V18L38 30V18" stroke="#333" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/><path d="M10 30V18L18 30V18" stroke="#333" stroke-width="4" stroke-linecap="round" stroke-linejoin="round"/><path d="M24 20V21" stroke="#333" stroke-width="4" stroke-linecap="round"/><path d="M24 27V28" stroke="#333" stroke-width="4" stroke-linecap="round"/></svg>

+ 5 - 0
app/store/chat.ts

@@ -26,6 +26,7 @@ import { nanoid } from "nanoid";
 import { createPersistStore } from "../utils/store";
 import { collectModelsWithDefaultModel } from "../utils/model";
 import { useAccessStore } from "./access";
+import { isDalle3 } from "../utils";
 
 export type ChatMessage = RequestMessage & {
   date: string;
@@ -541,6 +542,10 @@ export const useChatStore = createPersistStore(
         const config = useAppConfig.getState();
         const session = get().currentSession();
         const modelConfig = session.mask.modelConfig;
+        // skip summarize when using dalle3?
+        if (isDalle3(modelConfig.model)) {
+          return;
+        }
 
         const api: ClientApi = getClientApi(modelConfig.providerName);
 

+ 2 - 0
app/store/config.ts

@@ -1,4 +1,5 @@
 import { LLMModel } from "../client/api";
+import { DalleSize } from "../typing";
 import { getClientConfig } from "../config/client";
 import {
   DEFAULT_INPUT_TEMPLATE,
@@ -61,6 +62,7 @@ export const DEFAULT_CONFIG = {
     compressMessageLengthThreshold: 1000,
     enableInjectSystemPrompts: true,
     template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
+    size: "1024x1024" as DalleSize,
   },
 };
 

+ 2 - 0
app/typing.ts

@@ -7,3 +7,5 @@ export interface RequestMessage {
   role: MessageRole;
   content: string;
 }
+
+export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";

+ 4 - 0
app/utils.ts

@@ -266,3 +266,7 @@ export function isVisionModel(model: string) {
     visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo
   );
 }
+
+export function isDalle3(model: string) {
+  return "dall-e-3" === model;
+}