Forráskód Böngészése

feat: Add Stability API server relay sending

licoy 1 éve
szülő
commit
2b0153807c

+ 3 - 0
app/api/auth.ts

@@ -73,6 +73,9 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) {
       case ModelProvider.Claude:
         systemApiKey = serverConfig.anthropicApiKey;
         break;
+      case ModelProvider.Stability:
+        systemApiKey = serverConfig.stabilityApiKey;
+        break;
       case ModelProvider.GPT:
       default:
         if (serverConfig.isAzure) {

+ 104 - 0
app/api/stability/[...path]/route.ts

@@ -0,0 +1,104 @@
+import { NextRequest, NextResponse } from "next/server";
+import { getServerSideConfig } from "@/app/config/server";
+import { ModelProvider, STABILITY_BASE_URL } from "@/app/constant";
+import { auth } from "@/app/api/auth";
+
+async function handle(
+  req: NextRequest,
+  { params }: { params: { path: string[] } },
+) {
+  console.log("[Stability] params ", params);
+
+  if (req.method === "OPTIONS") {
+    return NextResponse.json({ body: "OK" }, { status: 200 });
+  }
+
+  const controller = new AbortController();
+
+  const serverConfig = getServerSideConfig();
+
+  let baseUrl = serverConfig.stabilityUrl || STABILITY_BASE_URL;
+
+  if (!baseUrl.startsWith("http")) {
+    baseUrl = `https://${baseUrl}`;
+  }
+
+  if (baseUrl.endsWith("/")) {
+    baseUrl = baseUrl.slice(0, -1);
+  }
+
+  let path = `${req.nextUrl.pathname}`.replaceAll("/api/stability/", "");
+
+  console.log("[Stability Proxy] ", path);
+  console.log("[Stability Base Url]", baseUrl);
+
+  const timeoutId = setTimeout(
+    () => {
+      controller.abort();
+    },
+    10 * 60 * 1000,
+  );
+
+  const authResult = auth(req, ModelProvider.Stability);
+
+  if (authResult.error) {
+    return NextResponse.json(authResult, {
+      status: 401,
+    });
+  }
+
+  const bearToken = req.headers.get("Authorization") ?? "";
+  const token = bearToken.trim().replaceAll("Bearer ", "").trim();
+
+  const key = token ? token : serverConfig.stabilityApiKey;
+
+  if (!key) {
+    return NextResponse.json(
+      {
+        error: true,
+        message: `missing STABILITY_API_KEY in server env vars`,
+      },
+      {
+        status: 401,
+      },
+    );
+  }
+
+  const fetchUrl = `${baseUrl}/${path}`;
+  console.log("[Stability Url] ", fetchUrl);
+  const fetchOptions: RequestInit = {
+    headers: {
+      "Content-Type": req.headers.get("Content-Type") || "multipart/form-data",
+      Accept: req.headers.get("Accept") || "application/json",
+      Authorization: `Bearer ${key}`,
+    },
+    method: req.method,
+    body: req.body,
+    // to fix #2485: https://stackoverflow.com/questions/55920957/cloudflare-worker-typeerror-one-time-use-body
+    redirect: "manual",
+    // @ts-ignore
+    duplex: "half",
+    signal: controller.signal,
+  };
+
+  try {
+    const res = await fetch(fetchUrl, fetchOptions);
+    // to prevent browser prompt for credentials
+    const newHeaders = new Headers(res.headers);
+    newHeaders.delete("www-authenticate");
+    // to disable nginx buffering
+    newHeaders.set("X-Accel-Buffering", "no");
+    return new Response(res.body, {
+      status: res.status,
+      statusText: res.statusText,
+      headers: newHeaders,
+    });
+  } finally {
+    clearTimeout(timeoutId);
+  }
+}
+
+export const GET = handle;
+export const POST = handle;
+
+export const runtime = "edge";

+ 4 - 14
app/components/sd-panel.tsx

@@ -307,22 +307,12 @@ export function SdPanel() {
       model_name: currentModel.name,
       status: "wait",
       params: reqParams,
-      created_at: new Date().toISOString(),
+      created_at: new Date().toLocaleString(),
       img_data: "",
     };
-    sdListDb.add(data).then(
-      (id) => {
-        data = { ...data, id, status: "running" };
-        sdListDb.update(data);
-        execCountInc();
-        sendSdTask(data, sdListDb, execCountInc);
-        setParams(getModelParamBasicData(columns, params, true));
-      },
-      (error) => {
-        console.error(error);
-        showToast(`error: ` + error.message);
-      },
-    );
+    sendSdTask(data, sdListDb, execCountInc, () => {
+      setParams(getModelParamBasicData(columns, params, true));
+    });
   };
   return (
     <>

+ 71 - 12
app/components/sd.tsx

@@ -9,7 +9,6 @@ import {
   copyToClipboard,
   getMessageTextContent,
   useMobileScreen,
-  useWindowSize,
 } from "@/app/utils";
 import { useNavigate } from "react-router-dom";
 import { useAppConfig } from "@/app/store";
@@ -22,14 +21,18 @@ import CopyIcon from "@/app/icons/copy.svg";
 import PromptIcon from "@/app/icons/prompt.svg";
 import ResetIcon from "@/app/icons/reload.svg";
 import { useIndexedDB } from "react-indexed-db-hook";
-import { useSdStore } from "@/app/store/sd";
+import { sendSdTask, useSdStore } from "@/app/store/sd";
 import locales from "@/app/locales";
 import LoadingIcon from "../icons/three-dots.svg";
 import ErrorIcon from "../icons/delete.svg";
 import { Property } from "csstype";
-import { showConfirm } from "@/app/components/ui-lib";
+import {
+  showConfirm,
+  showImageModal,
+  showModal,
+} from "@/app/components/ui-lib";
 
-function openBase64ImgUrl(base64Data: string, contentType: string) {
+function getBase64ImgUrl(base64Data: string, contentType: string) {
   const byteCharacters = atob(base64Data);
   const byteNumbers = new Array(byteCharacters.length);
   for (let i = 0; i < byteCharacters.length; i++) {
@@ -37,8 +40,7 @@ function openBase64ImgUrl(base64Data: string, contentType: string) {
   }
   const byteArray = new Uint8Array(byteNumbers);
   const blob = new Blob([byteArray], { type: contentType });
-  const blobUrl = URL.createObjectURL(blob);
-  window.open(blobUrl);
+  return URL.createObjectURL(blob);
 }
 
 function getSdTaskStatus(item: any) {
@@ -69,7 +71,24 @@ function getSdTaskStatus(item: any) {
       <span>
         {locales.Sd.Status.Name}: {s}
       </span>
-      {item.status === "error" && <span> - {item.error}</span>}
+      {item.status === "error" && (
+        <span
+          className="clickable"
+          onClick={() => {
+            showModal({
+              title: locales.Sd.Detail,
+              children: (
+                <div style={{ color: color, userSelect: "text" }}>
+                  {item.error}
+                </div>
+              ),
+            });
+          }}
+        >
+          {" "}
+          - {item.error}
+        </span>
+      )}
     </p>
   );
 }
@@ -83,7 +102,7 @@ export function Sd() {
   const scrollRef = useRef<HTMLDivElement>(null);
   const sdListDb = useIndexedDB(StoreKey.SdList);
   const [sdImages, setSdImages] = useState([]);
-  const { execCount } = useSdStore();
+  const { execCount, execCountInc } = useSdStore();
 
   useEffect(() => {
     sdListDb.getAll().then((data) => {
@@ -145,7 +164,10 @@ export function Sd() {
                       src={`data:image/png;base64,${item.img_data}`}
                       alt={`${item.id}`}
                       onClick={(e) => {
-                        openBase64ImgUrl(item.img_data, "image/png");
+                        showImageModal(
+                          getBase64ImgUrl(item.img_data, "image/png"),
+                          true,
+                        );
                       }}
                     />
                   ) : item.status === "error" ? (
@@ -163,7 +185,20 @@ export function Sd() {
                   >
                     <p className={styles["line-1"]}>
                       {locales.SdPanel.Prompt}:{" "}
-                      <span title={item.params.prompt}>
+                      <span
+                        className="clickable"
+                        title={item.params.prompt}
+                        onClick={() => {
+                          showModal({
+                            title: locales.Sd.Detail,
+                            children: (
+                              <div style={{ userSelect: "text" }}>
+                                {item.params.prompt}
+                              </div>
+                            ),
+                          });
+                        }}
+                      >
                         {item.params.prompt}
                       </span>
                     </p>
@@ -177,7 +212,21 @@ export function Sd() {
                         <ChatAction
                           text={Locale.Sd.Actions.Params}
                           icon={<PromptIcon />}
-                          onClick={() => console.log(1)}
+                          onClick={() => {
+                            showModal({
+                              title: locales.Sd.GenerateParams,
+                              children: (
+                                <div style={{ userSelect: "text" }}>
+                                  {Object.keys(item.params).map((key) => (
+                                    <div key={key} style={{ margin: "10px" }}>
+                                      <strong>{key}: </strong>
+                                      {item.params[key]}
+                                    </div>
+                                  ))}
+                                </div>
+                              ),
+                            });
+                          }}
                         />
                         <ChatAction
                           text={Locale.Sd.Actions.Copy}
@@ -194,7 +243,17 @@ export function Sd() {
                         <ChatAction
                           text={Locale.Sd.Actions.Retry}
                           icon={<ResetIcon />}
-                          onClick={() => console.log(1)}
+                          onClick={() => {
+                            const reqData = {
+                              model: item.model,
+                              model_name: item.model_name,
+                              status: "wait",
+                              params: { ...item.params },
+                              created_at: new Date().toLocaleString(),
+                              img_data: "",
+                            };
+                            sendSdTask(reqData, sdListDb, execCountInc);
+                          }}
                         />
                         <ChatAction
                           text={Locale.Sd.Actions.Delete}

+ 3 - 2
app/components/ui-lib.tsx

@@ -425,11 +425,12 @@ export function showPrompt(content: any, value = "", rows = 3) {
   });
 }
 
-export function showImageModal(img: string) {
+export function showImageModal(img: string, defaultMax?: boolean) {
   showModal({
     title: Locale.Export.Image.Modal,
+    defaultMax: defaultMax,
     children: (
-      <div>
+      <div style={{ display: "flex", justifyContent: "center" }}>
         <img
           src={img}
           alt="preview"

+ 3 - 0
app/config/server.ts

@@ -124,6 +124,9 @@ export const getServerSideConfig = () => {
     anthropicApiVersion: process.env.ANTHROPIC_API_VERSION,
     anthropicUrl: process.env.ANTHROPIC_URL,
 
+    stabilityUrl: process.env.STABILITY_URL,
+    stabilityApiKey: process.env.STABILITY_API_KEY,
+
     gtmId: process.env.GTM_ID,
 
     needCode: ACCESS_CODES.size > 0,

+ 8 - 0
app/constant.ts

@@ -1,3 +1,5 @@
+import { stabilityRequestCall } from "@/app/store/sd";
+
 export const OWNER = "Yidadaa";
 export const REPO = "ChatGPT-Next-Web";
 export const REPO_URL = `https://github.com/${OWNER}/${REPO}`;
@@ -13,6 +15,7 @@ export const OPENAI_BASE_URL = "https://api.openai.com";
 export const ANTHROPIC_BASE_URL = "https://api.anthropic.com";
 
 export const GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/";
+export const STABILITY_BASE_URL = "https://api.stability.ai";
 
 export enum Path {
   Home = "/",
@@ -79,6 +82,7 @@ export enum ModelProvider {
   GPT = "GPT",
   GeminiPro = "GeminiPro",
   Claude = "Claude",
+  Stability = "Stability",
 }
 
 export const Anthropic = {
@@ -104,6 +108,10 @@ export const Google = {
   ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`,
 };
 
+export const StabilityPath = {
+  GeneratePath: "v2beta/stable-image/generate",
+};
+
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
 // export const DEFAULT_SYSTEM_TEMPLATE = `
 // You are ChatGPT, a large language model trained by {{ServiceProvider}}.

+ 2 - 0
app/locales/cn.ts

@@ -534,6 +534,8 @@ const cn = {
     Danger: {
       Delete: "确认删除?",
     },
+    GenerateParams: "生成参数",
+    Detail: "详情",
   },
 };
 

+ 2 - 0
app/locales/en.ts

@@ -540,6 +540,8 @@ const en: LocaleType = {
     Danger: {
       Delete: "Confirm to delete?",
     },
+    GenerateParams: "Generate Params",
+    Detail: "Detail",
   },
 };
 

+ 20 - 3
app/store/sd.ts

@@ -1,6 +1,7 @@
 import { initDB, useIndexedDB } from "react-indexed-db-hook";
-import { StoreKey } from "@/app/constant";
+import { StabilityPath, StoreKey } from "@/app/constant";
 import { create, StoreApi } from "zustand";
+import { showToast } from "@/app/components/ui-lib";
 
 export const SdDbConfig = {
   name: "@chatgpt-next-web/sd",
@@ -44,12 +45,28 @@ export const useSdStore = create<SdStore>()((set) => ({
   execCountInc: () => set((state) => ({ execCount: state.execCount + 1 })),
 }));
 
-export function sendSdTask(data: any, db: any, inc: any) {
+export function sendSdTask(data: any, db: any, inc: any, okCall?: Function) {
+  db.add(data).then(
+    (id: number) => {
+      data = { ...data, id, status: "running" };
+      db.update(data);
+      inc();
+      stabilityRequestCall(data, db, inc);
+      okCall?.();
+    },
+    (error: any) => {
+      console.error(error);
+      showToast(`error: ` + error.message);
+    },
+  );
+}
+
+export function stabilityRequestCall(data: any, db: any, inc: any) {
   const formData = new FormData();
   for (let paramsKey in data.params) {
     formData.append(paramsKey, data.params[paramsKey]);
   }
-  fetch("https://api.stability.ai/v2beta/stable-image/generate/" + data.model, {
+  fetch(`/api/stability/${StabilityPath.GeneratePath}/${data.model}`, {
     method: "POST",
     headers: {
       Accept: "application/json",