model.ts 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import { DEFAULT_MODELS } from "../constant";
  2. import { LLMModel } from "../client/api";
  3. const customProvider = (providerName: string) => ({
  4. id: providerName.toLowerCase(),
  5. providerName: providerName,
  6. providerType: "custom",
  7. });
  8. export function collectModelTable(
  9. models: readonly LLMModel[],
  10. customModels: string,
  11. ) {
  12. const modelTable: Record<
  13. string,
  14. {
  15. available: boolean;
  16. name: string;
  17. displayName: string;
  18. provider?: LLMModel["provider"]; // Marked as optional
  19. isDefault?: boolean;
  20. }
  21. > = {};
  22. // default models
  23. models.forEach((m) => {
  24. // using <modelName>@<providerId> as fullName
  25. modelTable[`${m.name}@${m?.provider?.id}`] = {
  26. ...m,
  27. displayName: m.name, // 'provider' is copied over if it exists
  28. };
  29. });
  30. // server custom models
  31. customModels
  32. .split(",")
  33. .filter((v) => !!v && v.length > 0)
  34. .forEach((m) => {
  35. const available = !m.startsWith("-");
  36. const nameConfig =
  37. m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m;
  38. let [name, displayName] = nameConfig.split("=");
  39. // enable or disable all models
  40. if (name === "all") {
  41. Object.values(modelTable).forEach(
  42. (model) => (model.available = available),
  43. );
  44. } else {
  45. // 1. find model by name, and set available value
  46. const [customModelName, customProviderName] = name.split("@");
  47. let count = 0;
  48. for (const fullName in modelTable) {
  49. const [modelName, providerName] = fullName.split("@");
  50. if (
  51. customModelName == modelName &&
  52. (customProviderName === undefined ||
  53. customProviderName === providerName)
  54. ) {
  55. count += 1;
  56. modelTable[fullName]["available"] = available;
  57. // swap name and displayName for bytedance
  58. if (providerName === "bytedance") {
  59. [name, displayName] = [displayName, modelName];
  60. modelTable[fullName]["name"] = name;
  61. }
  62. if (displayName) {
  63. modelTable[fullName]["displayName"] = displayName;
  64. }
  65. }
  66. }
  67. // 2. if model not exists, create new model with available value
  68. if (count === 0) {
  69. let [customModelName, customProviderName] = name.split("@");
  70. const provider = customProvider(
  71. customProviderName || customModelName,
  72. );
  73. // swap name and displayName for bytedance
  74. if (displayName && provider.providerName == "ByteDance") {
  75. [customModelName, displayName] = [displayName, customModelName];
  76. }
  77. modelTable[`${customModelName}@${provider?.id}`] = {
  78. name: customModelName,
  79. displayName: displayName || customModelName,
  80. available,
  81. provider, // Use optional chaining
  82. };
  83. }
  84. }
  85. });
  86. return modelTable;
  87. }
  88. export function collectModelTableWithDefaultModel(
  89. models: readonly LLMModel[],
  90. customModels: string,
  91. defaultModel: string,
  92. ) {
  93. let modelTable = collectModelTable(models, customModels);
  94. if (defaultModel && defaultModel !== "") {
  95. const [modelName, providerName] = defaultModel.split("@");
  96. if (providerName && providerName != "") {
  97. modelTable[defaultModel] = {
  98. ...modelTable[defaultModel],
  99. name: modelTable[defaultModel]?.name ?? modelName,
  100. displayName:
  101. modelTable[defaultModel]?.displayName ??
  102. modelName + "(" + providerName + ")",
  103. available: true,
  104. isDefault: true,
  105. };
  106. } else {
  107. for (const key of Object.keys(modelTable)) {
  108. if (modelTable[key].available && key.startsWith(modelName)) {
  109. modelTable[key].isDefault = true;
  110. break;
  111. }
  112. }
  113. }
  114. }
  115. return modelTable;
  116. }
  117. /**
  118. * Generate full model table.
  119. */
  120. export function collectModels(
  121. models: readonly LLMModel[],
  122. customModels: string,
  123. ) {
  124. const modelTable = collectModelTable(models, customModels);
  125. const allModels = Object.values(modelTable);
  126. return allModels;
  127. }
  128. export function collectModelsWithDefaultModel(
  129. models: readonly LLMModel[],
  130. customModels: string,
  131. defaultModel: string,
  132. ) {
  133. const modelTable = collectModelTableWithDefaultModel(
  134. models,
  135. customModels,
  136. defaultModel,
  137. );
  138. const allModels = Object.values(modelTable);
  139. return allModels;
  140. }
  141. export function isModelAvailableInServer(
  142. customModels: string,
  143. modelName: string,
  144. providerName: string,
  145. ) {
  146. const fullName = `${modelName}@${providerName}`;
  147. const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
  148. return modelTable[fullName]?.available === false;
  149. }