model.ts 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import { DEFAULT_MODELS, ServiceProvider } from "../constant";
  2. import { LLMModel } from "../client/api";
  3. const CustomSeq = {
  4. val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
  5. cache: new Map<string, number>(),
  6. next: (id: string) => {
  7. if (CustomSeq.cache.has(id)) {
  8. return CustomSeq.cache.get(id) as number;
  9. } else {
  10. let seq = CustomSeq.val++;
  11. CustomSeq.cache.set(id, seq);
  12. return seq;
  13. }
  14. },
  15. };
  16. const customProvider = (providerName: string) => ({
  17. id: providerName.toLowerCase(),
  18. providerName: providerName,
  19. providerType: "custom",
  20. sorted: CustomSeq.next(providerName),
  21. });
  22. /**
  23. * Sorts an array of models based on specified rules.
  24. *
  25. * First, sorted by provider; if the same, sorted by model
  26. */
  27. const sortModelTable = (models: ReturnType<typeof collectModels>) =>
  28. models.sort((a, b) => {
  29. if (a.provider && b.provider) {
  30. let cmp = a.provider.sorted - b.provider.sorted;
  31. return cmp === 0 ? a.sorted - b.sorted : cmp;
  32. } else {
  33. return a.sorted - b.sorted;
  34. }
  35. });
  36. /**
  37. * get model name and provider from a formatted string,
  38. * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
  39. * @param modelWithProvider model name with provider separated by last `@` char,
  40. * @returns [model, provider] tuple, if no `@` char found, provider is undefined
  41. */
  42. export function getModelProvider(modelWithProvider: string): [string, string?] {
  43. const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
  44. return [model, provider];
  45. }
  46. export function collectModelTable(
  47. models: readonly LLMModel[],
  48. customModels: string,
  49. ) {
  50. const modelTable: Record<
  51. string,
  52. {
  53. available: boolean;
  54. name: string;
  55. displayName: string;
  56. sorted: number;
  57. provider?: LLMModel["provider"]; // Marked as optional
  58. isDefault?: boolean;
  59. }
  60. > = {};
  61. // default models
  62. models.forEach((m) => {
  63. // using <modelName>@<providerId> as fullName
  64. modelTable[`${m.name}@${m?.provider?.id}`] = {
  65. ...m,
  66. displayName: m.name, // 'provider' is copied over if it exists
  67. };
  68. });
  69. // server custom models
  70. customModels
  71. .split(",")
  72. .filter((v) => !!v && v.length > 0)
  73. .forEach((m) => {
  74. const available = !m.startsWith("-");
  75. const nameConfig =
  76. m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m;
  77. let [name, displayName] = nameConfig.split("=");
  78. // enable or disable all models
  79. if (name === "all") {
  80. Object.values(modelTable).forEach(
  81. (model) => (model.available = available),
  82. );
  83. } else {
  84. // 1. find model by name, and set available value
  85. const [customModelName, customProviderName] = getModelProvider(name);
  86. let count = 0;
  87. for (const fullName in modelTable) {
  88. const [modelName, providerName] = getModelProvider(fullName);
  89. if (
  90. customModelName == modelName &&
  91. (customProviderName === undefined ||
  92. customProviderName === providerName)
  93. ) {
  94. count += 1;
  95. modelTable[fullName]["available"] = available;
  96. // swap name and displayName for bytedance
  97. if (providerName === "bytedance") {
  98. [name, displayName] = [displayName, modelName];
  99. modelTable[fullName]["name"] = name;
  100. }
  101. if (displayName) {
  102. modelTable[fullName]["displayName"] = displayName;
  103. }
  104. }
  105. }
  106. // 2. if model not exists, create new model with available value
  107. if (count === 0) {
  108. let [customModelName, customProviderName] = getModelProvider(name);
  109. const provider = customProvider(
  110. customProviderName || customModelName,
  111. );
  112. // swap name and displayName for bytedance
  113. if (displayName && provider.providerName == "ByteDance") {
  114. [customModelName, displayName] = [displayName, customModelName];
  115. }
  116. modelTable[`${customModelName}@${provider?.id}`] = {
  117. name: customModelName,
  118. displayName: displayName || customModelName,
  119. available,
  120. provider, // Use optional chaining
  121. sorted: CustomSeq.next(`${customModelName}@${provider?.id}`),
  122. };
  123. }
  124. }
  125. });
  126. return modelTable;
  127. }
  128. export function collectModelTableWithDefaultModel(
  129. models: readonly LLMModel[],
  130. customModels: string,
  131. defaultModel: string,
  132. ) {
  133. let modelTable = collectModelTable(models, customModels);
  134. if (defaultModel && defaultModel !== "") {
  135. if (defaultModel.includes("@")) {
  136. if (defaultModel in modelTable) {
  137. modelTable[defaultModel].isDefault = true;
  138. }
  139. } else {
  140. for (const key of Object.keys(modelTable)) {
  141. if (
  142. modelTable[key].available &&
  143. getModelProvider(key)[0] == defaultModel
  144. ) {
  145. modelTable[key].isDefault = true;
  146. break;
  147. }
  148. }
  149. }
  150. }
  151. return modelTable;
  152. }
  153. /**
  154. * Generate full model table.
  155. */
  156. export function collectModels(
  157. models: readonly LLMModel[],
  158. customModels: string,
  159. ) {
  160. const modelTable = collectModelTable(models, customModels);
  161. let allModels = Object.values(modelTable);
  162. allModels = sortModelTable(allModels);
  163. return allModels;
  164. }
  165. export function collectModelsWithDefaultModel(
  166. models: readonly LLMModel[],
  167. customModels: string,
  168. defaultModel: string,
  169. ) {
  170. const modelTable = collectModelTableWithDefaultModel(
  171. models,
  172. customModels,
  173. defaultModel,
  174. );
  175. let allModels = Object.values(modelTable);
  176. allModels = sortModelTable(allModels);
  177. return allModels;
  178. }
  179. export function isModelAvailableInServer(
  180. customModels: string,
  181. modelName: string,
  182. providerName: string,
  183. ) {
  184. const fullName = `${modelName}@${providerName}`;
  185. const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
  186. return modelTable[fullName]?.available === false;
  187. }
  188. /**
  189. * Check if the model name is a GPT-4 related model
  190. *
  191. * @param modelName The name of the model to check
  192. * @returns True if the model is a GPT-4 related model (excluding gpt-4o-mini)
  193. */
  194. export function isGPT4Model(modelName: string): boolean {
  195. return (
  196. (modelName.startsWith("gpt-4") ||
  197. modelName.startsWith("chatgpt-4o") ||
  198. modelName.startsWith("o1")) &&
  199. !modelName.startsWith("gpt-4o-mini")
  200. );
  201. }
  202. /**
  203. * Checks if a model is not available on any of the specified providers in the server.
  204. *
  205. * @param {string} customModels - A string of custom models, comma-separated.
  206. * @param {string} modelName - The name of the model to check.
  207. * @param {string|string[]} providerNames - A string or array of provider names to check against.
  208. *
  209. * @returns {boolean} True if the model is not available on any of the specified providers, false otherwise.
  210. */
  211. export function isModelNotavailableInServer(
  212. customModels: string,
  213. modelName: string,
  214. providerNames: string | string[],
  215. ): boolean {
  216. // Check DISABLE_GPT4 environment variable
  217. if (
  218. process.env.DISABLE_GPT4 === "1" &&
  219. isGPT4Model(modelName.toLowerCase())
  220. ) {
  221. return true;
  222. }
  223. const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
  224. const providerNamesArray = Array.isArray(providerNames)
  225. ? providerNames
  226. : [providerNames];
  227. for (const providerName of providerNamesArray) {
  228. // if model provider is bytedance, use model config name to check if not avaliable
  229. if (providerName === ServiceProvider.ByteDance) {
  230. return !Object.values(modelTable).filter((v) => v.name === modelName)?.[0]
  231. ?.available;
  232. }
  233. const fullName = `${modelName}@${providerName.toLowerCase()}`;
  234. if (modelTable?.[fullName]?.available === true) return false;
  235. }
  236. return true;
  237. }