model.ts 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import { DEFAULT_MODELS } from "../constant";
  2. import { LLMModel } from "../client/api";
  3. const customProvider = (providerName: string) => ({
  4. id: providerName.toLowerCase(),
  5. providerName: providerName,
  6. providerType: "custom",
  7. });
  8. export function collectModelTable(
  9. models: readonly LLMModel[],
  10. customModels: string,
  11. ) {
  12. const modelTable: Record<
  13. string,
  14. {
  15. available: boolean;
  16. name: string;
  17. displayName: string;
  18. provider?: LLMModel["provider"]; // Marked as optional
  19. isDefault?: boolean;
  20. }
  21. > = {};
  22. // server custom models
  23. customModels
  24. .split(",")
  25. .filter((v) => !!v && v.length > 0)
  26. .forEach((m) => {
  27. const available = !m.startsWith("-");
  28. const nameConfig =
  29. m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m;
  30. let [name, displayName] = nameConfig.split("=");
  31. // enable or disable all models
  32. if (name === "all") {
  33. Object.values(modelTable).forEach(
  34. (model) => (model.available = available),
  35. );
  36. } else {
  37. // 1. find model by name, and set available value
  38. const [customModelName, customProviderName] = name.split("@");
  39. let count = 0;
  40. for (const fullName in modelTable) {
  41. const [modelName, providerName] = fullName.split("@");
  42. if (
  43. customModelName == modelName &&
  44. (customProviderName === undefined ||
  45. customProviderName === providerName)
  46. ) {
  47. count += 1;
  48. modelTable[fullName]["available"] = available;
  49. // swap name and displayName for bytedance
  50. if (providerName === "bytedance") {
  51. [name, displayName] = [displayName, modelName];
  52. modelTable[fullName]["name"] = name;
  53. }
  54. if (displayName) {
  55. modelTable[fullName]["displayName"] = displayName;
  56. }
  57. }
  58. }
  59. // 2. if model not exists, create new model with available value
  60. if (count === 0) {
  61. let [customModelName, customProviderName] = name.split("@");
  62. const provider = customProvider(
  63. customProviderName || customModelName,
  64. );
  65. // swap name and displayName for bytedance
  66. if (displayName && provider.providerName == "ByteDance") {
  67. [customModelName, displayName] = [displayName, customModelName];
  68. }
  69. modelTable[`${customModelName}@${provider?.id}`] = {
  70. name: customModelName,
  71. displayName: displayName || customModelName,
  72. available,
  73. provider, // Use optional chaining
  74. };
  75. }
  76. }
  77. });
  78. // default models
  79. models.forEach((m) => {
  80. // using <modelName>@<providerId> as fullName
  81. modelTable[`${m.name}@${m?.provider?.id}`] = {
  82. ...m,
  83. displayName: m.name, // 'provider' is copied over if it exists
  84. };
  85. });
  86. return modelTable;
  87. }
  88. export function collectModelTableWithDefaultModel(
  89. models: readonly LLMModel[],
  90. customModels: string,
  91. defaultModel: string,
  92. ) {
  93. let modelTable = collectModelTable(models, customModels);
  94. if (defaultModel && defaultModel !== "") {
  95. if (defaultModel.includes("@")) {
  96. if (defaultModel in modelTable) {
  97. modelTable[defaultModel].isDefault = true;
  98. }
  99. } else {
  100. for (const key of Object.keys(modelTable)) {
  101. if (
  102. modelTable[key].available &&
  103. key.split("@").shift() == defaultModel
  104. ) {
  105. modelTable[key].isDefault = true;
  106. break;
  107. }
  108. }
  109. }
  110. }
  111. return modelTable;
  112. }
  113. /**
  114. * Generate full model table.
  115. */
  116. export function collectModels(
  117. models: readonly LLMModel[],
  118. customModels: string,
  119. ) {
  120. const modelTable = collectModelTable(models, customModels);
  121. const allModels = Object.values(modelTable);
  122. return allModels;
  123. }
  124. export function collectModelsWithDefaultModel(
  125. models: readonly LLMModel[],
  126. customModels: string,
  127. defaultModel: string,
  128. ) {
  129. const modelTable = collectModelTableWithDefaultModel(
  130. models,
  131. customModels,
  132. defaultModel,
  133. );
  134. const allModels = Object.values(modelTable);
  135. return allModels;
  136. }
  137. export function isModelAvailableInServer(
  138. customModels: string,
  139. modelName: string,
  140. providerName: string,
  141. ) {
  142. const fullName = `${modelName}@${providerName}`;
  143. const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
  144. return modelTable[fullName]?.available === false;
  145. }