Browse Source

🔨 refactor(model): 更改原先的实现方法,在 collect table 函数后面增加额外的 sort 处理

frostime 1 year ago
parent
commit
b023a00445
1 changed files with 39 additions and 11 deletions
  1. 39 11
      app/utils/model.ts

+ 39 - 11
app/utils/model.ts

@@ -7,6 +7,29 @@ const customProvider = (providerName: string) => ({
   providerType: "custom",
 });
 
+const sortModelTable = (
+  models: ReturnType<typeof collectModels>,
+  rule: "custom-first" | "default-first",
+) =>
+  models.sort((a, b) => {
+    if (a.provider === undefined && b.provider === undefined) {
+      return 0;
+    }
+
+    let aIsCustom = a.provider?.providerType === "custom";
+    let bIsCustom = b.provider?.providerType === "custom";
+
+    if (aIsCustom === bIsCustom) {
+      return 0;
+    }
+
+    if (aIsCustom) {
+      return rule === "custom-first" ? -1 : 1;
+    } else {
+      return rule === "custom-first" ? 1 : -1;
+    }
+  });
+
 export function collectModelTable(
   models: readonly LLMModel[],
   customModels: string,
@@ -22,6 +45,15 @@ export function collectModelTable(
     }
   > = {};
 
+  // default models
+  models.forEach((m) => {
+    // using <modelName>@<providerId> as fullName
+    modelTable[`${m.name}@${m?.provider?.id}`] = {
+      ...m,
+      displayName: m.name, // 'provider' is copied over if it exists
+    };
+  });
+
   // server custom models
   customModels
     .split(",")
@@ -80,15 +112,6 @@ export function collectModelTable(
       }
     });
 
-  // default models
-  models.forEach((m) => {
-    // using <modelName>@<providerId> as fullName
-    modelTable[`${m.name}@${m?.provider?.id}`] = {
-      ...m,
-      displayName: m.name, // 'provider' is copied over if it exists
-    };
-  });
-
   return modelTable;
 }
 
@@ -126,7 +149,9 @@ export function collectModels(
   customModels: string,
 ) {
   const modelTable = collectModelTable(models, customModels);
-  const allModels = Object.values(modelTable);
+  let allModels = Object.values(modelTable);
+
+  allModels = sortModelTable(allModels, "custom-first");
 
   return allModels;
 }
@@ -141,7 +166,10 @@ export function collectModelsWithDefaultModel(
     customModels,
     defaultModel,
   );
-  const allModels = Object.values(modelTable);
+  let allModels = Object.values(modelTable);
+
+  allModels = sortModelTable(allModels, "custom-first");
+
   return allModels;
 }