Browse Source

Merge pull request #4923 from ConnectAI-E/refactor-model-table

Refactor model table
Dogtiti 1 year ago
parent
commit
c4a6c933f8
4 changed files with 74 additions and 34 deletions
  1. 9 6
      app/api/anthropic/[...path]/route.ts
  2. 28 18
      app/api/common.ts
  3. 2 2
      app/store/config.ts
  4. 35 8
      app/utils/model.ts

+ 9 - 6
app/api/anthropic/[...path]/route.ts

@@ -4,12 +4,13 @@ import {
   Anthropic,
   ApiPath,
   DEFAULT_MODELS,
+  ServiceProvider,
   ModelProvider,
 } from "@/app/constant";
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "../../auth";
-import { collectModelTable } from "@/app/utils/model";
+import { isModelAvailableInServer } from "@/app/utils/model";
 
 const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]);
 
@@ -136,17 +137,19 @@ async function request(req: NextRequest) {
   // #1815 try to refuse some request to some models
   if (serverConfig.customModels && req.body) {
     try {
-      const modelTable = collectModelTable(
-        DEFAULT_MODELS,
-        serverConfig.customModels,
-      );
       const clonedBody = await req.text();
       fetchOptions.body = clonedBody;
 
       const jsonBody = JSON.parse(clonedBody) as { model?: string };
 
       // not undefined and is false
-      if (modelTable[jsonBody?.model ?? ""].available === false) {
+      if (
+        isModelAvailableInServer(
+          serverConfig.customModels,
+          jsonBody?.model as string,
+          ServiceProvider.Anthropic as string,
+        )
+      ) {
         return NextResponse.json(
           {
             error: true,

+ 28 - 18
app/api/common.ts

@@ -1,7 +1,12 @@
 import { NextRequest, NextResponse } from "next/server";
 import { getServerSideConfig } from "../config/server";
-import { DEFAULT_MODELS, OPENAI_BASE_URL, GEMINI_BASE_URL } from "../constant";
-import { collectModelTable } from "../utils/model";
+import {
+  DEFAULT_MODELS,
+  OPENAI_BASE_URL,
+  GEMINI_BASE_URL,
+  ServiceProvider,
+} from "../constant";
+import { isModelAvailableInServer } from "../utils/model";
 import { makeAzurePath } from "../azure";
 
 const serverConfig = getServerSideConfig();
@@ -83,17 +88,24 @@ export async function requestOpenai(req: NextRequest) {
   // #1815 try to refuse gpt4 request
   if (serverConfig.customModels && req.body) {
     try {
-      const modelTable = collectModelTable(
-        DEFAULT_MODELS,
-        serverConfig.customModels,
-      );
       const clonedBody = await req.text();
       fetchOptions.body = clonedBody;
 
       const jsonBody = JSON.parse(clonedBody) as { model?: string };
 
       // not undefined and is false
-      if (modelTable[jsonBody?.model ?? ""].available === false) {
+      if (
+        isModelAvailableInServer(
+          serverConfig.customModels,
+          jsonBody?.model as string,
+          ServiceProvider.OpenAI as string,
+        ) ||
+        isModelAvailableInServer(
+          serverConfig.customModels,
+          jsonBody?.model as string,
+          ServiceProvider.Azure as string,
+        )
+      ) {
         return NextResponse.json(
           {
             error: true,
@@ -112,16 +124,16 @@ export async function requestOpenai(req: NextRequest) {
   try {
     const res = await fetch(fetchUrl, fetchOptions);
 
-  // Extract the OpenAI-Organization header from the response
-  const openaiOrganizationHeader = res.headers.get("OpenAI-Organization");
+    // Extract the OpenAI-Organization header from the response
+    const openaiOrganizationHeader = res.headers.get("OpenAI-Organization");
 
-  // Check if serverConfig.openaiOrgId is defined and not an empty string
-  if (serverConfig.openaiOrgId && serverConfig.openaiOrgId.trim() !== "") {
-    // If openaiOrganizationHeader is present, log it; otherwise, log that the header is not present
-    console.log("[Org ID]", openaiOrganizationHeader);
-  } else {
-    console.log("[Org ID] is not set up.");
-  }
+    // Check if serverConfig.openaiOrgId is defined and not an empty string
+    if (serverConfig.openaiOrgId && serverConfig.openaiOrgId.trim() !== "") {
+      // If openaiOrganizationHeader is present, log it; otherwise, log that the header is not present
+      console.log("[Org ID]", openaiOrganizationHeader);
+    } else {
+      console.log("[Org ID] is not set up.");
+    }
 
     // to prevent browser prompt for credentials
     const newHeaders = new Headers(res.headers);
@@ -129,7 +141,6 @@ export async function requestOpenai(req: NextRequest) {
     // to disable nginx buffering
     newHeaders.set("X-Accel-Buffering", "no");
 
-
     // Conditionally delete the OpenAI-Organization header from the response if [Org ID] is undefined or empty (not setup in ENV)
     // Also, this is to prevent the header from being sent to the client
     if (!serverConfig.openaiOrgId || serverConfig.openaiOrgId.trim() === "") {
@@ -142,7 +153,6 @@ export async function requestOpenai(req: NextRequest) {
     // The browser will try to decode the response with brotli and fail
     newHeaders.delete("content-encoding");
 
-
     return new Response(res.body, {
       status: res.status,
       statusText: res.statusText,

+ 2 - 2
app/store/config.ts

@@ -116,12 +116,12 @@ export const useAppConfig = createPersistStore(
 
       for (const model of oldModels) {
         model.available = false;
-        modelMap[model.name] = model;
+        modelMap[`${model.name}@${model?.provider?.id}`] = model;
       }
 
       for (const model of newModels) {
         model.available = true;
-        modelMap[model.name] = model;
+        modelMap[`${model.name}@${model?.provider?.id}`] = model;
       }
 
       set(() => ({

+ 35 - 8
app/utils/model.ts

@@ -1,8 +1,9 @@
+import { DEFAULT_MODELS } from "../constant";
 import { LLMModel } from "../client/api";
 
 const customProvider = (modelName: string) => ({
   id: modelName,
-  providerName: "",
+  providerName: "Custom",
   providerType: "custom",
 });
 
@@ -23,7 +24,8 @@ export function collectModelTable(
 
   // default models
   models.forEach((m) => {
-    modelTable[m.name] = {
+    // using <modelName>@<providerId> as fullName
+    modelTable[`${m.name}@${m?.provider?.id}`] = {
       ...m,
       displayName: m.name, // 'provider' is copied over if it exists
     };
@@ -45,12 +47,27 @@ export function collectModelTable(
           (model) => (model.available = available),
         );
       } else {
-        modelTable[name] = {
-          name,
-          displayName: displayName || name,
-          available,
-          provider: modelTable[name]?.provider ?? customProvider(name), // Use optional chaining
-        };
+        // 1. find model by name(), and set available value
+        let count = 0;
+        for (const fullName in modelTable) {
+          if (fullName.split("@").shift() == name) {
+            count += 1;
+            modelTable[fullName]["available"] = available;
+            if (displayName) {
+              modelTable[fullName]["displayName"] = displayName;
+            }
+          }
+        }
+        // 2. if model not exists, create new model with available value
+        if (count === 0) {
+          const provider = customProvider(name);
+          modelTable[`${name}@${provider?.id}`] = {
+            name,
+            displayName: displayName || name,
+            available,
+            provider, // Use optional chaining
+          };
+        }
       }
     });
 
@@ -100,3 +117,13 @@ export function collectModelsWithDefaultModel(
   const allModels = Object.values(modelTable);
   return allModels;
 }
+
+export function isModelAvailableInServer(
+  customModels: string,
+  modelName: string,
+  providerName: string,
+) {
+  const fullName = `${modelName}@${providerName}`;
+  const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
+  return modelTable[fullName]?.available === false;
+}