Ver código fonte

feat: Support a way to define default model by adding DEFAULT_MODEL env.

Wayland Zhan 1 ano atrás
pai
commit
c96e4b7966

+ 1 - 0
app/api/config/route.ts

@@ -13,6 +13,7 @@ const DANGER_CONFIG = {
   hideBalanceQuery: serverConfig.hideBalanceQuery,
   disableFastLink: serverConfig.disableFastLink,
   customModels: serverConfig.customModels,
+  defaultModel: serverConfig.defaultModel,
 };
 
 declare global {

+ 22 - 7
app/components/chat.tsx

@@ -448,10 +448,20 @@ export function ChatActions(props: {
   // switch model
   const currentModel = chatStore.currentSession().mask.modelConfig.model;
   const allModels = useAllModels();
-  const models = useMemo(
-    () => allModels.filter((m) => m.available),
-    [allModels],
-  );
+  const models = useMemo(() => {
+    const filteredModels = allModels.filter((m) => m.available);
+    const defaultModel = filteredModels.find((m) => m.isDefault);
+
+    if (defaultModel) {
+      const arr = [
+        defaultModel,
+        ...filteredModels.filter((m) => m !== defaultModel),
+      ];
+      return arr;
+    } else {
+      return filteredModels;
+    }
+  }, [allModels]);
   const [showModelSelector, setShowModelSelector] = useState(false);
   const [showUploadImage, setShowUploadImage] = useState(false);
 
@@ -467,7 +477,10 @@ export function ChatActions(props: {
     // switch to first available model
     const isUnavaliableModel = !models.some((m) => m.name === currentModel);
     if (isUnavaliableModel && models.length > 0) {
-      const nextModel = models[0].name as ModelType;
+      // show next model to default model if exist
+      let nextModel: ModelType = (
+        models.find((model) => model.isDefault) || models[0]
+      ).name;
       chatStore.updateCurrentSession(
         (session) => (session.mask.modelConfig.model = nextModel),
       );
@@ -1102,11 +1115,13 @@ function _Chat() {
     };
     // eslint-disable-next-line react-hooks/exhaustive-deps
   }, []);
-  
+
   const handlePaste = useCallback(
     async (event: React.ClipboardEvent<HTMLTextAreaElement>) => {
       const currentModel = chatStore.currentSession().mask.modelConfig.model;
-      if(!isVisionModel(currentModel)){return;}
+      if (!isVisionModel(currentModel)) {
+        return;
+      }
       const items = (event.clipboardData || window.clipboardData).items;
       for (const item of items) {
         if (item.kind === "file" && item.type.startsWith("image/")) {

+ 4 - 0
app/config/server.ts

@@ -21,6 +21,7 @@ declare global {
       ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
       DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
       CUSTOM_MODELS?: string; // to control custom models
+      DEFAULT_MODEL?: string; // to cnntrol default model in every new chat window
 
       // azure only
       AZURE_URL?: string; // https://{azure-url}/openai/deployments/{deploy-name}
@@ -59,12 +60,14 @@ export const getServerSideConfig = () => {
 
   const disableGPT4 = !!process.env.DISABLE_GPT4;
   let customModels = process.env.CUSTOM_MODELS ?? "";
+  let defaultModel = process.env.DEFAULT_MODEL ?? "";
 
   if (disableGPT4) {
     if (customModels) customModels += ",";
     customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4"))
       .map((m) => "-" + m.name)
       .join(",");
+    if (defaultModel.startsWith("gpt-4")) defaultModel = "";
   }
 
   const isAzure = !!process.env.AZURE_URL;
@@ -116,6 +119,7 @@ export const getServerSideConfig = () => {
     hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
     disableFastLink: !!process.env.DISABLE_FAST_LINK,
     customModels,
+    defaultModel,
     whiteWebDevEndpoints,
   };
 };

+ 9 - 0
app/store/access.ts

@@ -8,6 +8,7 @@ import { getHeaders } from "../client/api";
 import { getClientConfig } from "../config/client";
 import { createPersistStore } from "../utils/store";
 import { ensure } from "../utils/clone";
+import { DEFAULT_CONFIG } from "./config";
 
 let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
 
@@ -48,6 +49,7 @@ const DEFAULT_ACCESS_STATE = {
   disableGPT4: false,
   disableFastLink: false,
   customModels: "",
+  defaultModel: "",
 };
 
 export const useAccessStore = createPersistStore(
@@ -100,6 +102,13 @@ export const useAccessStore = createPersistStore(
         },
       })
         .then((res) => res.json())
+        .then((res) => {
+          // Set default model from env request
+          let defaultModel = res.defaultModel ?? "";
+          DEFAULT_CONFIG.modelConfig.model =
+            defaultModel !== "" ? defaultModel : "gpt-3.5-turbo";
+          return res;
+        })
         .then((res: DangerConfig) => {
           console.log("[Config] got config from server", res);
           set(() => ({ ...res }));

+ 3 - 2
app/utils/hooks.ts

@@ -1,14 +1,15 @@
 import { useMemo } from "react";
 import { useAccessStore, useAppConfig } from "../store";
-import { collectModels } from "./model";
+import { collectModels, collectModelsWithDefaultModel } from "./model";
 
 export function useAllModels() {
   const accessStore = useAccessStore();
   const configStore = useAppConfig();
   const models = useMemo(() => {
-    return collectModels(
+    return collectModelsWithDefaultModel(
       configStore.models,
       [configStore.customModels, accessStore.customModels].join(","),
+      accessStore.defaultModel,
     );
   }, [accessStore.customModels, configStore.customModels, configStore.models]);
 

+ 42 - 6
app/utils/model.ts

@@ -1,5 +1,11 @@
 import { LLMModel } from "../client/api";
 
+const customProvider = (modelName: string) => ({
+  id: modelName,
+  providerName: "",
+  providerType: "custom",
+});
+
 export function collectModelTable(
   models: readonly LLMModel[],
   customModels: string,
@@ -11,6 +17,7 @@ export function collectModelTable(
       name: string;
       displayName: string;
       provider?: LLMModel["provider"]; // Marked as optional
+      isDefault?: boolean;
     }
   > = {};
 
@@ -22,12 +29,6 @@ export function collectModelTable(
     };
   });
 
-  const customProvider = (modelName: string) => ({
-    id: modelName,
-    providerName: "",
-    providerType: "custom",
-  });
-
   // server custom models
   customModels
     .split(",")
@@ -52,6 +53,27 @@ export function collectModelTable(
         };
       }
     });
+
+  return modelTable;
+}
+
+export function collectModelTableWithDefaultModel(
+  models: readonly LLMModel[],
+  customModels: string,
+  defaultModel: string,
+) {
+  let modelTable = collectModelTable(models, customModels);
+  if (defaultModel && defaultModel !== "") {
+    delete modelTable[defaultModel];
+    modelTable[defaultModel] = {
+      name: defaultModel,
+      displayName: defaultModel,
+      available: true,
+      provider:
+        modelTable[defaultModel]?.provider ?? customProvider(defaultModel),
+      isDefault: true,
+    };
+  }
   return modelTable;
 }
 
@@ -67,3 +89,17 @@ export function collectModels(
 
   return allModels;
 }
+
+export function collectModelsWithDefaultModel(
+  models: readonly LLMModel[],
+  customModels: string,
+  defaultModel: string,
+) {
+  const modelTable = collectModelTableWithDefaultModel(
+    models,
+    customModels,
+    defaultModel,
+  );
+  const allModels = Object.values(modelTable);
+  return allModels;
+}