소스 검색

add dalle3 model

lloydzhou 1 년 전
부모
커밋
ac599aa47c
5개의 변경된 파일113개의 추가작업 그리고 27개의 파일을 삭제
  1. 64 26
      app/client/platforms/openai.ts
  2. 34 0
      app/components/chat.tsx
  3. 6 1
      app/constant.ts
  4. 5 0
      app/store/chat.ts
  5. 4 0
      app/utils.ts

+ 64 - 26
app/client/platforms/openai.ts

@@ -33,6 +33,7 @@ import {
   getMessageTextContent,
   getMessageImages,
   isVisionModel,
+  isDalle3 as _isDalle3,
 } from "@/app/utils";
 
 export interface OpenAIListModelResponse {
@@ -58,6 +59,13 @@ export interface RequestPayload {
   max_tokens?: number;
 }
 
+export interface DalleRequestPayload {
+  model: string;
+  prompt: string;
+  n: number;
+  size: "1024x1024" | "1792x1024" | "1024x1792";
+}
+
 export class ChatGPTApi implements LLMApi {
   private disableListModels = true;
 
@@ -101,19 +109,25 @@ export class ChatGPTApi implements LLMApi {
   }
 
   extractMessage(res: any) {
+    if (res.error) {
+      return "```\n" + JSON.stringify(res, null, 4) + "\n```";
+    }
+    // dalle3 model return url, just return
+    if (res.data) {
+      const url = res.data?.at(0)?.url ?? "";
+      return [
+        {
+          type: "image_url",
+          image_url: {
+            url,
+          },
+        },
+      ];
+    }
     return res.choices?.at(0)?.message?.content ?? "";
   }
 
   async chat(options: ChatOptions) {
-    const visionModel = isVisionModel(options.config.model);
-    const messages: ChatOptions["messages"] = [];
-    for (const v of options.messages) {
-      const content = visionModel
-        ? await preProcessImageContent(v.content)
-        : getMessageTextContent(v);
-      messages.push({ role: v.role, content });
-    }
-
     const modelConfig = {
       ...useAppConfig.getState().modelConfig,
       ...useChatStore.getState().currentSession().mask.modelConfig,
@@ -123,26 +137,48 @@ export class ChatGPTApi implements LLMApi {
       },
     };
 
-    const requestPayload: RequestPayload = {
-      messages,
-      stream: options.config.stream,
-      model: modelConfig.model,
-      temperature: modelConfig.temperature,
-      presence_penalty: modelConfig.presence_penalty,
-      frequency_penalty: modelConfig.frequency_penalty,
-      top_p: modelConfig.top_p,
-      // max_tokens: Math.max(modelConfig.max_tokens, 1024),
-      // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
-    };
+    let requestPayload: RequestPayload | DalleRequestPayload;
+
+    const isDalle3 = _isDalle3(options.config.model);
+    if (isDalle3) {
+      const prompt = getMessageTextContent(options.messages.slice(-1)?.pop());
+      requestPayload = {
+        model: options.config.model,
+        prompt,
+        n: 1,
+        size: options.config?.size ?? "1024x1024",
+      };
+    } else {
+      const visionModel = isVisionModel(options.config.model);
+      const messages: ChatOptions["messages"] = [];
+      for (const v of options.messages) {
+        const content = visionModel
+          ? await preProcessImageContent(v.content)
+          : getMessageTextContent(v);
+        messages.push({ role: v.role, content });
+      }
 
-    // add max_tokens to vision model
-    if (visionModel && modelConfig.model.includes("preview")) {
-      requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
+      requestPayload = {
+        messages,
+        stream: options.config.stream,
+        model: modelConfig.model,
+        temperature: modelConfig.temperature,
+        presence_penalty: modelConfig.presence_penalty,
+        frequency_penalty: modelConfig.frequency_penalty,
+        top_p: modelConfig.top_p,
+        // max_tokens: Math.max(modelConfig.max_tokens, 1024),
+        // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
+      };
+
+      // add max_tokens to vision model
+      if (visionModel && modelConfig.model.includes("preview")) {
+        requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
+      }
     }
 
     console.log("[Request] openai payload: ", requestPayload);
 
-    const shouldStream = !!options.config.stream;
+    const shouldStream = !isDalle3 && !!options.config.stream;
     const controller = new AbortController();
     options.onController?.(controller);
 
@@ -168,13 +204,15 @@ export class ChatGPTApi implements LLMApi {
             model?.provider?.providerName === ServiceProvider.Azure,
         );
         chatPath = this.path(
-          Azure.ChatPath(
+          (isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
             (model?.displayName ?? model?.name) as string,
             useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
           ),
         );
       } else {
-        chatPath = this.path(OpenaiPath.ChatPath);
+        chatPath = this.path(
+          isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
+        );
       }
       const chatPayload = {
         method: "POST",

+ 34 - 0
app/components/chat.tsx

@@ -37,6 +37,7 @@ import AutoIcon from "../icons/auto.svg";
 import BottomIcon from "../icons/bottom.svg";
 import StopIcon from "../icons/pause.svg";
 import RobotIcon from "../icons/robot.svg";
+import SizeIcon from "../icons/size.svg";
 import PluginIcon from "../icons/plugin.svg";
 
 import {
@@ -60,6 +61,7 @@ import {
   getMessageTextContent,
   getMessageImages,
   isVisionModel,
+  isDalle3,
 } from "../utils";
 
 import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
@@ -481,6 +483,11 @@ export function ChatActions(props: {
   const [showPluginSelector, setShowPluginSelector] = useState(false);
   const [showUploadImage, setShowUploadImage] = useState(false);
 
+  const [showSizeSelector, setShowSizeSelector] = useState(false);
+  const dalle3Sizes = ["1024x1024", "1792x1024", "1024x1792"];
+  const currentSize =
+    chatStore.currentSession().mask.modelConfig?.size || "1024x1024";
+
   useEffect(() => {
     const show = isVisionModel(currentModel);
     setShowUploadImage(show);
@@ -624,6 +631,33 @@ export function ChatActions(props: {
         />
       )}
 
+      {isDalle3(currentModel) && (
+        <ChatAction
+          onClick={() => setShowSizeSelector(true)}
+          text={currentSize}
+          icon={<SizeIcon />}
+        />
+      )}
+
+      {showSizeSelector && (
+        <Selector
+          defaultSelectedValue={currentSize}
+          items={dalle3Sizes.map((m) => ({
+            title: m,
+            value: m,
+          }))}
+          onClose={() => setShowSizeSelector(false)}
+          onSelection={(s) => {
+            if (s.length === 0) return;
+            const size = s[0];
+            chatStore.updateCurrentSession((session) => {
+              session.mask.modelConfig.size = size;
+            });
+            showToast(size);
+          }}
+        />
+      )}
+
       <ChatAction
         onClick={() => setShowPluginSelector(true)}
         text={Locale.Plugin.Name}

+ 6 - 1
app/constant.ts

@@ -146,6 +146,7 @@ export const Anthropic = {
 
 export const OpenaiPath = {
   ChatPath: "v1/chat/completions",
+  ImagePath: "v1/images/generations",
   UsagePath: "dashboard/billing/usage",
   SubsPath: "dashboard/billing/subscription",
   ListModelPath: "v1/models",
@@ -154,7 +155,10 @@ export const OpenaiPath = {
 export const Azure = {
   ChatPath: (deployName: string, apiVersion: string) =>
     `deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
-  ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
+  // https://<your_resource_name>.openai.azure.com/openai/deployments/<your_deployment_name>/images/generations?api-version=<api_version>
+  ImagePath: (deployName: string, apiVersion: string) =>
+    `deployments/${deployName}/images/generations?api-version=${apiVersion}`,
+  ExampleEndpoint: "https://{resource-url}/openai",
 };
 
 export const Google = {
@@ -256,6 +260,7 @@ const openaiModels = [
   "gpt-4-vision-preview",
   "gpt-4-turbo-2024-04-09",
   "gpt-4-1106-preview",
+  "dall-e-3",
 ];
 
 const googleModels = [

+ 5 - 0
app/store/chat.ts

@@ -26,6 +26,7 @@ import { nanoid } from "nanoid";
 import { createPersistStore } from "../utils/store";
 import { collectModelsWithDefaultModel } from "../utils/model";
 import { useAccessStore } from "./access";
+import { isDalle3 } from "../utils";
 
 export type ChatMessage = RequestMessage & {
   date: string;
@@ -541,6 +542,10 @@ export const useChatStore = createPersistStore(
         const config = useAppConfig.getState();
         const session = get().currentSession();
         const modelConfig = session.mask.modelConfig;
+        // skip summarize when using dalle3?
+        if (isDalle3(modelConfig.model)) {
+          return;
+        }
 
         const api: ClientApi = getClientApi(modelConfig.providerName);
 

+ 4 - 0
app/utils.ts

@@ -265,3 +265,7 @@ export function isVisionModel(model: string) {
     visionKeywords.some((keyword) => model.includes(keyword)) || isGpt4Turbo
   );
 }
+
+export function isDalle3(model: string) {
+  return "dall-e-3" === model;
+}