|
|
@@ -33,6 +33,7 @@ import {
|
|
|
getMessageTextContent,
|
|
|
getMessageImages,
|
|
|
isVisionModel,
|
|
|
+ isDalle3 as _isDalle3,
|
|
|
} from "@/app/utils";
|
|
|
|
|
|
export interface OpenAIListModelResponse {
|
|
|
@@ -58,6 +59,13 @@ export interface RequestPayload {
|
|
|
max_tokens?: number;
|
|
|
}
|
|
|
|
|
|
+export interface DalleRequestPayload {
|
|
|
+ model: string;
|
|
|
+ prompt: string;
|
|
|
+ n: number;
|
|
|
+ size: "1024x1024" | "1792x1024" | "1024x1792";
|
|
|
+}
|
|
|
+
|
|
|
export class ChatGPTApi implements LLMApi {
|
|
|
private disableListModels = true;
|
|
|
|
|
|
@@ -101,19 +109,25 @@ export class ChatGPTApi implements LLMApi {
|
|
|
}
|
|
|
|
|
|
extractMessage(res: any) {
|
|
|
+ if (res.error) {
|
|
|
+ return "```\n" + JSON.stringify(res, null, 4) + "\n```";
|
|
|
+ }
|
|
|
+ // dalle3 model return url, just return
|
|
|
+ if (res.data) {
|
|
|
+ const url = res.data?.at(0)?.url ?? "";
|
|
|
+ return [
|
|
|
+ {
|
|
|
+ type: "image_url",
|
|
|
+ image_url: {
|
|
|
+ url,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ];
|
|
|
+ }
|
|
|
return res.choices?.at(0)?.message?.content ?? "";
|
|
|
}
|
|
|
|
|
|
async chat(options: ChatOptions) {
|
|
|
- const visionModel = isVisionModel(options.config.model);
|
|
|
- const messages: ChatOptions["messages"] = [];
|
|
|
- for (const v of options.messages) {
|
|
|
- const content = visionModel
|
|
|
- ? await preProcessImageContent(v.content)
|
|
|
- : getMessageTextContent(v);
|
|
|
- messages.push({ role: v.role, content });
|
|
|
- }
|
|
|
-
|
|
|
const modelConfig = {
|
|
|
...useAppConfig.getState().modelConfig,
|
|
|
...useChatStore.getState().currentSession().mask.modelConfig,
|
|
|
@@ -123,26 +137,48 @@ export class ChatGPTApi implements LLMApi {
|
|
|
},
|
|
|
};
|
|
|
|
|
|
- const requestPayload: RequestPayload = {
|
|
|
- messages,
|
|
|
- stream: options.config.stream,
|
|
|
- model: modelConfig.model,
|
|
|
- temperature: modelConfig.temperature,
|
|
|
- presence_penalty: modelConfig.presence_penalty,
|
|
|
- frequency_penalty: modelConfig.frequency_penalty,
|
|
|
- top_p: modelConfig.top_p,
|
|
|
- // max_tokens: Math.max(modelConfig.max_tokens, 1024),
|
|
|
- // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
|
|
|
- };
|
|
|
+ let requestPayload: RequestPayload | DalleRequestPayload;
|
|
|
+
|
|
|
+ const isDalle3 = _isDalle3(options.config.model);
|
|
|
+ if (isDalle3) {
|
|
|
+ const prompt = getMessageTextContent(options.messages.slice(-1)?.pop());
|
|
|
+ requestPayload = {
|
|
|
+ model: options.config.model,
|
|
|
+ prompt,
|
|
|
+ n: 1,
|
|
|
+ size: options.config?.size ?? "1024x1024",
|
|
|
+ };
|
|
|
+ } else {
|
|
|
+ const visionModel = isVisionModel(options.config.model);
|
|
|
+ const messages: ChatOptions["messages"] = [];
|
|
|
+ for (const v of options.messages) {
|
|
|
+ const content = visionModel
|
|
|
+ ? await preProcessImageContent(v.content)
|
|
|
+ : getMessageTextContent(v);
|
|
|
+ messages.push({ role: v.role, content });
|
|
|
+ }
|
|
|
|
|
|
- // add max_tokens to vision model
|
|
|
- if (visionModel && modelConfig.model.includes("preview")) {
|
|
|
- requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
|
|
|
+ requestPayload = {
|
|
|
+ messages,
|
|
|
+ stream: options.config.stream,
|
|
|
+ model: modelConfig.model,
|
|
|
+ temperature: modelConfig.temperature,
|
|
|
+ presence_penalty: modelConfig.presence_penalty,
|
|
|
+ frequency_penalty: modelConfig.frequency_penalty,
|
|
|
+ top_p: modelConfig.top_p,
|
|
|
+ // max_tokens: Math.max(modelConfig.max_tokens, 1024),
|
|
|
+ // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
|
|
|
+ };
|
|
|
+
|
|
|
+ // add max_tokens to vision model
|
|
|
+ if (visionModel && modelConfig.model.includes("preview")) {
|
|
|
+ requestPayload["max_tokens"] = Math.max(modelConfig.max_tokens, 4000);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
console.log("[Request] openai payload: ", requestPayload);
|
|
|
|
|
|
- const shouldStream = !!options.config.stream;
|
|
|
+ const shouldStream = !isDalle3 && !!options.config.stream;
|
|
|
const controller = new AbortController();
|
|
|
options.onController?.(controller);
|
|
|
|
|
|
@@ -168,13 +204,15 @@ export class ChatGPTApi implements LLMApi {
|
|
|
model?.provider?.providerName === ServiceProvider.Azure,
|
|
|
);
|
|
|
chatPath = this.path(
|
|
|
- Azure.ChatPath(
|
|
|
+ (isDalle3 ? Azure.ImagePath : Azure.ChatPath)(
|
|
|
(model?.displayName ?? model?.name) as string,
|
|
|
useCustomConfig ? useAccessStore.getState().azureApiVersion : "",
|
|
|
),
|
|
|
);
|
|
|
} else {
|
|
|
- chatPath = this.path(OpenaiPath.ChatPath);
|
|
|
+ chatPath = this.path(
|
|
|
+ isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath,
|
|
|
+ );
|
|
|
}
|
|
|
const chatPayload = {
|
|
|
method: "POST",
|