model.ts 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import { DEFAULT_MODELS } from "../constant";
  2. import { LLMModel } from "../client/api";
  3. const customProvider = (modelName: string) => ({
  4. id: modelName,
  5. providerName: "Custom",
  6. providerType: "custom",
  7. });
  8. export function collectModelTable(
  9. models: readonly LLMModel[],
  10. customModels: string,
  11. ) {
  12. const modelTable: Record<
  13. string,
  14. {
  15. available: boolean;
  16. name: string;
  17. displayName: string;
  18. provider?: LLMModel["provider"]; // Marked as optional
  19. isDefault?: boolean;
  20. }
  21. > = {};
  22. // default models
  23. models.forEach((m) => {
  24. // supoort name=displayName eg:completions_pro=ernie-4.0-8k
  25. const [name, displayName] = m.name?.split("=");
  26. // using <modelName>@<providerId> as fullName
  27. modelTable[`${name}@${m?.provider?.id}`] = {
  28. ...m,
  29. name,
  30. displayName: displayName || name, // 'provider' is copied over if it exists
  31. };
  32. });
  33. // server custom models
  34. customModels
  35. .split(",")
  36. .filter((v) => !!v && v.length > 0)
  37. .forEach((m) => {
  38. const available = !m.startsWith("-");
  39. const nameConfig =
  40. m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m;
  41. const [name, displayName] = nameConfig.split("=");
  42. // enable or disable all models
  43. if (name === "all") {
  44. Object.values(modelTable).forEach(
  45. (model) => (model.available = available),
  46. );
  47. } else {
  48. // 1. find model by name(), and set available value
  49. let count = 0;
  50. for (const fullName in modelTable) {
  51. if (fullName.split("@").shift() == name) {
  52. count += 1;
  53. modelTable[fullName]["available"] = available;
  54. if (displayName) {
  55. modelTable[fullName]["displayName"] = displayName;
  56. }
  57. }
  58. }
  59. // 2. if model not exists, create new model with available value
  60. if (count === 0) {
  61. const provider = customProvider(name);
  62. modelTable[`${name}@${provider?.id}`] = {
  63. name,
  64. displayName: displayName || name,
  65. available,
  66. provider, // Use optional chaining
  67. };
  68. }
  69. }
  70. });
  71. return modelTable;
  72. }
  73. export function collectModelTableWithDefaultModel(
  74. models: readonly LLMModel[],
  75. customModels: string,
  76. defaultModel: string,
  77. ) {
  78. let modelTable = collectModelTable(models, customModels);
  79. if (defaultModel && defaultModel !== "") {
  80. modelTable[defaultModel] = {
  81. ...modelTable[defaultModel],
  82. name: defaultModel,
  83. available: true,
  84. isDefault: true,
  85. };
  86. }
  87. return modelTable;
  88. }
  89. /**
  90. * Generate full model table.
  91. */
  92. export function collectModels(
  93. models: readonly LLMModel[],
  94. customModels: string,
  95. ) {
  96. const modelTable = collectModelTable(models, customModels);
  97. const allModels = Object.values(modelTable);
  98. return allModels;
  99. }
  100. export function collectModelsWithDefaultModel(
  101. models: readonly LLMModel[],
  102. customModels: string,
  103. defaultModel: string,
  104. ) {
  105. const modelTable = collectModelTableWithDefaultModel(
  106. models,
  107. customModels,
  108. defaultModel,
  109. );
  110. const allModels = Object.values(modelTable);
  111. return allModels;
  112. }
  113. export function isModelAvailableInServer(
  114. customModels: string,
  115. modelName: string,
  116. providerName: string,
  117. ) {
  118. const fullName = `${modelName}@${providerName}`;
  119. const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
  120. return modelTable[fullName]?.available === false;
  121. }