Browse Source

fix: fix using different model

Fred Liang 1 năm trước cách đây
mục cha
commit
5c638251f8
3 tập tin đã thay đổi với 38 bổ sung13 xóa
  1. 14 7
      app/api/auth.ts
  2. 3 3
      app/api/google/[...path]/route.ts
  3. 21 3
      app/api/openai/[...path]/route.ts

+ 14 - 7
app/api/auth.ts

@@ -1,7 +1,7 @@
 import { NextRequest } from "next/server";
 import { getServerSideConfig } from "../config/server";
 import md5 from "spark-md5";
-import { ACCESS_CODE_PREFIX } from "../constant";
+import { ACCESS_CODE_PREFIX, ModelProvider } from "../constant";
 
 function getIP(req: NextRequest) {
   let ip = req.ip ?? req.headers.get("x-real-ip");
@@ -24,7 +24,7 @@ function parseApiKey(bearToken: string) {
   };
 }
 
-export function auth(req: NextRequest) {
+export function auth(req: NextRequest, modelProvider: ModelProvider) {
   const authToken = req.headers.get("Authorization") ?? "";
 
   // check if it is openai api key or user token
@@ -56,12 +56,19 @@ export function auth(req: NextRequest) {
   // if user does not provide an api key, inject system api key
   if (!apiKey) {
     const serverConfig = getServerSideConfig();
-    const systemApiKey = serverConfig.isAzure
-      ? serverConfig.azureApiKey
-      : serverConfig.isGoogle
-      ? serverConfig.googleApiKey
-      : serverConfig.apiKey;
 
+    // const systemApiKey = serverConfig.isAzure
+    //   ? serverConfig.azureApiKey
+    //   : serverConfig.isGoogle
+    //   ? serverConfig.googleApiKey
+    //   : serverConfig.apiKey;
+
+    const systemApiKey =
+      modelProvider === ModelProvider.GeminiPro
+        ? serverConfig.googleApiKey
+        : serverConfig.isAzure
+        ? serverConfig.azureApiKey
+        : serverConfig.apiKey;
     if (systemApiKey) {
       console.log("[Auth] use system api key");
       req.headers.set("Authorization", `Bearer ${systemApiKey}`);

+ 3 - 3
app/api/google/[...path]/route.ts

@@ -1,7 +1,7 @@
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "../../auth";
 import { getServerSideConfig } from "@/app/config/server";
-import { GEMINI_BASE_URL, Google } from "@/app/constant";
+import { GEMINI_BASE_URL, Google, ModelProvider } from "@/app/constant";
 
 async function handle(
   req: NextRequest,
@@ -39,7 +39,7 @@ async function handle(
     10 * 60 * 1000,
   );
 
-  const authResult = auth(req);
+  const authResult = auth(req, ModelProvider.GeminiPro);
   if (authResult.error) {
     return NextResponse.json(authResult, {
       status: 401,
@@ -50,6 +50,7 @@ async function handle(
   const token = bearToken.trim().replaceAll("Bearer ", "").trim();
 
   const key = token ? token : serverConfig.googleApiKey;
+
   if (!key) {
     return NextResponse.json(
       {
@@ -63,7 +64,6 @@ async function handle(
   }
 
   const fetchUrl = `${baseUrl}/${path}?key=${key}`;
-
   const fetchOptions: RequestInit = {
     headers: {
       "Content-Type": "application/json",

+ 21 - 3
app/api/openai/[...path]/route.ts

@@ -1,6 +1,6 @@
 import { type OpenAIListModelResponse } from "@/app/client/platforms/openai";
 import { getServerSideConfig } from "@/app/config/server";
-import { OpenaiPath } from "@/app/constant";
+import { ModelProvider, OpenaiPath } from "@/app/constant";
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "../../auth";
@@ -45,7 +45,7 @@ async function handle(
     );
   }
 
-  const authResult = auth(req);
+  const authResult = auth(req, ModelProvider.GPT);
   if (authResult.error) {
     return NextResponse.json(authResult, {
       status: 401,
@@ -75,4 +75,22 @@ export const GET = handle;
 export const POST = handle;
 
 export const runtime = "edge";
-export const preferredRegion = ['arn1', 'bom1', 'cdg1', 'cle1', 'cpt1', 'dub1', 'fra1', 'gru1', 'hnd1', 'iad1', 'icn1', 'kix1', 'lhr1', 'pdx1', 'sfo1', 'sin1', 'syd1'];
+export const preferredRegion = [
+  "arn1",
+  "bom1",
+  "cdg1",
+  "cle1",
+  "cpt1",
+  "dub1",
+  "fra1",
+  "gru1",
+  "hnd1",
+  "iad1",
+  "icn1",
+  "kix1",
+  "lhr1",
+  "pdx1",
+  "sfo1",
+  "sin1",
+  "syd1",
+];