Переглянути джерело

support azure deployment name

lloydzhou 1 рік тому
батько
коміт
1c20137b0e

+ 1 - 1
app/api/auth.ts

@@ -75,7 +75,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
         break;
       case ModelProvider.GPT:
       default:
-        if (serverConfig.isAzure) {
+        if (req.nextUrl.pathname.includes("azure/deployments")) {
           systemApiKey = serverConfig.azureApiKey;
         } else {
           systemApiKey = serverConfig.apiKey;

+ 10 - 9
app/api/common.ts

@@ -14,9 +14,11 @@ const serverConfig = getServerSideConfig();
 export async function requestOpenai(req: NextRequest) {
   const controller = new AbortController();
 
+  const isAzure = req.nextUrl.pathname.includes("azure/deployments");
+
   var authValue,
     authHeaderName = "";
-  if (serverConfig.isAzure) {
+  if (isAzure) {
     authValue =
       req.headers
         .get("Authorization")
@@ -56,14 +58,13 @@ export async function requestOpenai(req: NextRequest) {
     10 * 60 * 1000,
   );
 
-  if (serverConfig.isAzure) {
-    if (!serverConfig.azureApiVersion) {
-      return NextResponse.json({
-        error: true,
-        message: `missing AZURE_API_VERSION in server env vars`,
-      });
-    }
-    path = makeAzurePath(path, serverConfig.azureApiVersion);
+  if (isAzure) {
+    const azureApiVersion = req?.nextUrl?.searchParams?.get("api-version");
+    baseUrl = baseUrl.split("/deployments").shift();
+    path = `${req.nextUrl.pathname.replaceAll(
+      "/api/azure/",
+      "",
+    )}?api-version=${azureApiVersion}`;
   }
 
   const fetchUrl = `${baseUrl}/${path}`;

+ 13 - 6
app/client/api.ts

@@ -30,6 +30,7 @@ export interface RequestMessage {
 
 export interface LLMConfig {
   model: string;
+  providerName?: string;
   temperature?: number;
   top_p?: number;
   stream?: boolean;
@@ -54,6 +55,7 @@ export interface LLMUsage {
 
 export interface LLMModel {
   name: string;
+  displayName?: string;
   available: boolean;
   provider: LLMModelProvider;
 }
@@ -160,10 +162,14 @@ export function getHeaders() {
     Accept: "application/json",
   };
   const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
-  const isGoogle = modelConfig.model.startsWith("gemini");
-  const isAzure = accessStore.provider === ServiceProvider.Azure;
-  const isAnthropic = accessStore.provider === ServiceProvider.Anthropic;
-  const authHeader = isAzure ? "api-key" : isAnthropic ? 'x-api-key' : "Authorization";
+  const isGoogle = modelConfig.providerName == ServiceProvider.Azure;
+  const isAzure = modelConfig.providerName === ServiceProvider.Azure;
+  const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
+  const authHeader = isAzure
+    ? "api-key"
+    : isAnthropic
+    ? "x-api-key"
+    : "Authorization";
   const apiKey = isGoogle
     ? accessStore.googleApiKey
     : isAzure
@@ -172,7 +178,8 @@ export function getHeaders() {
     ? accessStore.anthropicApiKey
     : accessStore.openaiApiKey;
   const clientConfig = getClientConfig();
-  const makeBearer = (s: string) => `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
+  const makeBearer = (s: string) =>
+    `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
   const validString = (x: string) => x && x.length > 0;
 
   // when using google api in app, not set auth header
@@ -185,7 +192,7 @@ export function getHeaders() {
       validString(accessStore.accessCode)
     ) {
       // access_code must send with header named `Authorization`, will using in auth middleware.
-      headers['Authorization'] = makeBearer(
+      headers["Authorization"] = makeBearer(
         ACCESS_CODE_PREFIX + accessStore.accessCode,
       );
     }

+ 40 - 1
app/client/platforms/openai.ts

@@ -1,13 +1,16 @@
 "use client";
+// azure and openai, using same models. so using same LLMApi.
 import {
   ApiPath,
   DEFAULT_API_HOST,
   DEFAULT_MODELS,
   OpenaiPath,
+  Azure,
   REQUEST_TIMEOUT_MS,
   ServiceProvider,
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
+import { collectModelsWithDefaultModel } from "@/app/utils/model";
 
 import {
   ChatOptions,
@@ -97,6 +100,15 @@ export class ChatGPTApi implements LLMApi {
     return [baseUrl, path].join("/");
   }
 
+  getBaseUrl(apiPath: string) {
+    const isApp = !!getClientConfig()?.isApp;
+    let baseUrl = isApp ? DEFAULT_API_HOST + "/proxy" + apiPath : apiPath;
+    if (baseUrl.endsWith("/")) {
+      baseUrl = baseUrl.slice(0, baseUrl.length - 1);
+    }
+    return baseUrl + "/";
+  }
+
   extractMessage(res: any) {
     return res.choices?.at(0)?.message?.content ?? "";
   }
@@ -113,6 +125,7 @@ export class ChatGPTApi implements LLMApi {
       ...useChatStore.getState().currentSession().mask.modelConfig,
       ...{
         model: options.config.model,
+        providerName: options.config.providerName,
       },
     };
 
@@ -140,7 +153,33 @@ export class ChatGPTApi implements LLMApi {
     options.onController?.(controller);
 
     try {
-      const chatPath = this.path(OpenaiPath.ChatPath);
+      let chatPath = "";
+      if (modelConfig.providerName == ServiceProvider.Azure) {
+        // find model, and get displayName as deployName
+        const { models: configModels, customModels: configCustomModels } =
+          useAppConfig.getState();
+        const { defaultModel, customModels: accessCustomModels } =
+          useAccessStore.getState();
+
+        const models = collectModelsWithDefaultModel(
+          configModels,
+          [configCustomModels, accessCustomModels].join(","),
+          defaultModel,
+        );
+        const model = models.find(
+          (model) =>
+            model.name == modelConfig.model &&
+            model?.provider.providerName == ServiceProvider.Azure,
+        );
+        chatPath =
+          this.getBaseUrl(ApiPath.Azure) +
+          Azure.ChatPath(
+            model?.displayName ?? model.name,
+            useAccessStore.getState().azureApiVersion,
+          );
+      } else {
+        chatPath = this.getBaseUrl(ApiPath.OpenAI) + OpenaiPath.ChatPath;
+      }
       const chatPayload = {
         method: "POST",
         body: JSON.stringify(requestPayload),

+ 23 - 12
app/components/chat.tsx

@@ -88,6 +88,7 @@ import {
   Path,
   REQUEST_TIMEOUT_MS,
   UNFINISHED_INPUT,
+  ServiceProvider,
 } from "../constant";
 import { Avatar } from "./emoji";
 import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask";
@@ -448,6 +449,9 @@ export function ChatActions(props: {
 
   // switch model
   const currentModel = chatStore.currentSession().mask.modelConfig.model;
+  const currentProviderName =
+    chatStore.currentSession().mask.modelConfig?.providerName ||
+    ServiceProvider.OpenAI;
   const allModels = useAllModels();
   const models = useMemo(() => {
     const filteredModels = allModels.filter((m) => m.available);
@@ -479,13 +483,13 @@ export function ChatActions(props: {
     const isUnavaliableModel = !models.some((m) => m.name === currentModel);
     if (isUnavaliableModel && models.length > 0) {
       // show next model to default model if exist
-      let nextModel: ModelType = (
-        models.find((model) => model.isDefault) || models[0]
-      ).name;
-      chatStore.updateCurrentSession(
-        (session) => (session.mask.modelConfig.model = nextModel),
-      );
-      showToast(nextModel);
+      let nextModel = models.find((model) => model.isDefault) || models[0];
+      chatStore.updateCurrentSession((session) => {
+        session.mask.modelConfig.model = nextModel.name;
+        session.mask.modelConfig.providerName = nextModel?.provider
+          ?.providerName as ServiceProvider;
+      });
+      showToast(nextModel.name);
     }
   }, [chatStore, currentModel, models]);
 
@@ -573,19 +577,26 @@ export function ChatActions(props: {
 
       {showModelSelector && (
         <Selector
-          defaultSelectedValue={currentModel}
+          defaultSelectedValue={`${currentModel}@${currentProviderName}`}
           items={models.map((m) => ({
-            title: m.displayName,
-            value: m.name,
+            title: `${m.displayName}${
+              m?.provider?.providerName
+                ? "(" + m?.provider?.providerName + ")"
+                : ""
+            }`,
+            value: `${m.name}@${m?.provider?.providerName}`,
           }))}
           onClose={() => setShowModelSelector(false)}
           onSelection={(s) => {
             if (s.length === 0) return;
+            const [model, providerName] = s[0].split("@");
             chatStore.updateCurrentSession((session) => {
-              session.mask.modelConfig.model = s[0] as ModelType;
+              session.mask.modelConfig.model = model as ModelType;
+              session.mask.modelConfig.providerName =
+                providerName as ServiceProvider;
               session.mask.syncGlobalConfig = false;
             });
-            showToast(s[0]);
+            showToast(model);
           }}
         />
       )}

+ 7 - 4
app/components/exporter.tsx

@@ -36,11 +36,14 @@ import { toBlob, toPng } from "html-to-image";
 import { DEFAULT_MASK_AVATAR } from "../store/mask";
 
 import { prettyObject } from "../utils/format";
-import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant";
+import {
+  EXPORT_MESSAGE_CLASS_NAME,
+  ModelProvider,
+  ServiceProvider,
+} from "../constant";
 import { getClientConfig } from "../config/client";
 import { ClientApi } from "../client/api";
 import { getMessageTextContent } from "../utils";
-import { identifyDefaultClaudeModel } from "../utils/checkers";
 
 const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
   loading: () => <LoadingIcon />,
@@ -314,9 +317,9 @@ export function PreviewActions(props: {
     setShouldExport(false);
 
     var api: ClientApi;
-    if (config.modelConfig.model.startsWith("gemini")) {
+    if (config.modelConfig.providerName == ServiceProvider.Google) {
       api = new ClientApi(ModelProvider.GeminiPro);
-    } else if (identifyDefaultClaudeModel(config.modelConfig.model)) {
+    } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
       api = new ClientApi(ModelProvider.Claude);
     } else {
       api = new ClientApi(ModelProvider.GPT);

+ 3 - 4
app/components/home.tsx

@@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg";
 import { getCSSVar, useMobileScreen } from "../utils";
 
 import dynamic from "next/dynamic";
-import { ModelProvider, Path, SlotID } from "../constant";
+import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant";
 import { ErrorBoundary } from "./error";
 
 import { getISOLang, getLang } from "../locales";
@@ -29,7 +29,6 @@ import { AuthPage } from "./auth";
 import { getClientConfig } from "../config/client";
 import { ClientApi } from "../client/api";
 import { useAccessStore } from "../store";
-import { identifyDefaultClaudeModel } from "../utils/checkers";
 
 export function Loading(props: { noLogo?: boolean }) {
   return (
@@ -172,9 +171,9 @@ export function useLoadData() {
   const config = useAppConfig();
 
   var api: ClientApi;
-  if (config.modelConfig.model.startsWith("gemini")) {
+  if (config.modelConfig.providerName == ServiceProvider.Google) {
     api = new ClientApi(ModelProvider.GeminiPro);
-  } else if (identifyDefaultClaudeModel(config.modelConfig.model)) {
+  } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
     api = new ClientApi(ModelProvider.Claude);
   } else {
     api = new ClientApi(ModelProvider.GPT);

+ 10 - 9
app/components/model-config.tsx

@@ -1,3 +1,4 @@
+import { ServiceProvider } from "@/app/constant";
 import { ModalConfigValidator, ModelConfig } from "../store";
 
 import Locale from "../locales";
@@ -10,25 +11,25 @@ export function ModelConfigList(props: {
   updateConfig: (updater: (config: ModelConfig) => void) => void;
 }) {
   const allModels = useAllModels();
+  const value = `${props.modelConfig.model}@${props.modelConfig?.providerName}`;
 
   return (
     <>
       <ListItem title={Locale.Settings.Model}>
         <Select
-          value={props.modelConfig.model}
+          value={value}
           onChange={(e) => {
-            props.updateConfig(
-              (config) =>
-                (config.model = ModalConfigValidator.model(
-                  e.currentTarget.value,
-                )),
-            );
+            const [model, providerName] = e.currentTarget.value.split("@");
+            props.updateConfig((config) => {
+              config.model = ModalConfigValidator.model(model);
+              config.providerName = providerName as ServiceProvider;
+            });
           }}
         >
           {allModels
             .filter((v) => v.available)
             .map((v, i) => (
-              <option value={v.name} key={i}>
+              <option value={`${v.name}@${v.provider?.providerName}`} key={i}>
                 {v.displayName}({v.provider?.providerName})
               </option>
             ))}
@@ -92,7 +93,7 @@ export function ModelConfigList(props: {
         ></input>
       </ListItem>
 
-      {props.modelConfig.model.startsWith("gemini") ? null : (
+      {props.modelConfig?.providerName == ServiceProvider.Google ? null : (
         <>
           <ListItem
             title={Locale.Settings.PresencePenalty.Title}

+ 12 - 0
app/constant.ts

@@ -25,6 +25,7 @@ export enum Path {
 
 export enum ApiPath {
   Cors = "",
+  Azure = "/api/azure",
   OpenAI = "/api/openai",
   Anthropic = "/api/anthropic",
 }
@@ -93,6 +94,8 @@ export const OpenaiPath = {
 };
 
 export const Azure = {
+  ChatPath: (deployName: string, apiVersion: string) =>
+    `deployments/${deployName}/chat/completions?api-version=${apiVersion}`,
   ExampleEndpoint: "https://{resource-url}/openai/deployments/{deploy-id}",
 };
 
@@ -179,6 +182,15 @@ export const DEFAULT_MODELS = [
       providerType: "openai",
     },
   })),
+  ...openaiModels.map((name) => ({
+    name,
+    available: true,
+    provider: {
+      id: "azure",
+      providerName: "Azure",
+      providerType: "azure",
+    },
+  })),
   ...googleModels.map((name) => ({
     name,
     available: true,

+ 6 - 1
app/store/access.ts

@@ -17,6 +17,11 @@ const DEFAULT_OPENAI_URL =
     ? DEFAULT_API_HOST + "/api/proxy/openai"
     : ApiPath.OpenAI;
 
+const DEFAULT_AZURE_URL =
+  getClientConfig()?.buildMode === "export"
+    ? DEFAULT_API_HOST + "/api/proxy/azure/{resource_name}"
+    : ApiPath.Azure;
+
 const DEFAULT_ACCESS_STATE = {
   accessCode: "",
   useCustomConfig: false,
@@ -28,7 +33,7 @@ const DEFAULT_ACCESS_STATE = {
   openaiApiKey: "",
 
   // azure
-  azureUrl: "",
+  azureUrl: DEFAULT_AZURE_URL,
   azureApiKey: "",
   azureApiVersion: "2023-08-01-preview",
 

+ 5 - 5
app/store/chat.ts

@@ -9,6 +9,7 @@ import {
   DEFAULT_MODELS,
   DEFAULT_SYSTEM_TEMPLATE,
   KnowledgeCutOffDate,
+  ServiceProvider,
   ModelProvider,
   StoreKey,
   SUMMARIZE_MODEL,
@@ -20,7 +21,6 @@ import { prettyObject } from "../utils/format";
 import { estimateTokenLength } from "../utils/token";
 import { nanoid } from "nanoid";
 import { createPersistStore } from "../utils/store";
-import { identifyDefaultClaudeModel } from "../utils/checkers";
 import { collectModelsWithDefaultModel } from "../utils/model";
 import { useAccessStore } from "./access";
 
@@ -364,9 +364,9 @@ export const useChatStore = createPersistStore(
         });
 
         var api: ClientApi;
-        if (modelConfig.model.startsWith("gemini")) {
+        if (modelConfig.providerName == ServiceProvider.Google) {
           api = new ClientApi(ModelProvider.GeminiPro);
-        } else if (identifyDefaultClaudeModel(modelConfig.model)) {
+        } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
           api = new ClientApi(ModelProvider.Claude);
         } else {
           api = new ClientApi(ModelProvider.GPT);
@@ -548,9 +548,9 @@ export const useChatStore = createPersistStore(
         const modelConfig = session.mask.modelConfig;
 
         var api: ClientApi;
-        if (modelConfig.model.startsWith("gemini")) {
+        if (modelConfig.providerName == ServiceProvider.Google) {
           api = new ClientApi(ModelProvider.GeminiPro);
-        } else if (identifyDefaultClaudeModel(modelConfig.model)) {
+        } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
           api = new ClientApi(ModelProvider.Claude);
         } else {
           api = new ClientApi(ModelProvider.GPT);

+ 2 - 0
app/store/config.ts

@@ -5,6 +5,7 @@ import {
   DEFAULT_MODELS,
   DEFAULT_SIDEBAR_WIDTH,
   StoreKey,
+  ServiceProvider,
 } from "../constant";
 import { createPersistStore } from "../utils/store";
 
@@ -48,6 +49,7 @@ export const DEFAULT_CONFIG = {
 
   modelConfig: {
     model: "gpt-3.5-turbo" as ModelType,
+    providerName: "Openai" as ServiceProvider,
     temperature: 0.5,
     top_p: 1,
     max_tokens: 4000,

+ 0 - 21
app/utils/checkers.ts

@@ -1,21 +0,0 @@
-import { useAccessStore } from "../store/access";
-import { useAppConfig } from "../store/config";
-import { collectModels } from "./model";
-
-export function identifyDefaultClaudeModel(modelName: string) {
-  const accessStore = useAccessStore.getState();
-  const configStore = useAppConfig.getState();
-
-  const allModals = collectModels(
-    configStore.models,
-    [configStore.customModels, accessStore.customModels].join(","),
-  );
-
-  const modelMeta = allModals.find((m) => m.name === modelName);
-
-  return (
-    modelName.startsWith("claude") &&
-    modelMeta &&
-    modelMeta.provider?.providerType === "anthropic"
-  );
-}

+ 6 - 1
app/utils/hooks.ts

@@ -11,7 +11,12 @@ export function useAllModels() {
       [configStore.customModels, accessStore.customModels].join(","),
       accessStore.defaultModel,
     );
-  }, [accessStore.customModels, configStore.customModels, configStore.models]);
+  }, [
+    accessStore.customModels,
+    accessStore.defaultModel,
+    configStore.customModels,
+    configStore.models,
+  ]);
 
   return models;
 }

+ 5 - 0
next.config.mjs

@@ -69,6 +69,11 @@ if (mode !== "export") {
         source: "/api/proxy/v1/:path*",
         destination: "https://api.openai.com/v1/:path*",
       },
+      {
+        // https://{resource_name}.openai.azure.com/openai/deployments/{deploy_name}/chat/completions
+        source: "/api/proxy/azure/:resource_name/deployments/:deploy_name/:path*",
+        destination: "https://:resource_name.openai.azure.com/openai/deployments/:deploy_name/:path*",
+      },
       {
         source: "/api/proxy/google/:path*",
         destination: "https://generativelanguage.googleapis.com/:path*",