Browse Source

wip: doubao

Dogtiti 1 năm trước cách đây
mục cha
commit
9b3b4494ba

+ 3 - 0
app/api/auth.ts

@@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
       case ModelProvider.Claude:
         systemApiKey = serverConfig.anthropicApiKey;
         break;
+      case ModelProvider.Doubao:
+        systemApiKey = serverConfig.bytedanceApiKey;
+        break;
       case ModelProvider.GPT:
       default:
         if (req.nextUrl.pathname.includes("azure/deployments")) {

+ 160 - 0
app/api/bytedance/[...path]/route.ts

@@ -0,0 +1,160 @@
+import { getServerSideConfig } from "@/app/config/server";
+import {
+  BYTEDANCE_BASE_URL,
+  ApiPath,
+  ModelProvider,
+  ServiceProvider,
+} from "@/app/constant";
+import { prettyObject } from "@/app/utils/format";
+import { NextRequest, NextResponse } from "next/server";
+import { auth } from "@/app/api/auth";
+import { isModelAvailableInServer } from "@/app/utils/model";
+
+const serverConfig = getServerSideConfig();
+
+async function handle(
+  req: NextRequest,
+  { params }: { params: { path: string[] } },
+) {
+  console.log("[ByteDance Route] params ", params);
+
+  if (req.method === "OPTIONS") {
+    return NextResponse.json({ body: "OK" }, { status: 200 });
+  }
+
+  const authResult = auth(req, ModelProvider.Doubao);
+  if (authResult.error) {
+    return NextResponse.json(authResult, {
+      status: 401,
+    });
+  }
+
+  try {
+    const response = await request(req);
+    return response;
+  } catch (e) {
+    console.error("[ByteDance] ", e);
+    return NextResponse.json(prettyObject(e));
+  }
+}
+
+export const GET = handle;
+export const POST = handle;
+
+export const runtime = "edge";
+export const preferredRegion = [
+  "arn1",
+  "bom1",
+  "cdg1",
+  "cle1",
+  "cpt1",
+  "dub1",
+  "fra1",
+  "gru1",
+  "hnd1",
+  "iad1",
+  "icn1",
+  "kix1",
+  "lhr1",
+  "pdx1",
+  "sfo1",
+  "sin1",
+  "syd1",
+];
+
+async function request(req: NextRequest) {
+  const controller = new AbortController();
+
+  let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.ByteDance, "");
+
+  let baseUrl = serverConfig.bytedanceUrl || BYTEDANCE_BASE_URL;
+
+  if (!baseUrl.startsWith("http")) {
+    baseUrl = `https://${baseUrl}`;
+  }
+
+  if (baseUrl.endsWith("/")) {
+    baseUrl = baseUrl.slice(0, -1);
+  }
+
+  console.log("[Proxy] ", path);
+  console.log("[Base Url]", baseUrl);
+
+  const timeoutId = setTimeout(
+    () => {
+      controller.abort();
+    },
+    10 * 60 * 1000,
+  );
+
+  const fetchUrl = `${baseUrl}${path}`;
+
+  const fetchOptions: RequestInit = {
+    headers: {
+      "Content-Type": "application/json",
+      Authorization: req.headers.get("Authorization") ?? "",
+    },
+    method: req.method,
+    body: req.body,
+    redirect: "manual",
+    // @ts-ignore
+    duplex: "half",
+    signal: controller.signal,
+  };
+
+  // #1815 try to refuse some request to some models
+  if (serverConfig.customModels && req.body) {
+    try {
+      const clonedBody = await req.text();
+      fetchOptions.body = clonedBody;
+
+      const jsonBody = JSON.parse(clonedBody) as { model?: string };
+
+      // not undefined and is false
+      if (
+        isModelAvailableInServer(
+          serverConfig.customModels,
+          jsonBody?.model as string,
+          ServiceProvider.ByteDance as string,
+        )
+      ) {
+        return NextResponse.json(
+          {
+            error: true,
+            message: `you are not allowed to use ${jsonBody?.model} model`,
+          },
+          {
+            status: 403,
+          },
+        );
+      }
+    } catch (e) {
+      console.error(`[ByteDance] filter`, e);
+    }
+  }
+  console.log("[ByteDance request]", fetchOptions.headers, req.method);
+  try {
+    const res = await fetch(fetchUrl, fetchOptions);
+
+    console.log(
+      "[ByteDance response]",
+      res.status,
+      "   ",
+      res.headers,
+      res.url,
+    );
+    // to prevent browser prompt for credentials
+    const newHeaders = new Headers(res.headers);
+    newHeaders.delete("www-authenticate");
+    // to disable nginx buffering
+    newHeaders.set("X-Accel-Buffering", "no");
+
+    return new Response(res.body, {
+      status: res.status,
+      statusText: res.statusText,
+      headers: newHeaders,
+    });
+  } finally {
+    clearTimeout(timeoutId);
+  }
+}

