浏览代码

change build messages for qwen in client

lloydzhou 1 年之前
父节点
当前提交
32b82b9cb3
共有 2 个文件被更改,包括 36 次插入38 次删除
  1. 10 25
      app/api/alibaba/[...path]/route.ts
  2. 26 13
      app/client/platforms/alibaba.ts

+ 10 - 25
app/api/alibaba/[...path]/route.ts

@@ -91,34 +91,14 @@ async function request(req: NextRequest) {
   );
 
   const fetchUrl = `${baseUrl}${path}`;
-
-  const clonedBody = await req.text();
-
-  const { messages, model, stream, top_p, ...rest } = JSON.parse(
-    clonedBody,
-  ) as RequestPayload;
-
-  const requestBody = {
-    model,
-    input: {
-      messages,
-    },
-    parameters: {
-      ...rest,
-      top_p: top_p === 1 ? 0.99 : top_p, // qwen top_p is should be < 1
-      result_format: "message",
-      incremental_output: true,
-    },
-  };
-
   const fetchOptions: RequestInit = {
     headers: {
       "Content-Type": "application/json",
       Authorization: req.headers.get("Authorization") ?? "",
-      "X-DashScope-SSE": stream ? "enable" : "disable",
+      "X-DashScope-SSE": req.headers.get("X-DashScope-SSE") ?? "disable",
     },
     method: req.method,
-    body: JSON.stringify(requestBody),
+    body: req.body,
     redirect: "manual",
     // @ts-ignore
     duplex: "half",
@@ -128,18 +108,23 @@ async function request(req: NextRequest) {
   // #1815 try to refuse some request to some models
   if (serverConfig.customModels && req.body) {
     try {
+      const clonedBody = await req.text();
+      fetchOptions.body = clonedBody;
+
+      const jsonBody = JSON.parse(clonedBody) as { model?: string };
+
       // not undefined and is false
       if (
         isModelAvailableInServer(
           serverConfig.customModels,
-          model as string,
-          ServiceProvider.Alibaba as string,
+          jsonBody?.model as string,
+          ServiceProvider.ByteDance as string,
         )
       ) {
         return NextResponse.json(
           {
             error: true,
-            message: `you are not allowed to use ${model} model`,
+            message: `you are not allowed to use ${jsonBody?.model} model`,
           },
           {
             status: 403,

+ 26 - 13
app/client/platforms/alibaba.ts

@@ -32,19 +32,25 @@ export interface OpenAIListModelResponse {
   }>;
 }
 
-interface RequestPayload {
+interface RequestInput {
   messages: {
     role: "system" | "user" | "assistant";
     content: string | MultimodalContent[];
   }[];
-  stream?: boolean;
-  model: string;
+}
+interface RequestParam {
+  result_format: string;
+  incremental_output?: boolean;
   temperature: number;
-  presence_penalty: number;
-  frequency_penalty: number;
+  repetition_penalty: number;
   top_p: number;
   max_tokens?: number;
 }
+interface RequestPayload {
+  model: string;
+  input: RequestInput;
+  parameters: RequestParam;
+}
 
 export class QwenApi implements LLMApi {
   path(path: string): string {
@@ -91,17 +97,21 @@ export class QwenApi implements LLMApi {
       },
     };
 
+    const shouldStream = !!options.config.stream;
     const requestPayload: RequestPayload = {
-      messages,
-      stream: options.config.stream,
       model: modelConfig.model,
-      temperature: modelConfig.temperature,
-      presence_penalty: modelConfig.presence_penalty,
-      frequency_penalty: modelConfig.frequency_penalty,
-      top_p: modelConfig.top_p,
+      input: {
+        messages,
+      },
+      parameters: {
+        result_format: "message",
+        incremental_output: shouldStream,
+        temperature: modelConfig.temperature,
+        // max_tokens: modelConfig.max_tokens,
+        top_p: modelConfig.top_p === 1 ? 0.99 : modelConfig.top_p, // qwen top_p is should be < 1
+      },
     };
 
-    const shouldStream = !!options.config.stream;
     const controller = new AbortController();
     options.onController?.(controller);
 
@@ -111,7 +121,10 @@ export class QwenApi implements LLMApi {
         method: "POST",
         body: JSON.stringify(requestPayload),
         signal: controller.signal,
-        headers: getHeaders(),
+        headers: {
+          ...getHeaders(),
+          "X-DashScope-SSE": shouldStream ? "enable" : "disable",
+        },
       };
 
       // make a fetch request