Pārlūkot izejas kodu

Merge pull request #5883 from code-october/fix/model-leak

fix model leak issue
Dogtiti 11 mēneši atpakaļ
vecāks
revīzija
d91af7f983

+ 2 - 2
app/api/alibaba.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -89,7 +89,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Alibaba as string,

+ 2 - 2
app/api/anthropic.ts

@@ -9,7 +9,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "./auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
 
 const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]);
@@ -122,7 +122,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Anthropic as string,

+ 2 - 2
app/api/baidu.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 import { getAccessToken } from "@/app/utils/baidu";
 
 const serverConfig = getServerSideConfig();
@@ -104,7 +104,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Baidu as string,

+ 2 - 2
app/api/bytedance.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -88,7 +88,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.ByteDance as string,

+ 7 - 8
app/api/common.ts

@@ -2,7 +2,7 @@ import { NextRequest, NextResponse } from "next/server";
 import { getServerSideConfig } from "../config/server";
 import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
 import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
-import { getModelProvider, isModelAvailableInServer } from "../utils/model";
+import { getModelProvider, isModelNotavailableInServer } from "../utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -118,15 +118,14 @@ export async function requestOpenai(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
-          ServiceProvider.OpenAI as string,
-        ) ||
-        isModelAvailableInServer(
-          serverConfig.customModels,
-          jsonBody?.model as string,
-          ServiceProvider.Azure as string,
+          [
+            ServiceProvider.OpenAI,
+            ServiceProvider.Azure,
+            jsonBody?.model as string,  // support provider-unspecified model
+          ],
         )
       ) {
         return NextResponse.json(

+ 2 - 2
app/api/glm.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -89,7 +89,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.ChatGLM as string,

+ 2 - 2
app/api/iflytek.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 // iflytek
 
 const serverConfig = getServerSideConfig();
@@ -89,7 +89,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Iflytek as string,

+ 2 - 2
app/api/moonshot.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -88,7 +88,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.Moonshot as string,

+ 2 - 2
app/api/xai.ts

@@ -8,7 +8,7 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
-import { isModelAvailableInServer } from "@/app/utils/model";
+import { isModelNotavailableInServer } from "@/app/utils/model";
 
 const serverConfig = getServerSideConfig();
 
@@ -88,7 +88,7 @@ async function request(req: NextRequest) {
 
       // not undefined and is false
       if (
-        isModelAvailableInServer(
+        isModelNotavailableInServer(
           serverConfig.customModels,
           jsonBody?.model as string,
           ServiceProvider.XAI as string,

+ 24 - 0
app/utils/model.ts

@@ -202,3 +202,27 @@ export function isModelAvailableInServer(
   const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
   return modelTable[fullName]?.available === false;
 }
+
+/**
+ * Checks if a model is not available on any of the specified providers in the server.
+ * 
+ * @param {string} customModels - A string of custom models, comma-separated.
+ * @param {string} modelName - The name of the model to check.
+ * @param {string|string[]} providerNames - A string or array of provider names to check against.
+ * 
+ * @returns {boolean} True if the model is not available on any of the specified providers, false otherwise.
+ */
+export function isModelNotavailableInServer(
+  customModels: string,
+  modelName: string,
+  providerNames: string | string[],
+) {
+  const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
+  const providerNamesArray = Array.isArray(providerNames) ? providerNames : [providerNames];
+  for (const providerName of providerNamesArray){
+    const fullName = `${modelName}@${providerName.toLowerCase()}`;
+    if (modelTable[fullName]?.available === true)
+      return false;
+  }
+  return true;
+}

+ 59 - 0
test/model-available.test.ts

@@ -0,0 +1,59 @@
+import { isModelNotavailableInServer } from "../app/utils/model";
+
+describe("isModelNotavailableInServer", () => {
+    test("test model will return false, which means the model is available", () => {
+        const customModels = "";
+        const modelName = "gpt-4";
+        const providerNames = "OpenAI";
+        const result = isModelNotavailableInServer(customModels, modelName, providerNames);
+        expect(result).toBe(false);
+    });
+
+    test("test model will return true when model is not available in custom models", () => {
+        const customModels = "-all,gpt-4o-mini";
+        const modelName = "gpt-4";
+        const providerNames = "OpenAI";
+        const result = isModelNotavailableInServer(customModels, modelName, providerNames);
+        expect(result).toBe(true);
+    });
+
+    test("should respect DISABLE_GPT4 setting", () => {
+        process.env.DISABLE_GPT4 = "1";
+        const result = isModelNotavailableInServer("", "gpt-4", "OpenAI");
+        expect(result).toBe(true);
+    });
+    
+    test("should handle empty provider names", () => {
+        const result = isModelNotavailableInServer("-all,gpt-4", "gpt-4", "");
+        expect(result).toBe(true);
+    });
+
+    test("should be case insensitive for model names", () => {
+        const result = isModelNotavailableInServer("-all,GPT-4", "gpt-4", "OpenAI");
+        expect(result).toBe(true);
+    });
+    
+    test("support passing multiple providers, model unavailable on one of the providers will return true", () => {
+        const customModels = "-all,gpt-4@Google";
+        const modelName = "gpt-4";
+        const providerNames = ["OpenAI", "Azure"];
+        const result = isModelNotavailableInServer(customModels, modelName, providerNames);
+        expect(result).toBe(true);
+    });
+
+    test("support passing multiple providers, model available on one of the providers will return false", () => {
+        const customModels = "-all,gpt-4@Google";
+        const modelName = "gpt-4";
+        const providerNames = ["OpenAI", "Google"];
+        const result = isModelNotavailableInServer(customModels, modelName, providerNames);
+        expect(result).toBe(false);
+    });
+
+    test("test custom model without setting provider", () => {
+        const customModels = "-all,mistral-large";
+        const modelName = "mistral-large";
+        const providerNames = modelName;
+        const result = isModelNotavailableInServer(customModels, modelName, providerNames);
+        expect(result).toBe(false);
+    });
+})