+ 5 - 0
app/client/api.ts

@@ -9,6 +9,8 @@ import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store";
 import { ChatGPTApi } from "./platforms/openai";
 import { GeminiProApi } from "./platforms/google";
 import { ClaudeApi } from "./platforms/anthropic";
+import { DoubaoApi } from "./platforms/bytedance";
+
 export const ROLES = ["system", "user", "assistant"] as const;
 export type MessageRole = (typeof ROLES)[number];
 
@@ -104,6 +106,9 @@ export class ClientApi {
       case ModelProvider.Claude:
         this.llm = new ClaudeApi();
         break;
+      case ModelProvider.Doubao:
+        this.llm = new DoubaoApi();
+        break;
       default:
         this.llm = new ChatGPTApi();
     }

+ 260 - 0
app/client/platforms/bytedance.ts

@@ -0,0 +1,260 @@
+"use client";
+import {
+  ApiPath,
+  ByteDance,
+  DEFAULT_API_HOST,
+  REQUEST_TIMEOUT_MS,
+} from "@/app/constant";
+import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
+
+import {
+  ChatOptions,
+  getHeaders,
+  LLMApi,
+  LLMModel,
+  MultimodalContent,
+} from "../api";
+import Locale from "../../locales";
+import {
+  EventStreamContentType,
+  fetchEventSource,
+} from "@fortaine/fetch-event-source";
+import { prettyObject } from "@/app/utils/format";
+import { getClientConfig } from "@/app/config/client";
+import { getMessageTextContent, isVisionModel } from "@/app/utils";
+
+export interface OpenAIListModelResponse {
+  object: string;
+  data: Array<{
+    id: string;
+    object: string;
+    root: string;
+  }>;
+}
+
+interface RequestPayload {
+  messages: {
+    role: "system" | "user" | "assistant";
+    content: string | MultimodalContent[];
+  }[];
+  stream?: boolean;
+  model: string;
+  temperature: number;
+  presence_penalty: number;
+  frequency_penalty: number;
+  top_p: number;
+  max_tokens?: number;
+}
+
+export class DoubaoApi implements LLMApi {
+  path(path: string): string {
+    const accessStore = useAccessStore.getState();
+
+    let baseUrl = "";
+
+    if (accessStore.useCustomConfig) {
+      baseUrl = accessStore.bytedanceUrl;
+    }
+
+    if (baseUrl.length === 0) {
+      const isApp = !!getClientConfig()?.isApp;
+      baseUrl = isApp
+        ? DEFAULT_API_HOST + "/api/proxy/bytedance"
+        : ApiPath.ByteDance;
+    }
+
+    if (baseUrl.endsWith("/")) {
+      baseUrl = baseUrl.slice(0, baseUrl.length - 1);
+    }
+    if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.ByteDance)) {
+      baseUrl = "https://" + baseUrl;
+    }
+
+    console.log("[Proxy Endpoint] ", baseUrl, path);
+
+    return [baseUrl, path].join("/");
+  }
+
+  extractMessage(res: any) {
+    return res.choices?.at(0)?.message?.content ?? "";
+  }
+
+  async chat(options: ChatOptions) {
+    const visionModel = isVisionModel(options.config.model);
+    const messages = options.messages.map((v) => ({
+      role: v.role,
+      content: visionModel ? v.content : getMessageTextContent(v),
+    }));
+
+    const modelConfig = {
+      ...useAppConfig.getState().modelConfig,
+      ...useChatStore.getState().currentSession().mask.modelConfig,
+      ...{
+        model: options.config.model,
+      },
+    };
+
+    const requestPayload: RequestPayload = {
+      messages,
+      stream: options.config.stream,
+      model: modelConfig.model,
+      temperature: modelConfig.temperature,
+      presence_penalty: modelConfig.presence_penalty,
+      frequency_penalty: modelConfig.frequency_penalty,
+      top_p: modelConfig.top_p,
+    };
+
+    console.log("[Request] ByteDance payload: ", requestPayload);
+
+    const shouldStream = !!options.config.stream;
+    const controller = new AbortController();
+    options.onController?.(controller);
+
+    try {
+      const chatPath = this.path(ByteDance.ChatPath);
+      const chatPayload = {
+        method: "POST",
+        body: JSON.stringify(requestPayload),
+        signal: controller.signal,
+        headers: getHeaders(),
+      };
+
+      // make a fetch request
+      const requestTimeoutId = setTimeout(
+        () => controller.abort(),
+        REQUEST_TIMEOUT_MS,
+      );
+
+      if (shouldStream) {
+        let responseText = "";
+        let remainText = "";
+        let finished = false;
+
+        // animate response to make it looks smooth
+        function animateResponseText() {
+          if (finished || controller.signal.aborted) {
+            responseText += remainText;
+            console.log("[Response Animation] finished");
+            if (responseText?.length === 0) {
+              options.onError?.(new Error("empty response from server"));
+            }
+            return;
+          }
+
+          if (remainText.length > 0) {
+            const fetchCount = Math.max(1, Math.round(remainText.length / 60));
+            const fetchText = remainText.slice(0, fetchCount);
+            responseText += fetchText;
+            remainText = remainText.slice(fetchCount);
+            options.onUpdate?.(responseText, fetchText);
+          }
+
+          requestAnimationFrame(animateResponseText);
+        }
+
+        // start animaion
+        animateResponseText();
+
+        const finish = () => {
+          if (!finished) {
+            finished = true;
+            options.onFinish(responseText + remainText);
+          }
+        };
+
+        controller.signal.onabort = finish;
+
+        fetchEventSource(chatPath, {
+          ...chatPayload,
+          async onopen(res) {
+            clearTimeout(requestTimeoutId);
+            const contentType = res.headers.get("content-type");
+            console.log(
+              "[ByteDance] request response content type: ",
+              contentType,
+            );
+
+            if (contentType?.startsWith("text/plain")) {
+              responseText = await res.clone().text();
+              return finish();
+            }
+
+            if (
+              !res.ok ||
+              !res.headers
+                .get("content-type")
+                ?.startsWith(EventStreamContentType) ||
+              res.status !== 200
+            ) {
+              const responseTexts = [responseText];
+              let extraInfo = await res.clone().text();
+              try {
+                const resJson = await res.clone().json();
+                extraInfo = prettyObject(resJson);
+              } catch {}
+
+              if (res.status === 401) {
+                responseTexts.push(Locale.Error.Unauthorized);
+              }
+
+              if (extraInfo) {
+                responseTexts.push(extraInfo);
+              }
+
+              responseText = responseTexts.join("\n\n");
+
+              return finish();
+            }
+          },
+          onmessage(msg) {
+            if (msg.data === "[DONE]" || finished) {
+              return finish();
+            }
+            const text = msg.data;
+            try {
+              const json = JSON.parse(text);
+              const choices = json.choices as Array<{
+                delta: { content: string };
+              }>;
+              const delta = choices[0]?.delta?.content;
+              if (delta) {
+                remainText += delta;
+              }
+            } catch (e) {
+              console.error("[Request] parse error", text, msg);
+            }
+          },
+          onclose() {
+            finish();
+          },
+          onerror(e) {
+            options.onError?.(e);
+            throw e;
+          },
+          openWhenHidden: true,
+        });
+      } else {
+        const res = await fetch(chatPath, chatPayload);
+        clearTimeout(requestTimeoutId);
+
+        const resJson = await res.json();
+        const message = this.extractMessage(resJson);
+        options.onFinish(message);
+      }
+    } catch (e) {
+      console.log("[Request] failed to make a chat request", e);
+      options.onError?.(e as Error);
+    }
+  }
+  async usage() {
+    return {
+      used: 0,
+      total: 0,
+    };
+  }
+
+  async models(): Promise<LLMModel[]> {
+    return [];
+  }
+}
+export { ByteDance };

