Jelajahi Sumber

try getAccessToken in app, fixbug to fetch in none stream mode

lloydzhou 1 tahun lalu
induk
melakukan
fadd7f6eb4
3 mengubah file dengan 47 tambahan dan 33 penghapusan
  1. 17 24
      app/api/baidu/[...path]/route.ts
  2. 29 8
      app/client/platforms/baidu.ts
  3. 1 1
      app/constant.ts

+ 17 - 24
app/api/baidu/[...path]/route.ts

@@ -10,6 +10,7 @@ import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
 import { auth } from "@/app/api/auth";
 import { isModelAvailableInServer } from "@/app/utils/model";
+import { getAccessToken } from "@/app/utils/baidu";
 
 const serverConfig = getServerSideConfig();
 
@@ -30,6 +31,18 @@ async function handle(
     });
   }
 
+  if (!serverConfig.baiduApiKey || !serverConfig.baiduSecretKey) {
+    return NextResponse.json(
+      {
+        error: true,
+        message: `missing BAIDU_API_KEY or BAIDU_SECRET_KEY in server env vars`,
+      },
+      {
+        status: 401,
+      },
+    );
+  }
+
   try {
     const response = await request(req);
     return response;
@@ -88,7 +101,10 @@ async function request(req: NextRequest) {
     10 * 60 * 1000,
   );
 
-  const { access_token } = await getAccessToken();
+  const { access_token } = await getAccessToken(
+    serverConfig.baiduApiKey,
+    serverConfig.baiduSecretKey,
+  );
   const fetchUrl = `${baseUrl}${path}?access_token=${access_token}`;
 
   const fetchOptions: RequestInit = {
@@ -133,11 +149,9 @@ async function request(req: NextRequest) {
       console.error(`[Baidu] filter`, e);
     }
   }
-  console.log("[Baidu request]", fetchOptions.headers, req.method);
   try {
     const res = await fetch(fetchUrl, fetchOptions);
 
-    console.log("[Baidu response]", res.status, "   ", res.headers, res.url);
     // to prevent browser prompt for credentials
     const newHeaders = new Headers(res.headers);
     newHeaders.delete("www-authenticate");
@@ -153,24 +167,3 @@ async function request(req: NextRequest) {
     clearTimeout(timeoutId);
   }
 }
-
-/**
- * 使用 AK,SK 生成鉴权签名(Access Token)
- * @return 鉴权签名信息
- */
-async function getAccessToken(): Promise<{
-  access_token: string;
-  expires_in: number;
-  error?: number;
-}> {
-  const AK = serverConfig.baiduApiKey;
-  const SK = serverConfig.baiduSecretKey;
-  const res = await fetch(
-    `${BAIDU_OATUH_URL}?grant_type=client_credentials&client_id=${AK}&client_secret=${SK}`,
-    {
-      method: "POST",
-    },
-  );
-  const resJson = await res.json();
-  return resJson;
-}

+ 29 - 8
app/client/platforms/baidu.ts

@@ -6,6 +6,7 @@ import {
   REQUEST_TIMEOUT_MS,
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
+import { getAccessToken } from "@/app/utils/baidu";
 
 import {
   ChatOptions,
@@ -74,16 +75,20 @@ export class ErnieApi implements LLMApi {
     return [baseUrl, path].join("/");
   }
 
-  extractMessage(res: any) {
-    return res.choices?.at(0)?.message?.content ?? "";
-  }
-
   async chat(options: ChatOptions) {
     const messages = options.messages.map((v) => ({
       role: v.role,
       content: getMessageTextContent(v),
     }));
 
+    // "error_code": 336006, "error_msg": "the length of messages must be an odd number",
+    if (messages.length % 2 === 0) {
+      messages.unshift({
+        role: "user",
+        content: " ",
+      });
+    }
+
     const modelConfig = {
       ...useAppConfig.getState().modelConfig,
       ...useChatStore.getState().currentSession().mask.modelConfig,
@@ -92,9 +97,10 @@ export class ErnieApi implements LLMApi {
       },
     };
 
+    const shouldStream = !!options.config.stream;
     const requestPayload: RequestPayload = {
       messages,
-      stream: options.config.stream,
+      stream: shouldStream,
       model: modelConfig.model,
       temperature: modelConfig.temperature,
       presence_penalty: modelConfig.presence_penalty,
@@ -104,12 +110,27 @@ export class ErnieApi implements LLMApi {
 
     console.log("[Request] Baidu payload: ", requestPayload);
 
-    const shouldStream = !!options.config.stream;
     const controller = new AbortController();
     options.onController?.(controller);
 
     try {
-      const chatPath = this.path(Baidu.ChatPath(modelConfig.model));
+      let chatPath = this.path(Baidu.ChatPath(modelConfig.model));
+
+      // getAccessToken can not run in browser, because cors error
+      if (!!getClientConfig()?.isApp) {
+        const accessStore = useAccessStore.getState();
+        if (accessStore.useCustomConfig) {
+          if (accessStore.isValidBaidu()) {
+            const { access_token } = await getAccessToken(
+              accessStore.baiduApiKey,
+              accessStore.baiduSecretKey,
+            );
+            chatPath = `${chatPath}${
+              chatPath.includes("?") ? "&" : "?"
+            }access_token=${access_token}`;
+          }
+        }
+      }
       const chatPayload = {
         method: "POST",
         body: JSON.stringify(requestPayload),
@@ -230,7 +251,7 @@ export class ErnieApi implements LLMApi {
         clearTimeout(requestTimeoutId);
 
         const resJson = await res.json();
-        const message = this.extractMessage(resJson);
+        const message = resJson?.result;
         options.onFinish(message);
       }
     } catch (e) {

+ 1 - 1
app/constant.ts

@@ -124,7 +124,7 @@ export const Baidu = {
     if (modelName === "ernie-3.5-8k") {
       endpoint = "completions";
     }
-    return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
+    return `rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
   },
 };