Browse Source

Custom model names can include the `@` symbol by itself.

To specify the model's provider, append it after the model name using `@` as before.

This format supports cases like `google vertex ai` with a model name like `claude-3-5-sonnet@20240620`.

For instance, `claude-3-5-sonnet@20240620@vertex-ai` will be split by `split(/@(?!.*@)/)` into:

`[ 'claude-3-5-sonnet@20240620', 'vertex-ai' ]`, where the former is the model name and the latter is the custom provider.
ryanhex53 1 năm trước cách đây
mục cha
commit
b844045d23

+ 1 - 1
app/api/common.ts

@@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
         .filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
         .forEach((m) => {
           const [fullName, displayName] = m.split("=");
-          const [_, providerName] = fullName.split("@");
+          const [_, providerName] = fullName.split(/@(?!.*@)/);
           if (providerName === "azure" && !displayName) {
             const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
               "deployments/",

+ 1 - 1
app/components/chat.tsx

@@ -645,7 +645,7 @@ export function ChatActions(props: {
           onClose={() => setShowModelSelector(false)}
           onSelection={(s) => {
             if (s.length === 0) return;
-            const [model, providerName] = s[0].split("@");
+            const [model, providerName] = s[0].split(/@(?!.*@)/);
             chatStore.updateCurrentSession((session) => {
               session.mask.modelConfig.model = model as ModelType;
               session.mask.modelConfig.providerName =

+ 4 - 2
app/components/model-config.tsx

@@ -28,7 +28,8 @@ export function ModelConfigList(props: {
           value={value}
           align="left"
           onChange={(e) => {
-            const [model, providerName] = e.currentTarget.value.split("@");
+            const [model, providerName] =
+              e.currentTarget.value.split(/@(?!.*@)/);
             props.updateConfig((config) => {
               config.model = ModalConfigValidator.model(model);
               config.providerName = providerName as ServiceProvider;
@@ -247,7 +248,8 @@ export function ModelConfigList(props: {
           aria-label={Locale.Settings.CompressModel.Title}
           value={compressModelValue}
           onChange={(e) => {
-            const [model, providerName] = e.currentTarget.value.split("@");
+            const [model, providerName] =
+              e.currentTarget.value.split(/@(?!.*@)/);
             props.updateConfig((config) => {
               config.compressModel = ModalConfigValidator.model(model);
               config.compressProviderName = providerName as ServiceProvider;

+ 1 - 1
app/store/access.ts

@@ -226,7 +226,7 @@ export const useAccessStore = createPersistStore(
         .then((res) => {
           const defaultModel = res.defaultModel ?? "";
           if (defaultModel !== "") {
-            const [model, providerName] = defaultModel.split("@");
+            const [model, providerName] = defaultModel.split(/@(?!.*@)/);
             DEFAULT_CONFIG.modelConfig.model = model;
             DEFAULT_CONFIG.modelConfig.providerName = providerName;
           }

+ 4 - 4
app/utils/model.ts

@@ -79,10 +79,10 @@ export function collectModelTable(
         );
       } else {
         // 1. find model by name, and set available value
-        const [customModelName, customProviderName] = name.split("@");
+        const [customModelName, customProviderName] = name.split(/@(?!.*@)/);
         let count = 0;
         for (const fullName in modelTable) {
-          const [modelName, providerName] = fullName.split("@");
+          const [modelName, providerName] = fullName.split(/@(?!.*@)/);
           if (
             customModelName == modelName &&
             (customProviderName === undefined ||
@@ -102,7 +102,7 @@ export function collectModelTable(
         }
         // 2. if model not exists, create new model with available value
         if (count === 0) {
-          let [customModelName, customProviderName] = name.split("@");
+          let [customModelName, customProviderName] = name.split(/@(?!.*@)/);
           const provider = customProvider(
             customProviderName || customModelName,
           );
@@ -139,7 +139,7 @@ export function collectModelTableWithDefaultModel(
       for (const key of Object.keys(modelTable)) {
         if (
           modelTable[key].available &&
-          key.split("@").shift() == defaultModel
+          key.split(/@(?!.*@)/).shift() == defaultModel
         ) {
           modelTable[key].isDefault = true;
           break;