Przeglądaj źródła

Merge branch 'main' of https://github.com/ConnectAI-E/ChatGPT-Next-Web into feature/ByteDance

Dogtiti 1 rok temu
rodzic
commit
2ec8b7a804
4 zmienionych plików z 71 dodań i 82 usunięć
  1. 56 32
      app/client/api.ts
  2. 3 16
      app/components/exporter.tsx
  3. 4 12
      app/components/home.tsx
  4. 8 22
      app/store/chat.ts

+ 56 - 32
app/client/api.ts

@@ -162,46 +162,70 @@ export class ClientApi {
 
 export function getHeaders() {
   const accessStore = useAccessStore.getState();
+  const chatStore = useChatStore.getState();
   const headers: Record<string, string> = {
     "Content-Type": "application/json",
     Accept: "application/json",
   };
-  const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
-  const isGoogle = modelConfig.providerName == ServiceProvider.Google;
-  const isAzure = modelConfig.providerName === ServiceProvider.Azure;
-  const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
-  const authHeader = isAzure
-    ? "api-key"
-    : isAnthropic
-    ? "x-api-key"
-    : "Authorization";
-  const apiKey = isGoogle
-    ? accessStore.googleApiKey
-    : isAzure
-    ? accessStore.azureApiKey
-    : isAnthropic
-    ? accessStore.anthropicApiKey
-    : accessStore.openaiApiKey;
+
   const clientConfig = getClientConfig();
-  const makeBearer = (s: string) =>
-    `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
-  const validString = (x: string) => x && x.length > 0;
 
+  function getConfig() {
+    const modelConfig = chatStore.currentSession().mask.modelConfig;
+    const isGoogle = modelConfig.providerName == ServiceProvider.Google;
+    const isAzure = modelConfig.providerName === ServiceProvider.Azure;
+    const isAnthropic = modelConfig.providerName === ServiceProvider.Anthropic;
+    const isEnabledAccessControl = accessStore.enabledAccessControl();
+    const apiKey = isGoogle
+      ? accessStore.googleApiKey
+      : isAzure
+      ? accessStore.azureApiKey
+      : isAnthropic
+      ? accessStore.anthropicApiKey
+      : accessStore.openaiApiKey;
+    return { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl };
+  }
+
+  function getAuthHeader(): string {
+    return isAzure ? "api-key" : isAnthropic ? "x-api-key" : "Authorization";
+  }
+
+  function getBearerToken(apiKey: string, noBearer: boolean = false): string {
+    return validString(apiKey)
+      ? `${noBearer ? "" : "Bearer "}${apiKey.trim()}`
+      : "";
+  }
+
+  function validString(x: string): boolean {
+    return x?.length > 0;
+  }
+  const { isGoogle, isAzure, isAnthropic, apiKey, isEnabledAccessControl } =
+    getConfig();
   // when using google api in app, not set auth header
-  if (!(isGoogle && clientConfig?.isApp)) {
-    // use user's api key first
-    if (validString(apiKey)) {
-      headers[authHeader] = makeBearer(apiKey);
-    } else if (
-      accessStore.enabledAccessControl() &&
-      validString(accessStore.accessCode)
-    ) {
-      // access_code must send with header named `Authorization`, will using in auth middleware.
-      headers["Authorization"] = makeBearer(
-        ACCESS_CODE_PREFIX + accessStore.accessCode,
-      );
-    }
+  if (isGoogle && clientConfig?.isApp) return headers;
+
+  const authHeader = getAuthHeader();
+
+  const bearerToken = getBearerToken(apiKey, isAzure || isAnthropic);
+
+  if (bearerToken) {
+    headers[authHeader] = bearerToken;
+  } else if (isEnabledAccessControl && validString(accessStore.accessCode)) {
+    headers["Authorization"] = getBearerToken(
+      ACCESS_CODE_PREFIX + accessStore.accessCode,
+    );
   }
 
   return headers;
 }
+
+export function getClientApi(provider: ServiceProvider): ClientApi {
+  switch (provider) {
+    case ServiceProvider.Google:
+      return new ClientApi(ModelProvider.GeminiPro);
+    case ServiceProvider.Anthropic:
+      return new ClientApi(ModelProvider.Claude);
+    default:
+      return new ClientApi(ModelProvider.GPT);
+  }
+}

+ 3 - 16
app/components/exporter.tsx

@@ -36,13 +36,9 @@ import { toBlob, toPng } from "html-to-image";
 import { DEFAULT_MASK_AVATAR } from "../store/mask";
 
 import { prettyObject } from "../utils/format";
-import {
-  EXPORT_MESSAGE_CLASS_NAME,
-  ModelProvider,
-  ServiceProvider,
-} from "../constant";
+import { EXPORT_MESSAGE_CLASS_NAME } from "../constant";
 import { getClientConfig } from "../config/client";
-import { ClientApi } from "../client/api";
+import { type ClientApi, getClientApi } from "../client/api";
 import { getMessageTextContent } from "../utils";
 
 const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
@@ -316,16 +312,7 @@ export function PreviewActions(props: {
   const onRenderMsgs = (msgs: ChatMessage[]) => {
     setShouldExport(false);
 
-    var api: ClientApi;
-    if (config.modelConfig.providerName == ServiceProvider.Google) {
-      api = new ClientApi(ModelProvider.GeminiPro);
-    } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
-      api = new ClientApi(ModelProvider.Claude);
-    } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) {
-      api = new ClientApi(ModelProvider.Doubao);
-    } else {
-      api = new ClientApi(ModelProvider.GPT);
-    }
+    const api: ClientApi = getClientApi(config.modelConfig.providerName);
 
     api
       .share(msgs)

+ 4 - 12
app/components/home.tsx

@@ -12,7 +12,7 @@ import LoadingIcon from "../icons/three-dots.svg";
 import { getCSSVar, useMobileScreen } from "../utils";
 
 import dynamic from "next/dynamic";
-import { ServiceProvider, ModelProvider, Path, SlotID } from "../constant";
+import { Path, SlotID } from "../constant";
 import { ErrorBoundary } from "./error";
 
 import { getISOLang, getLang } from "../locales";
@@ -27,7 +27,7 @@ import { SideBar } from "./sidebar";
 import { useAppConfig } from "../store/config";
 import { AuthPage } from "./auth";
 import { getClientConfig } from "../config/client";
-import { ClientApi } from "../client/api";
+import { type ClientApi, getClientApi } from "../client/api";
 import { useAccessStore } from "../store";
 
 export function Loading(props: { noLogo?: boolean }) {
@@ -170,16 +170,8 @@ function Screen() {
 export function useLoadData() {
   const config = useAppConfig();
 
-  var api: ClientApi;
-  if (config.modelConfig.providerName == ServiceProvider.Google) {
-    api = new ClientApi(ModelProvider.GeminiPro);
-  } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
-    api = new ClientApi(ModelProvider.Claude);
-  } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) {
-    api = new ClientApi(ModelProvider.Doubao);
-  } else {
-    api = new ClientApi(ModelProvider.GPT);
-  }
+  const api: ClientApi = getClientApi(config.modelConfig.providerName);
+
   useEffect(() => {
     (async () => {
       const models = await api.llm.models();

+ 8 - 22
app/store/chat.ts

@@ -15,7 +15,12 @@ import {
   SUMMARIZE_MODEL,
   GEMINI_SUMMARIZE_MODEL,
 } from "../constant";
-import { ClientApi, RequestMessage, MultimodalContent } from "../client/api";
+import { getClientApi } from "../client/api";
+import type {
+  ClientApi,
+  RequestMessage,
+  MultimodalContent,
+} from "../client/api";
 import { ChatControllerPool } from "../client/controller";
 import { prettyObject } from "../utils/format";
 import { estimateTokenLength } from "../utils/token";
@@ -363,17 +368,7 @@ export const useChatStore = createPersistStore(
           ]);
         });
 
-        var api: ClientApi;
-        if (modelConfig.providerName == ServiceProvider.Google) {
-          api = new ClientApi(ModelProvider.GeminiPro);
-        } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
-          api = new ClientApi(ModelProvider.Claude);
-        } else if (modelConfig.providerName == ServiceProvider.ByteDance) {
-          api = new ClientApi(ModelProvider.Doubao);
-        } else {
-          api = new ClientApi(ModelProvider.GPT);
-        }
-
+        const api: ClientApi = getClientApi(modelConfig.providerName);
         // make request
         api.llm.chat({
           messages: sendMessages,
@@ -549,16 +544,7 @@ export const useChatStore = createPersistStore(
         const session = get().currentSession();
         const modelConfig = session.mask.modelConfig;
 
-        var api: ClientApi;
-        if (modelConfig.providerName == ServiceProvider.Google) {
-          api = new ClientApi(ModelProvider.GeminiPro);
-        } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
-          api = new ClientApi(ModelProvider.Claude);
-        } else if (modelConfig.providerName == ServiceProvider.ByteDance) {
-          api = new ClientApi(ModelProvider.Doubao);
-        } else {
-          api = new ClientApi(ModelProvider.GPT);
-        }
+        const api: ClientApi = getClientApi(modelConfig.providerName);
 
         // remove error messages if any
         const messages = session.messages;