Browse Source

review code

lloydzhou 1 year ago
parent
commit
f68cd2c5c0
3 changed files with 24 additions and 15 deletions
  1. 5 5
      app/client/platforms/baidu.ts
  2. 17 6
      app/constant.ts
  3. 2 4
      app/utils/model.ts

+ 5 - 5
app/client/platforms/baidu.ts

@@ -2,7 +2,7 @@
 import {
   ApiPath,
   Baidu,
-  DEFAULT_API_HOST,
+  BAIDU_BASE_URL,
   REQUEST_TIMEOUT_MS,
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
@@ -21,7 +21,7 @@ import {
 } from "@fortaine/fetch-event-source";
 import { prettyObject } from "@/app/utils/format";
 import { getClientConfig } from "@/app/config/client";
-import { getMessageTextContent, isVisionModel } from "@/app/utils";
+import { getMessageTextContent } from "@/app/utils";
 
 export interface OpenAIListModelResponse {
   object: string;
@@ -58,7 +58,8 @@ export class ErnieApi implements LLMApi {
 
     if (baseUrl.length === 0) {
       const isApp = !!getClientConfig()?.isApp;
-      baseUrl = isApp ? DEFAULT_API_HOST + "/api/proxy/baidu" : ApiPath.Baidu;
+      // do not use proxy for baidubce api
+      baseUrl = isApp ? BAIDU_BASE_URL : ApiPath.Baidu;
     }
 
     if (baseUrl.endsWith("/")) {
@@ -78,10 +79,9 @@ export class ErnieApi implements LLMApi {
   }
 
   async chat(options: ChatOptions) {
-    const visionModel = isVisionModel(options.config.model);
     const messages = options.messages.map((v) => ({
       role: v.role,
-      content: visionModel ? v.content : getMessageTextContent(v),
+      content: getMessageTextContent(v),
     }));
 
     const modelConfig = {

+ 17 - 6
app/constant.ts

@@ -112,9 +112,20 @@ export const Google = {
 };
 
 export const Baidu = {
-  ExampleEndpoint: "https://aip.baidubce.com",
-  ChatPath: (modelName: string) =>
-    `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${modelName}`,
+  ExampleEndpoint: BAIDU_BASE_URL,
+  ChatPath: (modelName: string) => {
+    let endpoint = modelName;
+    if (modelName === "ernie-4.0-8k") {
+      endpoint = "completions_pro";
+    }
+    if (modelName === "ernie-4.0-8k-preview-0518") {
+      endpoint = "completions_adv_pro";
+    }
+    if (modelName === "ernie-3.5-8k") {
+      endpoint = "completions";
+    }
+    return `/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${endpoint}`;
+  },
 };
 
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@@ -188,11 +199,11 @@ const anthropicModels = [
 
 const baiduModels = [
   "ernie-4.0-turbo-8k",
-  "completions_pro=ernie-4.0-8k",
+  "ernie-4.0-8k",
   "ernie-4.0-8k-preview",
-  "completions_adv_pro=ernie-4.0-8k-preview-0518",
+  "ernie-4.0-8k-preview-0518",
   "ernie-4.0-8k-latest",
-  "completions=ernie-3.5-8k",
+  "ernie-3.5-8k",
   "ernie-3.5-8k-0205",
 ];
 

+ 2 - 4
app/utils/model.ts

@@ -24,13 +24,11 @@ export function collectModelTable(
 
   // default models
   models.forEach((m) => {
-    // supoort name=displayName eg:completions_pro=ernie-4.0-8k
-    const [name, displayName] = m.name?.split("=");
     // using <modelName>@<providerId> as fullName
-    modelTable[`${name}@${m?.provider?.id}`] = {
+    modelTable[`${m.name}@${m?.provider?.id}`] = {
       ...m,
       name,
-      displayName: displayName || name, // 'provider' is copied over if it exists
+      displayName: m.name, // 'provider' is copied over if it exists
     };
   });