Przeglądaj źródła

chroe: update model name

Fred Liang 1 rok temu
rodzic
commit
ae19a0dc5f

+ 4 - 4
app/client/api.ts

@@ -7,7 +7,7 @@ import {
 } from "../constant";
 import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
 import { ChatGPTApi } from "./platforms/openai";
-import { GeminiApi } from "./platforms/google";
+import { GeminiProApi } from "./platforms/google";
 export const ROLES = ["system", "user", "assistant"] as const;
 export type MessageRole = (typeof ROLES)[number];
 
@@ -86,8 +86,8 @@ export class ClientApi {
   public llm: LLMApi;
 
   constructor(provider: ModelProvider = ModelProvider.GPT) {
-    if (provider === ModelProvider.Gemini) {
-      this.llm = new GeminiApi();
+    if (provider === ModelProvider.GeminiPro) {
+      this.llm = new GeminiProApi();
       return;
     }
     this.llm = new ChatGPTApi();
@@ -146,7 +146,7 @@ export function getHeaders() {
     "x-requested-with": "XMLHttpRequest",
   };
   const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
-  const isGoogle = modelConfig.model === "gemini";
+  const isGoogle = modelConfig.model === "gemini-pro";
   const isAzure = accessStore.provider === ServiceProvider.Azure;
   const authHeader = isAzure ? "api-key" : "Authorization";
   const apiKey = isGoogle

+ 2 - 2
app/client/platforms/google.ts

@@ -9,9 +9,9 @@ import { prettyObject } from "@/app/utils/format";
 import { getClientConfig } from "@/app/config/client";
 import Locale from "../../locales";
 import { getServerSideConfig } from "@/app/config/server";
-export class GeminiApi implements LLMApi {
+export class GeminiProApi implements LLMApi {
   extractMessage(res: any) {
-    console.log("[Response] gemini response: ", res);
+    console.log("[Response] gemini-pro response: ", res);
     return (
       res?.candidates?.at(0)?.content?.parts.at(0)?.text ||
       res?.error?.message ||

+ 2 - 2
app/components/exporter.tsx

@@ -307,8 +307,8 @@ export function PreviewActions(props: {
     setShouldExport(false);
 
     var api: ClientApi;
-    if (config.modelConfig.model === "gemini") {
-      api = new ClientApi(ModelProvider.Gemini);
+    if (config.modelConfig.model === "gemini-pro") {
+      api = new ClientApi(ModelProvider.GeminiPro);
     } else {
       api = new ClientApi(ModelProvider.GPT);
     }

+ 2 - 2
app/components/home.tsx

@@ -171,8 +171,8 @@ export function useLoadData() {
   const config = useAppConfig();
 
   var api: ClientApi;
-  if (config.modelConfig.model === "gemini") {
-    api = new ClientApi(ModelProvider.Gemini);
+  if (config.modelConfig.model === "gemini-pro") {
+    api = new ClientApi(ModelProvider.GeminiPro);
   } else {
     api = new ClientApi(ModelProvider.GPT);
   }

+ 2 - 2
app/constant.ts

@@ -72,7 +72,7 @@ export enum ServiceProvider {
 
 export enum ModelProvider {
   GPT = "GPT",
-  Gemini = "Gemini",
+  GeminiPro = "GeminiPro",
 }
 
 export const OpenaiPath = {
@@ -240,7 +240,7 @@ export const DEFAULT_MODELS = [
     },
   },
   {
-    name: "gemini",
+    name: "gemini-pro",
     available: true,
     provider: {
       id: "google",

+ 1 - 1
app/locales/cn.ts

@@ -325,7 +325,7 @@ const cn = {
         },
 
         ApiVerion: {
-          Title: "接口版本 (gemini api version)",
+          Title: "接口版本 (gemini-pro api version)",
           SubTitle: "选择指定的部分版本",
         },
       },

+ 1 - 1
app/locales/en.ts

@@ -333,7 +333,7 @@ const en: LocaleType = {
         },
 
         ApiVerion: {
-          Title: "API Version (gemini api version)",
+          Title: "API Version (gemini-pro api version)",
           SubTitle: "Select a specific part version",
         },
       },

+ 5 - 5
app/store/chat.ts

@@ -303,8 +303,8 @@ export const useChatStore = createPersistStore(
         });
 
         var api: ClientApi;
-        if (modelConfig.model === "gemini") {
-          api = new ClientApi(ModelProvider.Gemini);
+        if (modelConfig.model === "gemini-pro") {
+          api = new ClientApi(ModelProvider.GeminiPro);
         } else {
           api = new ClientApi(ModelProvider.GPT);
         }
@@ -389,7 +389,7 @@ export const useChatStore = createPersistStore(
         const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts;
 
         var systemPrompts: ChatMessage[] = [];
-        if (modelConfig.model !== "gemini") {
+        if (modelConfig.model !== "gemini-pro") {
           systemPrompts = shouldInjectSystemPrompts
             ? [
                 createMessage({
@@ -488,8 +488,8 @@ export const useChatStore = createPersistStore(
         const modelConfig = session.mask.modelConfig;
 
         var api: ClientApi;
-        if (modelConfig.model === "gemini") {
-          api = new ClientApi(ModelProvider.Gemini);
+        if (modelConfig.model === "gemini-pro") {
+          api = new ClientApi(ModelProvider.GeminiPro);
         } else {
           api = new ClientApi(ModelProvider.GPT);
         }