+ 2 - 0
app/components/exporter.tsx

@@ -321,6 +321,8 @@ export function PreviewActions(props: {
       api = new ClientApi(ModelProvider.GeminiPro);
     } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
       api = new ClientApi(ModelProvider.Claude);
+    } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) {
+      api = new ClientApi(ModelProvider.Doubao);
     } else {
       api = new ClientApi(ModelProvider.GPT);
     }

+ 2 - 0
app/components/home.tsx

@@ -175,6 +175,8 @@ export function useLoadData() {
     api = new ClientApi(ModelProvider.GeminiPro);
   } else if (config.modelConfig.providerName == ServiceProvider.Anthropic) {
     api = new ClientApi(ModelProvider.Claude);
+  } else if (config.modelConfig.providerName == ServiceProvider.ByteDance) {
+    api = new ClientApi(ModelProvider.Doubao);
   } else {
     api = new ClientApi(ModelProvider.GPT);
   }

+ 9 - 0
app/config/server.ts

@@ -32,6 +32,10 @@ declare global {
       GOOGLE_API_KEY?: string;
       GOOGLE_URL?: string;
 
+      // bytedance only
+      BYTEDANCE_URL?: string;
+      BYTEDANCE_API_KEY?: string;
+
       // google tag manager
       GTM_ID?: string;
 
@@ -92,6 +96,7 @@ export const getServerSideConfig = () => {
   const isAzure = !!process.env.AZURE_URL;
   const isGoogle = !!process.env.GOOGLE_API_KEY;
   const isAnthropic = !!process.env.ANTHROPIC_API_KEY;
+  const isBytedance = !!process.env.BYTEDANCE_API_KEY;
 
   // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? "";
   // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim());
@@ -126,6 +131,10 @@ export const getServerSideConfig = () => {
 
     gtmId: process.env.GTM_ID,
 
+    isBytedance,
+    bytedanceApiKey: getApiKey(process.env.BYTEDANCE_API_KEY),
+    bytedanceUrl: process.env.BYTEDANCE_URL,
+
     needCode: ACCESS_CODES.size > 0,
     code: process.env.CODE,
     codes: ACCESS_CODES,

+ 21 - 0
app/constant.ts

@@ -14,6 +14,8 @@ export const ANTHROPIC_BASE_URL = "https://api.anthropic.com";
 
 export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
 
+export const BYTEDANCE_BASE_URL = "https://ark.cn-beijing.volces.com";
+
 export enum Path {
   Home = "/",
   Chat = "/chat",
@@ -28,6 +30,7 @@ export enum ApiPath {
   Azure = "/api/azure",
   OpenAI = "/api/openai",
   Anthropic = "/api/anthropic",
+  ByteDance = "/api/bytedance",
 }
 
 export enum SlotID {
@@ -71,12 +74,14 @@ export enum ServiceProvider {
   Azure = "Azure",
   Google = "Google",
   Anthropic = "Anthropic",
+  ByteDance = "ByteDance",
 }
 
 export enum ModelProvider {
   GPT = "GPT",
   GeminiPro = "GeminiPro",
   Claude = "Claude",
+  Doubao = "Doubao",
 }
 
 export const Anthropic = {
@@ -104,6 +109,11 @@ export const Google = {
   ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
 };
 
+export const ByteDance = {
+  ExampleEndpoint: "https://ark.cn-beijing.volces.com/api/v3/chat/completions",
+  ChatPath: "/api/v3/chat/completions",
+};
+
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
 // export const DEFAULT_SYSTEM_TEMPLATE = `
 // You are ChatGPT, a large language model trained by {{ServiceProvider}}.
@@ -173,6 +183,8 @@ const anthropicModels = [
   "claude-3-5-sonnet-20240620",
 ];
 
+const bytedanceModels = ["ep-20240520082937-424bw=Doubao-lite-4k"];
+
 export const DEFAULT_MODELS = [
   ...openaiModels.map((name) => ({
     name,
@@ -210,6 +222,15 @@ export const DEFAULT_MODELS = [
       providerType: "anthropic",
     },
   })),
+  ...bytedanceModels.map((name) => ({
+    name,
+    available: true,
+    provider: {
+      id: "bytedance",
+      providerName: "ByteDance",
+      providerType: "bytedance",
+    },
+  })),
 ] as const;
 
 export const CHAT_PAGE_SIZE = 15;

+ 9 - 0
app/store/access.ts

@@ -47,6 +47,10 @@ const DEFAULT_ACCESS_STATE = {
   anthropicApiVersion: "2023-06-01",
   anthropicUrl: "",
 
+  // bytedance
+  bytedanceApiKey: "",
+  bytedanceUrl: "",
+
   // server config
   needCode: true,
   hideUserApiKey: false,
@@ -83,6 +87,10 @@ export const useAccessStore = createPersistStore(
       return ensure(get(), ["anthropicApiKey"]);
     },
 
+    isValidByteDance() {
+      return ensure(get(), ["bytedanceApiKey"]);
+    },
+
     isAuthorized() {
       this.fetch();
 
@@ -92,6 +100,7 @@ export const useAccessStore = createPersistStore(
         this.isValidAzure() ||
         this.isValidGoogle() ||
         this.isValidAnthropic() ||
+        this.isValidByteDance() ||
         !this.enabledAccessControl() ||
         (this.enabledAccessControl() && ensure(get(), ["accessCode"]))
       );

+ 4 - 0
app/store/chat.ts

@@ -368,6 +368,8 @@ export const useChatStore = createPersistStore(
           api = new ClientApi(ModelProvider.GeminiPro);
         } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
           api = new ClientApi(ModelProvider.Claude);
+        } else if (modelConfig.providerName == ServiceProvider.ByteDance) {
+          api = new ClientApi(ModelProvider.Doubao);
         } else {
           api = new ClientApi(ModelProvider.GPT);
         }
@@ -552,6 +554,8 @@ export const useChatStore = createPersistStore(
           api = new ClientApi(ModelProvider.GeminiPro);
         } else if (modelConfig.providerName == ServiceProvider.Anthropic) {
           api = new ClientApi(ModelProvider.Claude);
+        } else if (modelConfig.providerName == ServiceProvider.ByteDance) {
+          api = new ClientApi(ModelProvider.Doubao);
         } else {
           api = new ClientApi(ModelProvider.GPT);
         }