Quellcode durchsuchen

Merge branch 'main' of https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web into add_deepseek

suruiqiang vor 11 Monaten
Ursprung
Commit
7380c8a2c1

+ 2 - 2
app/api/alibaba.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -89,7 +89,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Alibaba as string,

+ 2 - 2
app/api/anthropic.ts

@@ -9,7 +9,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "./auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
 
 const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]);
@@ -122,7 +122,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Anthropic as string,

+ 2 - 2
app/api/baidu.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 import { getAccessToken } from "@/app/utils/baidu";
 
 const serverConfig = getServerSideConfig();
@@ -104,7 +104,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Baidu as string,

+ 2 - 2
app/api/bytedance.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -88,7 +88,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.ByteDance as string,

+ 7 - 8
app/api/common.ts

@@ -2,7 +2,7 @@ import { NextRequest, NextResponse } from "next/server";
 import { getServerSideConfig } from "../config/server";
 import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
 import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
-import { getModelProvider, isModelAvailableInServer } from "../utils/model";
+import { getModelProvider, isModelNotavailableInServer } from "../utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -118,15 +118,14 @@ export async function requestOpenai(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
-          ServiceProvider.OpenAI as string,
-        ) ||
-        isModelAvailableInServer(
-          serverConfig.customModels,
-          jsonBody?.model as string,
-          ServiceProvider.Azure as string,
+          [
+            ServiceProvider.OpenAI,
+            ServiceProvider.Azure,
+            jsonBody?.model as string, // support provider-unspecified model
+          ],
         )
       ) {
         return NextResponse.json(

+ 2 - 2
app/api/glm.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -89,7 +89,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.ChatGLM as string,

+ 2 - 2
app/api/iflytek.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 // iflytek
 
 const serverConfig = getServerSideConfig();
@@ -89,7 +89,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Iflytek as string,

+ 2 - 2
app/api/moonshot.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -88,7 +88,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Moonshot as string,

+ 2 - 2
app/api/xai.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -88,7 +88,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.XAI as string,

+ 118 - 22
app/client/platforms/glm.ts

@@ -21,16 +21,108 @@ import {
   SpeechOptions,
 } from "../api";
 import { getClientConfig } from "@/app/config/client";
-import { getMessageTextContent } from "@/app/utils";
+import { getMessageTextContent, isVisionModel } from "@/app/utils";
 import { RequestPayload } from "./openai";
 import { fetch } from "@/app/utils/stream";
+import { preProcessImageContent } from "@/app/utils/chat";
+
+interface BasePayload {
+  model: string;
+}
+
+interface ChatPayload extends BasePayload {
+  messages: ChatOptions["messages"];
+  stream?: boolean;
+  temperature?: number;
+  presence_penalty?: number;
+  frequency_penalty?: number;
+  top_p?: number;
+}
+
+interface ImageGenerationPayload extends BasePayload {
+  prompt: string;
+  size?: string;
+  user_id?: string;
+}
+
+interface VideoGenerationPayload extends BasePayload {
+  prompt: string;
+  duration?: number;
+  resolution?: string;
+  user_id?: string;
+}
+
+type ModelType = "chat" | "image" | "video";
 
 export class ChatGLMApi implements LLMApi {
   private disableListModels = true;
 
+  private getModelType(model: string): ModelType {
+    if (model.startsWith("cogview-")) return "image";
+    if (model.startsWith("cogvideo-")) return "video";
+    return "chat";
+  }
+
+  private getModelPath(type: ModelType): string {
+    switch (type) {
+      case "image":
+        return ChatGLM.ImagePath;
+      case "video":
+        return ChatGLM.VideoPath;
+      default:
+        return ChatGLM.ChatPath;
+    }
+  }
+
+  private createPayload(
+    messages: ChatOptions["messages"],
+    modelConfig: any,
+    options: ChatOptions,
+  ): BasePayload {
+    const modelType = this.getModelType(modelConfig.model);
+    const lastMessage = messages[messages.length - 1];
+    const prompt =
+      typeof lastMessage.content === "string"
+        ? lastMessage.content
+        : lastMessage.content.map((c) => c.text).join("\n");
+
+    switch (modelType) {
+      case "image":
+        return {
+          model: modelConfig.model,
+          prompt,
+          size: options.config.size,
+        } as ImageGenerationPayload;
+      default:
+        return {
+          messages,
+          stream: options.config.stream,
+          model: modelConfig.model,
+          temperature: modelConfig.temperature,
+          presence_penalty: modelConfig.presence_penalty,
+          frequency_penalty: modelConfig.frequency_penalty,
+          top_p: modelConfig.top_p,
+        } as ChatPayload;
+    }
+  }
+
+  private parseResponse(modelType: ModelType, json: any): string {
+    switch (modelType) {
+      case "image": {
+        const imageUrl = json.data?.[0]?.url;
+        return imageUrl ? `![Generated Image](${imageUrl})` : "";
+      }
+      case "video": {
+        const videoUrl = json.data?.[0]?.url;
+        return videoUrl ? `<video controls src="${videoUrl}"></video>` : "";
+      }
+      default:
+        return this.extractMessage(json);
+    }
+  }
+
   path(path: string): string {
     const accessStore = useAccessStore.getState();
-
     let baseUrl = "";
 
     if (accessStore.useCustomConfig) {
@@ -51,7 +143,6 @@ export class ChatGLMApi implements LLMApi {
     }
 
     console.log("[Proxy Endpoint] ", baseUrl, path);
-
     return [baseUrl, path].join("/");
   }
 
@@ -64,9 +155,12 @@ export class ChatGLMApi implements LLMApi {
   }
 
   async chat(options: ChatOptions) {
+    const visionModel = isVisionModel(options.config.model);
     const messages: ChatOptions["messages"] = [];
     for (const v of options.messages) {
-      const content = getMessageTextContent(v);
+      const content = visionModel
+        ? await preProcessImageContent(v.content)
+        : getMessageTextContent(v);
       messages.push({ role: v.role, content });
     }
 
@@ -78,25 +172,16 @@ export class ChatGLMApi implements LLMApi {
         providerName: options.config.providerName,
       },
     };
+    const modelType = this.getModelType(modelConfig.model);
+    const requestPayload = this.createPayload(messages, modelConfig, options);
+    const path = this.path(this.getModelPath(modelType));
 
-    const requestPayload: RequestPayload = {
-      messages,
-      stream: options.config.stream,
-      model: modelConfig.model,
-      temperature: modelConfig.temperature,
-      presence_penalty: modelConfig.presence_penalty,
-      frequency_penalty: modelConfig.frequency_penalty,
-      top_p: modelConfig.top_p,
-    };
+    console.log(`[Request] glm ${modelType} payload: `, requestPayload);
 
-    console.log("[Request] glm payload: ", requestPayload);
-
-    const shouldStream = !!options.config.stream;
     const controller = new AbortController();
     options.onController?.(controller);
 
     try {
-      const chatPath = this.path(ChatGLM.ChatPath);
       const chatPayload = {
         method: "POST",
         body: JSON.stringify(requestPayload),
@@ -104,12 +189,23 @@ export class ChatGLMApi implements LLMApi {
         headers: getHeaders(),
       };
 
-      // make a fetch request
       const requestTimeoutId = setTimeout(
         () => controller.abort(),
         REQUEST_TIMEOUT_MS,
       );
 
+      if (modelType === "image" || modelType === "video") {
+        const res = await fetch(path, chatPayload);
+        clearTimeout(requestTimeoutId);
+
+        const resJson = await res.json();
+        console.log(`[Response] glm ${modelType}:`, resJson);
+        const message = this.parseResponse(modelType, resJson);
+        options.onFinish(message, res);
+        return;
+      }
+
+      const shouldStream = !!options.config.stream;
       if (shouldStream) {
         const [tools, funcs] = usePluginStore
           .getState()
@@ -117,7 +213,7 @@ export class ChatGLMApi implements LLMApi {
             useChatStore.getState().currentSession().mask?.plugin || [],
           );
         return stream(
-          chatPath,
+          path,
           requestPayload,
           getHeaders(),
           tools as any,
@@ -125,7 +221,6 @@ export class ChatGLMApi implements LLMApi {
           controller,
           // parseSSE
           (text: string, runTools: ChatMessageTool[]) => {
-            // console.log("parseSSE", text, runTools);
             const json = JSON.parse(text);
             const choices = json.choices as Array<{
               delta: {
@@ -154,7 +249,7 @@ export class ChatGLMApi implements LLMApi {
             }
             return choices[0]?.delta?.content;
           },
-          // processToolMessage, include tool_calls message and tool call results
+          // processToolMessage
           (
             requestPayload: RequestPayload,
             toolCallMessage: any,
@@ -172,7 +267,7 @@ export class ChatGLMApi implements LLMApi {
           options,
         );
       } else {
-        const res = await fetch(chatPath, chatPayload);
+        const res = await fetch(path, chatPayload);
         clearTimeout(requestTimeoutId);
 
         const resJson = await res.json();
@@ -184,6 +279,7 @@ export class ChatGLMApi implements LLMApi {
       options.onError?.(e as Error);
     }
   }
+
   async usage() {
     return {
       used: 0,

+ 15 - 3
app/client/platforms/google.ts

@@ -60,9 +60,18 @@ export class GeminiProApi implements LLMApi {
   extractMessage(res: any) {
     console.log("[Response] gemini-pro response: ", res);
 
+    const getTextFromParts = (parts: any[]) => {
+      if (!Array.isArray(parts)) return "";
+
+      return parts
+        .map((part) => part?.text || "")
+        .filter((text) => text.trim() !== "")
+        .join("\n\n");
+    };
+
     return (
-      res?.candidates?.at(0)?.content?.parts.at(0)?.text ||
-      res?.at(0)?.candidates?.at(0)?.content?.parts.at(0)?.text ||
+      getTextFromParts(res?.candidates?.at(0)?.content?.parts) ||
+      getTextFromParts(res?.at(0)?.candidates?.at(0)?.content?.parts) ||
       res?.error?.message ||
       ""
     );
@@ -223,7 +232,10 @@ export class GeminiProApi implements LLMApi {
                 },
               });
             }
-            return chunkJson?.candidates?.at(0)?.content.parts.at(0)?.text;
+            return chunkJson?.candidates
+              ?.at(0)
+              ?.content.parts?.map((part: { text: string }) => part.text)
+              .join("\n\n");
           },
           // processToolMessage, include tool_calls message and tool call results
           (

+ 2 - 2
app/client/platforms/openai.ts

@@ -24,7 +24,7 @@ import {
   stream,
 } from "@/app/utils/chat";
 import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
-import { DalleSize, DalleQuality, DalleStyle } from "@/app/typing";
+import { ModelSize, DalleQuality, DalleStyle } from "@/app/typing";
 
 import {
   ChatOptions,
@@ -73,7 +73,7 @@ export interface DalleRequestPayload {
   prompt: string;
   response_format: "url" | "b64_json";
   n: number;
-  size: DalleSize;
+  size: ModelSize;
   quality: DalleQuality;
   style: DalleStyle;
 }

+ 8 - 5
app/components/chat.tsx

@@ -72,6 +72,8 @@ import {
   isDalle3,
   showPlugins,
   safeLocalStorage,
+  getModelSizes,
+  supportsCustomSize,
 } from "../utils";
 
 import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
@@ -79,7 +81,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
 import dynamic from "next/dynamic";
 
 import { ChatControllerPool } from "../client/controller";
-import { DalleSize, DalleQuality, DalleStyle } from "../typing";
+import { DalleQuality, DalleStyle, ModelSize } from "../typing";
 import { Prompt, usePromptStore } from "../store/prompt";
 import Locale from "../locales";
 
@@ -519,10 +521,11 @@ export function ChatActions(props: {
   const [showSizeSelector, setShowSizeSelector] = useState(false);
   const [showQualitySelector, setShowQualitySelector] = useState(false);
   const [showStyleSelector, setShowStyleSelector] = useState(false);
-  const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
+  const modelSizes = getModelSizes(currentModel);
   const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
   const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
-  const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
+  const currentSize =
+    session.mask.modelConfig?.size ?? ("1024x1024" as ModelSize);
   const currentQuality = session.mask.modelConfig?.quality ?? "standard";
   const currentStyle = session.mask.modelConfig?.style ?? "vivid";
 
@@ -673,7 +676,7 @@ export function ChatActions(props: {
           />
         )}
 
-        {isDalle3(currentModel) && (
+        {supportsCustomSize(currentModel) && (
           <ChatAction
             onClick={() => setShowSizeSelector(true)}
             text={currentSize}
@@ -684,7 +687,7 @@ export function ChatActions(props: {
         {showSizeSelector && (
           <Selector
             defaultSelectedValue={currentSize}
-            items={dalle3Sizes.map((m) => ({
+            items={modelSizes.map((m) => ({
               title: m,
               value: m,
             }))}

+ 11 - 6
app/components/sidebar.tsx

@@ -22,7 +22,6 @@ import {
   MIN_SIDEBAR_WIDTH,
   NARROW_SIDEBAR_WIDTH,
   Path,
-  PLUGINS,
   REPO_URL,
 } from "../constant";
 
@@ -32,6 +31,12 @@ import dynamic from "next/dynamic";
 import { showConfirm, Selector } from "./ui-lib";
 import clsx from "clsx";
 
+const DISCOVERY = [
+  { name: Locale.Plugin.Name, path: Path.Plugins },
+  { name: "Stable Diffusion", path: Path.Sd },
+  { name: Locale.SearchChat.Page.Title, path: Path.SearchChat },
+];
+
 const ChatList = dynamic(async () => (await import("./chat-list")).ChatList, {
   loading: () => null,
 });
@@ -219,7 +224,7 @@ export function SideBarTail(props: {
 export function SideBar(props: { className?: string }) {
   useHotKey();
   const { onDragStart, shouldNarrow } = useDragSideBar();
-  const [showPluginSelector, setShowPluginSelector] = useState(false);
+  const [showDiscoverySelector, setshowDiscoverySelector] = useState(false);
   const navigate = useNavigate();
   const config = useAppConfig();
   const chatStore = useChatStore();
@@ -254,21 +259,21 @@ export function SideBar(props: { className?: string }) {
             icon={<DiscoveryIcon />}
             text={shouldNarrow ? undefined : Locale.Discovery.Name}
             className={styles["sidebar-bar-button"]}
-            onClick={() => setShowPluginSelector(true)}
+            onClick={() => setshowDiscoverySelector(true)}
             shadow
           />
         </div>
-        {showPluginSelector && (
+        {showDiscoverySelector && (
           <Selector
             items={[
-              ...PLUGINS.map((item) => {
+              ...DISCOVERY.map((item) => {
                 return {
                   title: item.name,
                   value: item.path,
                 };
               }),
             ]}
-            onClose={() => setShowPluginSelector(false)}
+            onClose={() => setshowDiscoverySelector(false)}
             onSelection={(s) => {
               navigate(s[0], { state: { fromHome: true } });
             }}

+ 4 - 13
app/config/server.ts

@@ -1,5 +1,6 @@
 import md5 from "spark-md5";
 import { DEFAULT_MODELS, DEFAULT_GA_ID } from "../constant";
+import { isGPT4Model } from "../utils/model";
 
 declare global {
   namespace NodeJS {
@@ -130,22 +131,12 @@ export const getServerSideConfig = () => {
 
   if (disableGPT4) {
     if (customModels) customModels += ",";
-    customModels += DEFAULT_MODELS.filter(
-      (m) =>
-        (m.name.startsWith("gpt-4") ||
-          m.name.startsWith("chatgpt-4o") ||
-          m.name.startsWith("o1")) &&
-        !m.name.startsWith("gpt-4o-mini"),
-    )
+    customModels += DEFAULT_MODELS.filter((m) => isGPT4Model(m.name))
       .map((m) => "-" + m.name)
       .join(",");
-    if (
-      (defaultModel.startsWith("gpt-4") ||
-        defaultModel.startsWith("chatgpt-4o") ||
-        defaultModel.startsWith("o1")) &&
-      !defaultModel.startsWith("gpt-4o-mini")
-    )
+    if (defaultModel && isGPT4Model(defaultModel)) {
       defaultModel = "";
+    }
   }
 
   const isStability = !!process.env.STABILITY_API_KEY;

+ 12 - 5
app/constant.ts

@@ -243,6 +243,8 @@ export const XAI = {
 export const ChatGLM = {
   ExampleEndpoint: CHATGLM_BASE_URL,
   ChatPath: "api/paas/v4/chat/completions",
+  ImagePath: "api/paas/v4/images/generations",
+  VideoPath: "api/paas/v4/videos/generations",
 };
 
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@@ -313,6 +315,7 @@ export const VISION_MODEL_REGEXES = [
   /qwen2-vl/,
   /gpt-4-turbo(?!.*preview)/, // Matches "gpt-4-turbo" but not "gpt-4-turbo-preview"
   /^dall-e-3$/, // Matches exactly "dall-e-3"
+  /glm-4v/,
 ];
 
 export const EXCLUDE_VISION_MODEL_REGEXES = [/claude-3-5-haiku-20241022/];
@@ -443,6 +446,15 @@ const chatglmModels = [
   "glm-4-long",
   "glm-4-flashx",
   "glm-4-flash",
+  "glm-4v-plus",
+  "glm-4v",
+  "glm-4v-flash", // free
+  "cogview-3-plus",
+  "cogview-3",
+  "cogview-3-flash", // free
+  // 目前无法适配轮询任务
+  //   "cogvideox",
+  //   "cogvideox-flash", // free
 ];
 
 let seq = 1000; // 内置的模型序号生成器从1000开始
@@ -609,11 +621,6 @@ export const internalAllowedWebDavEndpoints = [
 ];
 
 export const DEFAULT_GA_ID = "G-89WN60ZK2E";
-export const PLUGINS = [
-  { name: "Plugins", path: Path.Plugins },
-  { name: "Stable Diffusion", path: Path.Sd },
-  { name: "Search Chat", path: Path.SearchChat },
-];
 
 export const SAAS_CHAT_URL = "https://nextchat.dev/chat";
 export const SAAS_CHAT_UTM_URL = "https://nextchat.dev/chat?utm=github";

+ 2 - 2
app/locales/cn.ts

@@ -176,7 +176,7 @@ const cn = {
       },
     },
     Lang: {
-      Name: "Language", // ATTENTION: if you wanna add a new translation, please do not translate this value, leave it as `Language`
+      Name: "Language", // 注意:如果要添加新的翻译,请不要翻译此值,将它保留为 `Language`
       All: "所有语言",
     },
     Avatar: "头像",
@@ -630,7 +630,7 @@ const cn = {
     Sysmessage: "你是一个助手",
   },
   SearchChat: {
-    Name: "搜索",
+    Name: "搜索聊天记录",
     Page: {
       Title: "搜索聊天记录",
       Search: "输入搜索关键词",

+ 1 - 1
app/locales/tw.ts

@@ -485,7 +485,7 @@ const tw = {
     },
   },
   SearchChat: {
-    Name: "搜尋",
+    Name: "搜尋聊天記錄",
     Page: {
       Title: "搜尋聊天記錄",
       Search: "輸入搜尋關鍵詞",

+ 2 - 2
app/store/config.ts

@@ -1,5 +1,5 @@
 import { LLMModel } from "../client/api";
-import { DalleSize, DalleQuality, DalleStyle } from "../typing";
+import { DalleQuality, DalleStyle, ModelSize } from "../typing";
 import { getClientConfig } from "../config/client";
 import {
   DEFAULT_INPUT_TEMPLATE,
@@ -78,7 +78,7 @@ export const DEFAULT_CONFIG = {
     compressProviderName: "",
     enableInjectSystemPrompts: true,
     template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
-    size: "1024x1024" as DalleSize,
+    size: "1024x1024" as ModelSize,
     quality: "standard" as DalleQuality,
     style: "vivid" as DalleStyle,
   },

+ 11 - 0
app/typing.ts

@@ -11,3 +11,14 @@ export interface RequestMessage {
 export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
 export type DalleQuality = "standard" | "hd";
 export type DalleStyle = "vivid" | "natural";
+
+export type ModelSize =
+  | "1024x1024"
+  | "1792x1024"
+  | "1024x1792"
+  | "768x1344"
+  | "864x1152"
+  | "1344x768"
+  | "1152x864"
+  | "1440x720"
+  | "720x1440";

+ 23 - 0
app/utils.ts

@@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant";
 import { fetch as tauriStreamFetch } from "./utils/stream";
 import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant";
 import { getClientConfig } from "./config/client";
+import { ModelSize } from "./typing";
 
 export function trimTopic(topic: string) {
   // Fix an issue where double quotes still show in the Indonesian language
@@ -271,6 +272,28 @@ export function isDalle3(model: string) {
   return "dall-e-3" === model;
 }
 
+export function getModelSizes(model: string): ModelSize[] {
+  if (isDalle3(model)) {
+    return ["1024x1024", "1792x1024", "1024x1792"];
+  }
+  if (model.toLowerCase().includes("cogview")) {
+    return [
+      "1024x1024",
+      "768x1344",
+      "864x1152",
+      "1344x768",
+      "1152x864",
+      "1440x720",
+      "720x1440",
+    ];
+  }
+  return [];
+}
+
+export function supportsCustomSize(model: string): boolean {
+  return getModelSizes(model).length > 0;
+}
+
 export function showPlugins(provider: ServiceProvider, model: string) {
   if (
     provider == ServiceProvider.OpenAI ||

+ 49 - 0
app/utils/model.ts

@@ -202,3 +202,52 @@ export function isModelAvailableInServer(
   const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
   return modelTable[fullName]?.available === false;
 }
+
+/**
+ * Check if the model name is a GPT-4 related model
+ *
+ * @param modelName The name of the model to check
+ * @returns True if the model is a GPT-4 related model (excluding gpt-4o-mini)
+ */
+export function isGPT4Model(modelName: string): boolean {
+  return (
+    (modelName.startsWith("gpt-4") ||
+      modelName.startsWith("chatgpt-4o") ||
+      modelName.startsWith("o1")) &&
+    !modelName.startsWith("gpt-4o-mini")
+  );
+}
+
+/**
+ * Checks if a model is not available on any of the specified providers in the server.
+ *
+ * @param {string} customModels - A string of custom models, comma-separated.
+ * @param {string} modelName - The name of the model to check.
+ * @param {string|string[]} providerNames - A string or array of provider names to check against.
+ *
+ * @returns {boolean} True if the model is not available on any of the specified providers, false otherwise.
+ */
+export function isModelNotavailableInServer(
+  customModels: string,
+  modelName: string,
+  providerNames: string | string[],
+): boolean {
+  // Check DISABLE_GPT4 environment variable
+  if (
+    process.env.DISABLE_GPT4 === "1" &&
+    isGPT4Model(modelName.toLowerCase())
+  ) {
+    return true;
+  }
+
+  const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
+
+  const providerNamesArray = Array.isArray(providerNames)
+    ? providerNames
+    : [providerNames];
+  for (const providerName of providerNamesArray) {
+    const fullName = `${modelName}@${providerName.toLowerCase()}`;
+    if (modelTable?.[fullName]?.available === true) return false;
+  }
+  return true;
+}

+ 80 - 0
test/model-available.test.ts

@@ -0,0 +1,80 @@
+import { isModelNotavailableInServer } from "../app/utils/model";
+
+describe("isModelNotavailableInServer", () => {
+  test("test model will return false, which means the model is available", () => {
+    const customModels = "";
+    const modelName = "gpt-4";
+    const providerNames = "OpenAI";
+    const result = isModelNotavailableInServer(
+      customModels,
+      modelName,
+      providerNames,
+    );
+    expect(result).toBe(false);
+  });
+
+  test("test model will return true when model is not available in custom models", () => {
+    const customModels = "-all,gpt-4o-mini";
+    const modelName = "gpt-4";
+    const providerNames = "OpenAI";
+    const result = isModelNotavailableInServer(
+      customModels,
+      modelName,
+      providerNames,
+    );
+    expect(result).toBe(true);
+  });
+
+  test("should respect DISABLE_GPT4 setting", () => {
+    process.env.DISABLE_GPT4 = "1";
+    const result = isModelNotavailableInServer("", "gpt-4", "OpenAI");
+    expect(result).toBe(true);
+  });
+
+  test("should handle empty provider names", () => {
+    const result = isModelNotavailableInServer("-all,gpt-4", "gpt-4", "");
+    expect(result).toBe(true);
+  });
+
+  test("should be case insensitive for model names", () => {
+    const result = isModelNotavailableInServer("-all,GPT-4", "gpt-4", "OpenAI");
+    expect(result).toBe(true);
+  });
+
+  test("support passing multiple providers, model unavailable on one of the providers will return true", () => {
+    const customModels = "-all,gpt-4@google";
+    const modelName = "gpt-4";
+    const providerNames = ["OpenAI", "Azure"];
+    const result = isModelNotavailableInServer(
+      customModels,
+      modelName,
+      providerNames,
+    );
+    expect(result).toBe(true);
+  });
+
+  // FIXME: 这个测试用例有问题,需要修复
+  //   test("support passing multiple providers, model available on one of the providers will return false", () => {
+  //     const customModels = "-all,gpt-4@google";
+  //     const modelName = "gpt-4";
+  //     const providerNames = ["OpenAI", "Google"];
+  //     const result = isModelNotavailableInServer(
+  //       customModels,
+  //       modelName,
+  //       providerNames,
+  //     );
+  //     expect(result).toBe(false);
+  //   });
+
+  test("test custom model without setting provider", () => {
+    const customModels = "-all,mistral-large";
+    const modelName = "mistral-large";
+    const providerNames = modelName;
+    const result = isModelNotavailableInServer(
+      customModels,
+      modelName,
+      providerNames,
+    );
+    expect(result).toBe(false);
+  });
+});