lloydzhou 1 vuosi sitten
vanhempi
commit
8455fefc8a

+ 3 - 0
app/api/[provider]/[...path]/route.ts

@@ -10,6 +10,7 @@ import { handle as alibabaHandler } from "../../alibaba";
 import { handle as moonshotHandler } from "../../moonshot";
 import { handle as stabilityHandler } from "../../stability";
 import { handle as iflytekHandler } from "../../iflytek";
+import { handle as xaiHandler } from "../../xai";
 import { handle as proxyHandler } from "../../proxy";
 
 async function handle(
@@ -38,6 +39,8 @@ async function handle(
       return stabilityHandler(req, { params });
     case ApiPath.Iflytek:
       return iflytekHandler(req, { params });
+    case ApiPath.XAI:
+      return xaiHandler(req, { params });
     case ApiPath.OpenAI:
       return openaiHandler(req, { params });
     default:

+ 4 - 1
app/api/auth.ts

@@ -92,6 +92,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
         systemApiKey =
           serverConfig.iflytekApiKey + ":" + serverConfig.iflytekApiSecret;
         break;
+      case ModelProvider.XAI:
+        systemApiKey = serverConfig.xaiApiKey;
+        break;
       case ModelProvider.GPT:
       default:
         if (req.nextUrl.pathname.includes("azure/deployments")) {
@@ -102,7 +105,7 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
     }
 
     if (systemApiKey) {
-      console.log("[Auth] use system api key");
+      console.log("[Auth] use system api key", systemApiKey);
       req.headers.set("Authorization", `Bearer ${systemApiKey}`);
     } else {
       console.log("[Auth] admin did not provide an api key");

+ 128 - 0
app/api/xai.ts

@@ -0,0 +1,128 @@
+import { getServerSideConfig } from "@/app/config/server";
+import {
+  XAI_BASE_URL,
+  ApiPath,
+  ModelProvider,
+  ServiceProvider,
+} from "@/app/constant";
+import { prettyObject } from "@/app/utils/format";
+import { NextRequest, NextResponse } from "next/server";
+import { auth } from "@/app/api/auth";
+import { isModelAvailableInServer } from "@/app/utils/model";
+
+const serverConfig = getServerSideConfig();
+
+export async function handle(
+  req: NextRequest,
+  { params }: { params: { path: string[] } },
+) {
+  console.log("[XAI Route] params ", params);
+
+  if (req.method === "OPTIONS") {
+    return NextResponse.json({ body: "OK" }, { status: 200 });
+  }
+
+  const authResult = auth(req, ModelProvider.XAI);
+  if (authResult.error) {
+    return NextResponse.json(authResult, {
+      status: 401,
+    });
+  }
+
+  try {
+    const response = await request(req);
+    return response;
+  } catch (e) {
+    console.error("[XAI] ", e);
+    return NextResponse.json(prettyObject(e));
+  }
+}
+
+async function request(req: NextRequest) {
+  const controller = new AbortController();
+
+  // alibaba use base url or just remove the path
+  let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.XAI, "");
+
+  let baseUrl = serverConfig.xaiUrl || XAI_BASE_URL;
+
+  if (!baseUrl.startsWith("http")) {
+    baseUrl = `https://${baseUrl}`;
+  }
+
+  if (baseUrl.endsWith("/")) {
+    baseUrl = baseUrl.slice(0, -1);
+  }
+
+  console.log("[Proxy] ", path);
+  console.log("[Base Url]", baseUrl);
+
+  const timeoutId = setTimeout(
+    () => {
+      controller.abort();
+    },
+    10 * 60 * 1000,
+  );
+
+  const fetchUrl = `${baseUrl}${path}`;
+  const fetchOptions: RequestInit = {
+    headers: {
+      "Content-Type": "application/json",
+      Authorization: req.headers.get("Authorization") ?? "",
+    },
+    method: req.method,
+    body: req.body,
+    redirect: "manual",
+    // @ts-ignore
+    duplex: "half",
+    signal: controller.signal,
+  };
+
+  // #1815 try to refuse some request to some models
+  if (serverConfig.customModels && req.body) {
+    try {
+      const clonedBody = await req.text();
+      fetchOptions.body = clonedBody;
+
+      const jsonBody = JSON.parse(clonedBody) as { model?: string };
+
+      // not undefined and is false
+      if (
+        isModelAvailableInServer(
+          serverConfig.customModels,
+          jsonBody?.model as string,
+          ServiceProvider.XAI as string,
+        )
+      ) {
+        return NextResponse.json(
+          {
+            error: true,
+            message: `you are not allowed to use ${jsonBody?.model} model`,
+          },
+          {
+            status: 403,
+          },
+        );
+      }
+    } catch (e) {
+      console.error(`[XAI] filter`, e);
+    }
+  }
+  try {
+    const res = await fetch(fetchUrl, fetchOptions);
+
+    // to prevent browser prompt for credentials
+    const newHeaders = new Headers(res.headers);
+    newHeaders.delete("www-authenticate");
+    // to disable nginx buffering
+    newHeaders.set("X-Accel-Buffering", "no");
+
+    return new Response(res.body, {
+      status: res.status,
+      statusText: res.statusText,
+      headers: newHeaders,
+    });
+  } finally {
+    clearTimeout(timeoutId);
+  }
+}

+ 10 - 0
app/client/api.ts

@@ -20,6 +20,7 @@ import { QwenApi } from "./platforms/alibaba";
 import { HunyuanApi } from "./platforms/tencent";
 import { MoonshotApi } from "./platforms/moonshot";
 import { SparkApi } from "./platforms/iflytek";
+import { XAIApi } from "./platforms/xai";
 
 export const ROLES = ["system", "user", "assistant"] as const;
 export type MessageRole = (typeof ROLES)[number];
@@ -152,6 +153,9 @@ export class ClientApi {
       case ModelProvider.Iflytek:
         this.llm = new SparkApi();
         break;
+      case ModelProvider.XAI:
+        this.llm = new XAIApi();
+        break;
       default:
         this.llm = new ChatGPTApi();
     }
@@ -239,6 +243,7 @@ export function getHeaders(ignoreHeaders: boolean = false) {
     const isAlibaba = modelConfig.providerName === ServiceProvider.Alibaba;
     const isMoonshot = modelConfig.providerName === ServiceProvider.Moonshot;
     const isIflytek = modelConfig.providerName === ServiceProvider.Iflytek;
+    const isXAI = modelConfig.providerName === ServiceProvider.XAI;
     const isEnabledAccessControl = accessStore.enabledAccessControl();
     const apiKey = isGoogle
       ? accessStore.googleApiKey
@@ -252,6 +257,8 @@ export function getHeaders(ignoreHeaders: boolean = false) {
       ? accessStore.alibabaApiKey
       : isMoonshot
       ? accessStore.moonshotApiKey
+      : isXAI
+      ? accessStore.xaiApiKey
       : isIflytek
       ? accessStore.iflytekApiKey && accessStore.iflytekApiSecret
         ? accessStore.iflytekApiKey + ":" + accessStore.iflytekApiSecret
@@ -266,6 +273,7 @@ export function getHeaders(ignoreHeaders: boolean = false) {
       isAlibaba,
       isMoonshot,
       isIflytek,
+      isXAI,
       apiKey,
       isEnabledAccessControl,
     };
@@ -328,6 +336,8 @@ export function getClientApi(provider: ServiceProvider): ClientApi {
       return new ClientApi(ModelProvider.Moonshot);
     case ServiceProvider.Iflytek:
       return new ClientApi(ModelProvider.Iflytek);
+    case ServiceProvider.XAI:
+      return new ClientApi(ModelProvider.XAI);
     default:
       return new ClientApi(ModelProvider.GPT);
   }

+ 195 - 0
app/client/platforms/xai.ts

@@ -0,0 +1,195 @@
+"use client";
+// azure and openai, using same models. so using same LLMApi.
+import { ApiPath, XAI_BASE_URL, XAI, REQUEST_TIMEOUT_MS } from "@/app/constant";
+import {
+  useAccessStore,
+  useAppConfig,
+  useChatStore,
+  ChatMessageTool,
+  usePluginStore,
+} from "@/app/store";
+import { stream } from "@/app/utils/chat";
+import {
+  ChatOptions,
+  getHeaders,
+  LLMApi,
+  LLMModel,
+  SpeechOptions,
+} from "../api";
+import { getClientConfig } from "@/app/config/client";
+import { getMessageTextContent } from "@/app/utils";
+import { RequestPayload } from "./openai";
+import { fetch } from "@/app/utils/stream";
+
+export class XAIApi implements LLMApi {
+  private disableListModels = true;
+
+  path(path: string): string {
+    const accessStore = useAccessStore.getState();
+
+    let baseUrl = "";
+
+    if (accessStore.useCustomConfig) {
+      baseUrl = accessStore.xaiUrl;
+    }
+
+    if (baseUrl.length === 0) {
+      const isApp = !!getClientConfig()?.isApp;
+      const apiPath = ApiPath.XAI;
+      baseUrl = isApp ? XAI_BASE_URL : apiPath;
+    }
+
+    if (baseUrl.endsWith("/")) {
+      baseUrl = baseUrl.slice(0, baseUrl.length - 1);
+    }
+    if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.XAI)) {
+      baseUrl = "https://" + baseUrl;
+    }
+
+    console.log("[Proxy Endpoint] ", baseUrl, path);
+
+    return [baseUrl, path].join("/");
+  }
+
+  extractMessage(res: any) {
+    return res.choices?.at(0)?.message?.content ?? "";
+  }
+
+  speech(options: SpeechOptions): Promise<ArrayBuffer> {
+    throw new Error("Method not implemented.");
+  }
+
+  async chat(options: ChatOptions) {
+    const messages: ChatOptions["messages"] = [];
+    for (const v of options.messages) {
+      const content = getMessageTextContent(v);
+      messages.push({ role: v.role, content });
+    }
+
+    const modelConfig = {
+      ...useAppConfig.getState().modelConfig,
+      ...useChatStore.getState().currentSession().mask.modelConfig,
+      ...{
+        model: options.config.model,
+        providerName: options.config.providerName,
+      },
+    };
+
+    const requestPayload: RequestPayload = {
+      messages,
+      stream: options.config.stream,
+      model: modelConfig.model,
+      temperature: modelConfig.temperature,
+      presence_penalty: modelConfig.presence_penalty,
+      frequency_penalty: modelConfig.frequency_penalty,
+      top_p: modelConfig.top_p,
+      // max_tokens: Math.max(modelConfig.max_tokens, 1024),
+      // Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
+    };
+
+    console.log("[Request] openai payload: ", requestPayload);
+
+    const shouldStream = !!options.config.stream;
+    const controller = new AbortController();
+    options.onController?.(controller);
+
+    try {
+      const chatPath = this.path(XAI.ChatPath);
+      const chatPayload = {
+        method: "POST",
+        body: JSON.stringify(requestPayload),
+        signal: controller.signal,
+        headers: getHeaders(),
+      };
+
+      // make a fetch request
+      const requestTimeoutId = setTimeout(
+        () => controller.abort(),
+        REQUEST_TIMEOUT_MS,
+      );
+
+      if (shouldStream) {
+        const [tools, funcs] = usePluginStore
+          .getState()
+          .getAsTools(
+            useChatStore.getState().currentSession().mask?.plugin || [],
+          );
+        return stream(
+          chatPath,
+          requestPayload,
+          getHeaders(),
+          tools as any,
+          funcs,
+          controller,
+          // parseSSE
+          (text: string, runTools: ChatMessageTool[]) => {
+            // console.log("parseSSE", text, runTools);
+            const json = JSON.parse(text);
+            const choices = json.choices as Array<{
+              delta: {
+                content: string;
+                tool_calls: ChatMessageTool[];
+              };
+            }>;
+            const tool_calls = choices[0]?.delta?.tool_calls;
+            if (tool_calls?.length > 0) {
+              const index = tool_calls[0]?.index;
+              const id = tool_calls[0]?.id;
+              const args = tool_calls[0]?.function?.arguments;
+              if (id) {
+                runTools.push({
+                  id,
+                  type: tool_calls[0]?.type,
+                  function: {
+                    name: tool_calls[0]?.function?.name as string,
+                    arguments: args,
+                  },
+                });
+              } else {
+                // @ts-ignore
+                runTools[index]["function"]["arguments"] += args;
+              }
+            }
+            return choices[0]?.delta?.content;
+          },
+          // processToolMessage, include tool_calls message and tool call results
+          (
+            requestPayload: RequestPayload,
+            toolCallMessage: any,
+            toolCallResult: any[],
+          ) => {
+            // @ts-ignore
+            requestPayload?.messages?.splice(
+              // @ts-ignore
+              requestPayload?.messages?.length,
+              0,
+              toolCallMessage,
+              ...toolCallResult,
+            );
+          },
+          options,
+        );
+      } else {
+        const res = await fetch(chatPath, chatPayload);
+        clearTimeout(requestTimeoutId);
+
+        const resJson = await res.json();
+        const message = this.extractMessage(resJson);
+        options.onFinish(message);
+      }
+    } catch (e) {
+      console.log("[Request] failed to make a chat request", e);
+      options.onError?.(e as Error);
+    }
+  }
+  async usage() {
+    return {
+      used: 0,
+      total: 0,
+    };
+  }
+
+  async models(): Promise<LLMModel[]> {
+    return [];
+  }
+}

+ 41 - 0
app/components/settings.tsx

@@ -59,6 +59,7 @@ import {
   ByteDance,
   Alibaba,
   Moonshot,
+  XAI,
   Google,
   GoogleSafetySettingsThreshold,
   OPENAI_BASE_URL,
@@ -1194,6 +1195,45 @@ export function Settings() {
     </>
   );
 
+  const XAIConfigComponent = accessStore.provider === ServiceProvider.XAI && (
+    <>
+      <ListItem
+        title={Locale.Settings.Access.XAI.Endpoint.Title}
+        subTitle={
+          Locale.Settings.Access.XAI.Endpoint.SubTitle + XAI.ExampleEndpoint
+        }
+      >
+        <input
+          aria-label={Locale.Settings.Access.XAI.Endpoint.Title}
+          type="text"
+          value={accessStore.moonshotUrl}
+          placeholder={XAI.ExampleEndpoint}
+          onChange={(e) =>
+            accessStore.update(
+              (access) => (access.moonshotUrl = e.currentTarget.value),
+            )
+          }
+        ></input>
+      </ListItem>
+      <ListItem
+        title={Locale.Settings.Access.XAI.ApiKey.Title}
+        subTitle={Locale.Settings.Access.XAI.ApiKey.SubTitle}
+      >
+        <PasswordInput
+          aria-label={Locale.Settings.Access.XAI.ApiKey.Title}
+          value={accessStore.moonshotApiKey}
+          type="text"
+          placeholder={Locale.Settings.Access.XAI.ApiKey.Placeholder}
+          onChange={(e) => {
+            accessStore.update(
+              (access) => (access.moonshotApiKey = e.currentTarget.value),
+            );
+          }}
+        />
+      </ListItem>
+    </>
+  );
+
   const stabilityConfigComponent = accessStore.provider ===
     ServiceProvider.Stability && (
     <>
@@ -1652,6 +1692,7 @@ export function Settings() {
                   {moonshotConfigComponent}
                   {stabilityConfigComponent}
                   {lflytekConfigComponent}
+                  {XAIConfigComponent}
                 </>
               )}
             </>

+ 9 - 0
app/config/server.ts

@@ -71,6 +71,10 @@ declare global {
       IFLYTEK_API_KEY?: string;
       IFLYTEK_API_SECRET?: string;
 
+      // xai only
+      XAI_URL?: string;
+      XAI_API_KEY?: string;
+
       // custom template for preprocessing user input
       DEFAULT_INPUT_TEMPLATE?: string;
     }
@@ -146,6 +150,7 @@ export const getServerSideConfig = () => {
   const isAlibaba = !!process.env.ALIBABA_API_KEY;
   const isMoonshot = !!process.env.MOONSHOT_API_KEY;
   const isIflytek = !!process.env.IFLYTEK_API_KEY;
+  const isXAI = !!process.env.XAI_API_KEY;
   // const apiKeyEnvVar = process.env.OPENAI_API_KEY ?? "";
   // const apiKeys = apiKeyEnvVar.split(",").map((v) => v.trim());
   // const randomIndex = Math.floor(Math.random() * apiKeys.length);
@@ -208,6 +213,10 @@ export const getServerSideConfig = () => {
     iflytekApiKey: process.env.IFLYTEK_API_KEY,
     iflytekApiSecret: process.env.IFLYTEK_API_SECRET,
 
+    isXAI,
+    xaiUrl: process.env.XAI_URL,
+    xaiApiKey: getApiKey(process.env.XAI_API_KEY),
+
     cloudflareAccountId: process.env.CLOUDFLARE_ACCOUNT_ID,
     cloudflareKVNamespaceId: process.env.CLOUDFLARE_KV_NAMESPACE_ID,
     cloudflareKVApiKey: getApiKey(process.env.CLOUDFLARE_KV_API_KEY),

+ 23 - 0
app/constant.ts

@@ -28,6 +28,8 @@ export const TENCENT_BASE_URL = "https://hunyuan.tencentcloudapi.com";
 export const MOONSHOT_BASE_URL = "https://api.moonshot.cn";
 export const IFLYTEK_BASE_URL = "https://spark-api-open.xf-yun.com";
 
+export const XAI_BASE_URL = "https://api.x.ai";
+
 export const CACHE_URL_PREFIX = "/api/cache";
 export const UPLOAD_URL = `${CACHE_URL_PREFIX}/upload`;
 
@@ -59,6 +61,7 @@ export enum ApiPath {
   Iflytek = "/api/iflytek",
   Stability = "/api/stability",
   Artifacts = "/api/artifacts",
+  XAI = "/api/xai",
 }
 
 export enum SlotID {
@@ -111,6 +114,7 @@ export enum ServiceProvider {
   Moonshot = "Moonshot",
   Stability = "Stability",
   Iflytek = "Iflytek",
+  XAI = "XAI",
 }
 
 // Google API safety settings, see https://ai.google.dev/gemini-api/docs/safety-settings
@@ -133,6 +137,7 @@ export enum ModelProvider {
   Hunyuan = "Hunyuan",
   Moonshot = "Moonshot",
   Iflytek = "Iflytek",
+  XAI = "XAI",
 }
 
 export const Stability = {
@@ -215,6 +220,11 @@ export const Iflytek = {
   ChatPath: "v1/chat/completions",
 };
 
+export const XAI = {
+  ExampleEndpoint: XAI_BASE_URL,
+  ChatPath: "v1/chat/completions",
+};
+
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
 // export const DEFAULT_SYSTEM_TEMPLATE = `
 // You are ChatGPT, a large language model trained by {{ServiceProvider}}.
@@ -364,6 +374,8 @@ const iflytekModels = [
   "4.0Ultra",
 ];
 
+const xAIModes = ["grok-beta"];
+
 let seq = 1000; // 内置的模型序号生成器从1000开始
 export const DEFAULT_MODELS = [
   ...openaiModels.map((name) => ({
@@ -476,6 +488,17 @@ export const DEFAULT_MODELS = [
       sorted: 10,
     },
   })),
+  ...xAIModes.map((name) => ({
+    name,
+    available: true,
+    sorted: seq++,
+    provider: {
+      id: "xai",
+      providerName: "XAI",
+      providerType: "xai",
+      sorted: 11,
+    },
+  })),
 ] as const;
 
 export const CHAT_PAGE_SIZE = 15;

+ 11 - 0
app/locales/cn.ts

@@ -462,6 +462,17 @@ const cn = {
           SubTitle: "样例:",
         },
       },
+      XAI: {
+        ApiKey: {
+          Title: "接口密钥",
+          SubTitle: "使用自定义XAI API Key",
+          Placeholder: "XAI API Key",
+        },
+        Endpoint: {
+          Title: "接口地址",
+          SubTitle: "样例:",
+        },
+      },
       Stability: {
         ApiKey: {
           Title: "接口密钥",

+ 11 - 0
app/locales/en.ts

@@ -446,6 +446,17 @@ const en: LocaleType = {
           SubTitle: "Example: ",
         },
       },
+      XAI: {
+        ApiKey: {
+          Title: "XAI API Key",
+          SubTitle: "Use a custom XAI API Key",
+          Placeholder: "XAI API Key",
+        },
+        Endpoint: {
+          Title: "Endpoint Address",
+          SubTitle: "Example: ",
+        },
+      },
       Stability: {
         ApiKey: {
           Title: "Stability API Key",

+ 12 - 0
app/store/access.ts

@@ -13,6 +13,7 @@ import {
   MOONSHOT_BASE_URL,
   STABILITY_BASE_URL,
   IFLYTEK_BASE_URL,
+  XAI_BASE_URL,
 } from "../constant";
 import { getHeaders } from "../client/api";
 import { getClientConfig } from "../config/client";
@@ -44,6 +45,8 @@ const DEFAULT_STABILITY_URL = isApp ? STABILITY_BASE_URL : ApiPath.Stability;
 
 const DEFAULT_IFLYTEK_URL = isApp ? IFLYTEK_BASE_URL : ApiPath.Iflytek;
 
+const DEFAULT_XAI_URL = isApp ? XAI_BASE_URL : ApiPath.XAI;
+
 const DEFAULT_ACCESS_STATE = {
   accessCode: "",
   useCustomConfig: false,
@@ -101,6 +104,10 @@ const DEFAULT_ACCESS_STATE = {
   iflytekApiKey: "",
   iflytekApiSecret: "",
 
+  // moonshot
+  xaiUrl: DEFAULT_XAI_URL,
+  xaiApiKey: "",
+
   // server config
   needCode: true,
   hideUserApiKey: false,
@@ -169,6 +176,10 @@ export const useAccessStore = createPersistStore(
       return ensure(get(), ["iflytekApiKey"]);
     },
 
+    isValidXAI() {
+      return ensure(get(), ["xaiApiKey"]);
+    },
+
     isAuthorized() {
       this.fetch();
 
@@ -184,6 +195,7 @@ export const useAccessStore = createPersistStore(
         this.isValidTencent() ||
         this.isValidMoonshot() ||
         this.isValidIflytek() ||
+        this.isValidXAI() ||
         !this.enabledAccessControl() ||
         (this.enabledAccessControl() && ensure(get(), ["accessCode"]))
       );