Explorar el Código

✨ feat(model): 增加 sorted 字段,并使用该字段对模型列表进行排序

1. 在 Model 和 Provider 类型中增加 sorted 字段(api.ts)
2. 默认模型在初始化的时候,自动设置默认 sorted 字段,从 1000 开始自增长(constant.ts)
3. 自定义模型更新的时候,自动分配 sorted 字段(model.ts)
frostime hace 1 año
padre
commit
150fc84b9b
Se han modificado 3 ficheros con 50 adiciones y 20 borrados
  1. 2 0
      app/client/api.ts
  2. 19 0
      app/constant.ts
  3. 29 20
      app/utils/model.ts

+ 2 - 0
app/client/api.ts

@@ -64,12 +64,14 @@ export interface LLMModel {
   displayName?: string;
   available: boolean;
   provider: LLMModelProvider;
+  sorted: number;
 }
 
 export interface LLMModelProvider {
   id: string;
   providerName: string;
   providerType: string;
+  sorted: number;
 }
 
 export abstract class LLMApi {

+ 19 - 0
app/constant.ts

@@ -320,86 +320,105 @@ const tencentModels = [
 
 const moonshotModes = ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"];
 
+let seq = 1000; // 内置的模型序号生成器从1000开始
 export const DEFAULT_MODELS = [
   ...openaiModels.map((name) => ({
     name,
     available: true,
+    sorted: seq++, // Global sequence sort(index)
     provider: {
       id: "openai",
       providerName: "OpenAI",
       providerType: "openai",
+      sorted: 1, // 这里是固定的,确保顺序与之前内置的版本一致
     },
   })),
   ...openaiModels.map((name) => ({
     name,
     available: true,
+    sorted: seq++,
     provider: {
       id: "azure",
       providerName: "Azure",
       providerType: "azure",
+      sorted: 2,
     },
   })),
   ...googleModels.map((name) => ({
     name,
     available: true,
+    sorted: seq++,
     provider: {
       id: "google",
       providerName: "Google",
       providerType: "google",
+      sorted: 3,
     },
   })),
   ...anthropicModels.map((name) => ({
     name,
     available: true,
+    sorted: seq++,
     provider: {
       id: "anthropic",
       providerName: "Anthropic",
       providerType: "anthropic",
+      sorted: 4,
     },
   })),
   ...baiduModels.map((name) => ({
     name,
     available: true,
+    sorted: seq++,
     provider: {
       id: "baidu",
       providerName: "Baidu",
       providerType: "baidu",
+      sorted: 5,
     },
   })),
   ...bytedanceModels.map((name) => ({
     name,
     available: true,
+    sorted: seq++,
     provider: {
       id: "bytedance",
       providerName: "ByteDance",
       providerType: "bytedance",
+      sorted: 6,
     },
   })),
   ...alibabaModes.map((name) => ({
     name,
     available: true,
+    sorted: seq++,
     provider: {
       id: "alibaba",
       providerName: "Alibaba",
       providerType: "alibaba",
+      sorted: 7,
     },
   })),
   ...tencentModels.map((name) => ({
     name,
     available: true,
+    sorted: seq++,
     provider: {
       id: "tencent",
       providerName: "Tencent",
       providerType: "tencent",
+      sorted: 8,
     },
   })),
   ...moonshotModes.map((name) => ({
     name,
     available: true,
+    sorted: seq++,
     provider: {
       id: "moonshot",
       providerName: "Moonshot",
       providerType: "moonshot",
+      sorted: 9,
     },
   })),
 ] as const;

+ 29 - 20
app/utils/model.ts

@@ -1,32 +1,39 @@
 import { DEFAULT_MODELS } from "../constant";
 import { LLMModel } from "../client/api";
 
+const CustomSeq = {
+  val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
+  cache: new Map<string, number>(),
+  next: (id: string) => {
+    if (CustomSeq.cache.has(id)) {
+      return CustomSeq.cache.get(id) as number;
+    } else {
+      let seq = CustomSeq.val++;
+      CustomSeq.cache.set(id, seq);
+      return seq;
+    }
+  },
+};
+
 const customProvider = (providerName: string) => ({
   id: providerName.toLowerCase(),
   providerName: providerName,
   providerType: "custom",
+  sorted: CustomSeq.next(providerName),
 });
 
-const sortModelTable = (
-  models: ReturnType<typeof collectModels>,
-  rule: "custom-first" | "default-first",
-) =>
+/**
+ * Sorts an array of models based on specified rules.
+ *
+ * First, sorted by provider; if the same, sorted by model
+ */
+const sortModelTable = (models: ReturnType<typeof collectModels>) =>
   models.sort((a, b) => {
-    if (a.provider === undefined && b.provider === undefined) {
-      return 0;
-    }
-
-    let aIsCustom = a.provider?.providerType === "custom";
-    let bIsCustom = b.provider?.providerType === "custom";
-
-    if (aIsCustom === bIsCustom) {
-      return 0;
-    }
-
-    if (aIsCustom) {
-      return rule === "custom-first" ? -1 : 1;
+    if (a.provider && b.provider) {
+      let cmp = a.provider.sorted - b.provider.sorted;
+      return cmp === 0 ? a.sorted - b.sorted : cmp;
     } else {
-      return rule === "custom-first" ? 1 : -1;
+      return a.sorted - b.sorted;
     }
   });
 
@@ -40,6 +47,7 @@ export function collectModelTable(
       available: boolean;
       name: string;
       displayName: string;
+      sorted: number;
       provider?: LLMModel["provider"]; // Marked as optional
       isDefault?: boolean;
     }
@@ -107,6 +115,7 @@ export function collectModelTable(
             displayName: displayName || customModelName,
             available,
             provider, // Use optional chaining
+            sorted: CustomSeq.next(`${customModelName}@${provider?.id}`),
           };
         }
       }
@@ -151,7 +160,7 @@ export function collectModels(
   const modelTable = collectModelTable(models, customModels);
   let allModels = Object.values(modelTable);
 
-  allModels = sortModelTable(allModels, "custom-first");
+  allModels = sortModelTable(allModels);
 
   return allModels;
 }
@@ -168,7 +177,7 @@ export function collectModelsWithDefaultModel(
   );
   let allModels = Object.values(modelTable);
 
-  allModels = sortModelTable(allModels, "custom-first");
+  allModels = sortModelTable(allModels);
 
   return allModels;
 }