Prechádzať zdrojové kódy

feat: move sd config to store

Dogtiti 1 rok pred
rodič
commit
82e6fd7bb5

+ 42 - 0
app/components/sd/sd-new.tsx

@@ -17,6 +17,9 @@ import {
   useDragSideBar,
   useHotKey,
 } from "@/app/components/sidebar";
+import { getParams, getModelParamBasicData } from "./sd-panel";
+import { useSdStore } from "@/app/store/sd";
+import { showToast } from "@/app/components/ui-lib";
 
 const SdPanel = dynamic(
   async () => (await import("@/app/components/sd")).SdPanel,
@@ -29,6 +32,37 @@ export function SdNew() {
   useHotKey();
   const { onDragStart, shouldNarrow } = useDragSideBar();
   const navigate = useNavigate();
+  const sdStore = useSdStore();
+  const currentModel = sdStore.currentModel;
+  const params = sdStore.currentParams;
+  const setParams = sdStore.setCurrentParams;
+
+  const handleSubmit = () => {
+    const columns = getParams?.(currentModel, params);
+    const reqParams: any = {};
+    for (let i = 0; i < columns.length; i++) {
+      const item = columns[i];
+      reqParams[item.value] = params[item.value] ?? null;
+      if (item.required) {
+        if (!reqParams[item.value]) {
+          showToast(Locale.SdPanel.ParamIsRequired(item.name));
+          return;
+        }
+      }
+    }
+    let data: any = {
+      model: currentModel.value,
+      model_name: currentModel.name,
+      status: "wait",
+      params: reqParams,
+      created_at: new Date().toLocaleString(),
+      img_data: "",
+    };
+    sdStore.sendTask(data, () => {
+      setParams(getModelParamBasicData(columns, params, true));
+      navigate(Path.Sd);
+    });
+  };
   return (
     <SideBarContainer
       onDragStart={onDragStart}
@@ -79,6 +113,14 @@ export function SdNew() {
             <IconButton icon={<GithubIcon />} shadow />
           </a>
         }
+        secondaryAction={
+          <IconButton
+            text={Locale.SdPanel.Submit}
+            type="primary"
+            shadow
+            onClick={handleSubmit}
+          ></IconButton>
+        }
       />
     </SideBarContainer>
   );

+ 114 - 140
app/components/sd/sd-panel.tsx

@@ -1,110 +1,110 @@
 import styles from "./sd-panel.module.scss";
-import React, { useState } from "react";
-import { Select, showToast } from "@/app/components/ui-lib";
+import React from "react";
+import { Select } from "@/app/components/ui-lib";
 import { IconButton } from "@/app/components/button";
 import Locale from "@/app/locales";
-import { nanoid } from "nanoid";
-import { StoreKey } from "@/app/constant";
 import { useSdStore } from "@/app/store/sd";
 
+export const params = [
+  {
+    name: Locale.SdPanel.Prompt,
+    value: "prompt",
+    type: "textarea",
+    placeholder: Locale.SdPanel.PleaseInput(Locale.SdPanel.Prompt),
+    required: true,
+  },
+  {
+    name: Locale.SdPanel.ModelVersion,
+    value: "model",
+    type: "select",
+    default: "sd3-medium",
+    support: ["sd3"],
+    options: [
+      { name: "SD3 Medium", value: "sd3-medium" },
+      { name: "SD3 Large", value: "sd3-large" },
+      { name: "SD3 Large Turbo", value: "sd3-large-turbo" },
+    ],
+  },
+  {
+    name: Locale.SdPanel.NegativePrompt,
+    value: "negative_prompt",
+    type: "textarea",
+    placeholder: Locale.SdPanel.PleaseInput(Locale.SdPanel.NegativePrompt),
+  },
+  {
+    name: Locale.SdPanel.AspectRatio,
+    value: "aspect_ratio",
+    type: "select",
+    default: "1:1",
+    options: [
+      { name: "1:1", value: "1:1" },
+      { name: "16:9", value: "16:9" },
+      { name: "21:9", value: "21:9" },
+      { name: "2:3", value: "2:3" },
+      { name: "3:2", value: "3:2" },
+      { name: "4:5", value: "4:5" },
+      { name: "5:4", value: "5:4" },
+      { name: "9:16", value: "9:16" },
+      { name: "9:21", value: "9:21" },
+    ],
+  },
+  {
+    name: Locale.SdPanel.ImageStyle,
+    value: "style",
+    type: "select",
+    default: "3d-model",
+    support: ["core"],
+    options: [
+      { name: Locale.SdPanel.Styles.D3Model, value: "3d-model" },
+      { name: Locale.SdPanel.Styles.AnalogFilm, value: "analog-film" },
+      { name: Locale.SdPanel.Styles.Anime, value: "anime" },
+      { name: Locale.SdPanel.Styles.Cinematic, value: "cinematic" },
+      { name: Locale.SdPanel.Styles.ComicBook, value: "comic-book" },
+      { name: Locale.SdPanel.Styles.DigitalArt, value: "digital-art" },
+      { name: Locale.SdPanel.Styles.Enhance, value: "enhance" },
+      { name: Locale.SdPanel.Styles.FantasyArt, value: "fantasy-art" },
+      { name: Locale.SdPanel.Styles.Isometric, value: "isometric" },
+      { name: Locale.SdPanel.Styles.LineArt, value: "line-art" },
+      { name: Locale.SdPanel.Styles.LowPoly, value: "low-poly" },
+      {
+        name: Locale.SdPanel.Styles.ModelingCompound,
+        value: "modeling-compound",
+      },
+      { name: Locale.SdPanel.Styles.NeonPunk, value: "neon-punk" },
+      { name: Locale.SdPanel.Styles.Origami, value: "origami" },
+      { name: Locale.SdPanel.Styles.Photographic, value: "photographic" },
+      { name: Locale.SdPanel.Styles.PixelArt, value: "pixel-art" },
+      { name: Locale.SdPanel.Styles.TileTexture, value: "tile-texture" },
+    ],
+  },
+  {
+    name: "Seed",
+    value: "seed",
+    type: "number",
+    default: 0,
+    min: 0,
+    max: 4294967294,
+  },
+  {
+    name: Locale.SdPanel.OutFormat,
+    value: "output_format",
+    type: "select",
+    default: "png",
+    options: [
+      { name: "PNG", value: "png" },
+      { name: "JPEG", value: "jpeg" },
+      { name: "WebP", value: "webp" },
+    ],
+  },
+];
+
 const sdCommonParams = (model: string, data: any) => {
-  return [
-    {
-      name: Locale.SdPanel.Prompt,
-      value: "prompt",
-      type: "textarea",
-      placeholder: Locale.SdPanel.PleaseInput(Locale.SdPanel.Prompt),
-      required: true,
-    },
-    {
-      name: Locale.SdPanel.ModelVersion,
-      value: "model",
-      type: "select",
-      default: "sd3-medium",
-      support: ["sd3"],
-      options: [
-        { name: "SD3 Medium", value: "sd3-medium" },
-        { name: "SD3 Large", value: "sd3-large" },
-        { name: "SD3 Large Turbo", value: "sd3-large-turbo" },
-      ],
-    },
-    {
-      name: Locale.SdPanel.NegativePrompt,
-      value: "negative_prompt",
-      type: "textarea",
-      placeholder: Locale.SdPanel.PleaseInput(Locale.SdPanel.NegativePrompt),
-    },
-    {
-      name: Locale.SdPanel.AspectRatio,
-      value: "aspect_ratio",
-      type: "select",
-      default: "1:1",
-      options: [
-        { name: "1:1", value: "1:1" },
-        { name: "16:9", value: "16:9" },
-        { name: "21:9", value: "21:9" },
-        { name: "2:3", value: "2:3" },
-        { name: "3:2", value: "3:2" },
-        { name: "4:5", value: "4:5" },
-        { name: "5:4", value: "5:4" },
-        { name: "9:16", value: "9:16" },
-        { name: "9:21", value: "9:21" },
-      ],
-    },
-    {
-      name: Locale.SdPanel.ImageStyle,
-      value: "style",
-      type: "select",
-      default: "3d",
-      support: ["core"],
-      options: [
-        { name: Locale.SdPanel.Styles.D3Model, value: "3d-model" },
-        { name: Locale.SdPanel.Styles.AnalogFilm, value: "analog-film" },
-        { name: Locale.SdPanel.Styles.Anime, value: "anime" },
-        { name: Locale.SdPanel.Styles.Cinematic, value: "cinematic" },
-        { name: Locale.SdPanel.Styles.ComicBook, value: "comic-book" },
-        { name: Locale.SdPanel.Styles.DigitalArt, value: "digital-art" },
-        { name: Locale.SdPanel.Styles.Enhance, value: "enhance" },
-        { name: Locale.SdPanel.Styles.FantasyArt, value: "fantasy-art" },
-        { name: Locale.SdPanel.Styles.Isometric, value: "isometric" },
-        { name: Locale.SdPanel.Styles.LineArt, value: "line-art" },
-        { name: Locale.SdPanel.Styles.LowPoly, value: "low-poly" },
-        {
-          name: Locale.SdPanel.Styles.ModelingCompound,
-          value: "modeling-compound",
-        },
-        { name: Locale.SdPanel.Styles.NeonPunk, value: "neon-punk" },
-        { name: Locale.SdPanel.Styles.Origami, value: "origami" },
-        { name: Locale.SdPanel.Styles.Photographic, value: "photographic" },
-        { name: Locale.SdPanel.Styles.PixelArt, value: "pixel-art" },
-        { name: Locale.SdPanel.Styles.TileTexture, value: "tile-texture" },
-      ],
-    },
-    {
-      name: "Seed",
-      value: "seed",
-      type: "number",
-      default: 0,
-      min: 0,
-      max: 4294967294,
-    },
-    {
-      name: Locale.SdPanel.OutFormat,
-      value: "output_format",
-      type: "select",
-      default: "png",
-      options: [
-        { name: "PNG", value: "png" },
-        { name: "JPEG", value: "jpeg" },
-        { name: "WebP", value: "webp" },
-      ],
-    },
-  ].filter((item) => {
+  return params.filter((item) => {
     return !(item.support && !item.support.includes(model));
   });
 };
 
-const models = [
+export const models = [
   {
     name: "Stable Image Ultra",
     value: "ultra",
@@ -162,7 +162,7 @@ export function ControlParam(props: {
 }) {
   return (
     <>
-      {props.columns.map((item) => {
+      {props.columns?.map((item) => {
         let element: null | JSX.Element;
         switch (item.type) {
           case "textarea":
@@ -251,7 +251,7 @@ export function ControlParam(props: {
   );
 }
 
-const getModelParamBasicData = (
+export const getModelParamBasicData = (
   columns: any[],
   data: any,
   clearText?: boolean,
@@ -268,47 +268,28 @@ const getModelParamBasicData = (
   return newParams;
 };
 
+export const getParams = (model: any, params: any) => {
+  return models.find((m) => m.value === model.value)?.params(params) || [];
+};
+
 export function SdPanel() {
-  const [currentModel, setCurrentModel] = useState(models[0]);
-  const [params, setParams] = useState(
-    getModelParamBasicData(currentModel.params({}), {}),
-  );
+  const sdStore = useSdStore();
+  const currentModel = sdStore.currentModel;
+  const setCurrentModel = sdStore.setCurrentModel;
+  const params = sdStore.currentParams;
+  const setParams = sdStore.setCurrentParams;
+
   const handleValueChange = (field: string, val: any) => {
-    setParams((prevParams: any) => ({
-      ...prevParams,
+    setParams({
+      ...params,
       [field]: val,
-    }));
+    });
   };
   const handleModelChange = (model: any) => {
     setCurrentModel(model);
     setParams(getModelParamBasicData(model.params({}), params));
   };
-  const sdStore = useSdStore();
-  const handleSubmit = () => {
-    const columns = currentModel.params(params);
-    const reqParams: any = {};
-    for (let i = 0; i < columns.length; i++) {
-      const item = columns[i];
-      reqParams[item.value] = params[item.value] ?? null;
-      if (item.required) {
-        if (!reqParams[item.value]) {
-          showToast(Locale.SdPanel.ParamIsRequired(item.name));
-          return;
-        }
-      }
-    }
-    let data: any = {
-      model: currentModel.value,
-      model_name: currentModel.name,
-      status: "wait",
-      params: reqParams,
-      created_at: new Date().toLocaleString(),
-      img_data: "",
-    };
-    sdStore.sendTask(data, () => {
-      setParams(getModelParamBasicData(columns, params, true));
-    });
-  };
+
   return (
     <>
       <ControlParamItem title={Locale.SdPanel.AIModel}>
@@ -327,17 +308,10 @@ export function SdPanel() {
         </div>
       </ControlParamItem>
       <ControlParam
-        columns={currentModel.params(params) as any[]}
+        columns={getParams?.(currentModel, params) as any[]}
         data={params}
         onChange={handleValueChange}
       ></ControlParam>
-      <IconButton
-        text={Locale.SdPanel.Submit}
-        type="primary"
-        style={{ marginTop: "20px" }}
-        shadow
-        onClick={handleSubmit}
-      ></IconButton>
     </>
   );
 }

+ 42 - 0
app/components/sd/sd-sidebar.tsx

@@ -17,6 +17,10 @@ import {
   useHotKey,
 } from "@/app/components/sidebar";
 
+import { getParams, getModelParamBasicData } from "./sd-panel";
+import { useSdStore } from "@/app/store/sd";
+import { showToast } from "@/app/components/ui-lib";
+
 const SdPanel = dynamic(
   async () => (await import("@/app/components/sd")).SdPanel,
   {
@@ -28,6 +32,36 @@ export function SideBar(props: { className?: string }) {
   useHotKey();
   const { onDragStart, shouldNarrow } = useDragSideBar();
   const navigate = useNavigate();
+  const sdStore = useSdStore();
+  const currentModel = sdStore.currentModel;
+  const params = sdStore.currentParams;
+  const setParams = sdStore.setCurrentParams;
+
+  const handleSubmit = () => {
+    const columns = getParams?.(currentModel, params);
+    const reqParams: any = {};
+    for (let i = 0; i < columns.length; i++) {
+      const item = columns[i];
+      reqParams[item.value] = params[item.value] ?? null;
+      if (item.required) {
+        if (!reqParams[item.value]) {
+          showToast(Locale.SdPanel.ParamIsRequired(item.name));
+          return;
+        }
+      }
+    }
+    let data: any = {
+      model: currentModel.value,
+      model_name: currentModel.name,
+      status: "wait",
+      params: reqParams,
+      created_at: new Date().toLocaleString(),
+      img_data: "",
+    };
+    sdStore.sendTask(data, () => {
+      setParams(getModelParamBasicData(columns, params, true));
+    });
+  };
 
   return (
     <SideBarContainer
@@ -55,6 +89,14 @@ export function SideBar(props: { className?: string }) {
             <IconButton icon={<GithubIcon />} shadow />
           </a>
         }
+        secondaryAction={
+          <IconButton
+            text={Locale.SdPanel.Submit}
+            type="primary"
+            shadow
+            onClick={handleSubmit}
+          ></IconButton>
+        }
       />
     </SideBarContainer>
   );

+ 48 - 9
app/components/sd/sd.tsx

@@ -33,6 +33,7 @@ import {
 import { removeImage } from "@/app/utils/chat";
 import { SideBar } from "./sd-sidebar";
 import { WindowContent } from "@/app/components/home";
+import { params } from "./sd-panel";
 
 function getSdTaskStatus(item: any) {
   let s: string;
@@ -216,15 +217,53 @@ export function Sd() {
                                   title: locales.Sd.GenerateParams,
                                   children: (
                                     <div style={{ userSelect: "text" }}>
-                                      {Object.keys(item.params).map((key) => (
-                                        <div
-                                          key={key}
-                                          style={{ margin: "10px" }}
-                                        >
-                                          <strong>{key}: </strong>
-                                          {item.params[key]}
-                                        </div>
-                                      ))}
+                                      {Object.keys(item.params).map((key) => {
+                                        let label = key;
+                                        let value = item.params[key];
+                                        switch (label) {
+                                          case "prompt":
+                                            label = Locale.SdPanel.Prompt;
+                                            break;
+                                          case "negative_prompt":
+                                            label =
+                                              Locale.SdPanel.NegativePrompt;
+                                            break;
+                                          case "aspect_ratio":
+                                            label = Locale.SdPanel.AspectRatio;
+                                            break;
+                                          case "seed":
+                                            label = "Seed";
+                                            value = value || 0;
+                                            break;
+                                          case "output_format":
+                                            label = Locale.SdPanel.OutFormat;
+                                            value = value?.toUpperCase();
+                                            break;
+                                          case "style":
+                                            label = Locale.SdPanel.ImageStyle;
+                                            value = params
+                                              .find(
+                                                (item) =>
+                                                  item.value === "style",
+                                              )
+                                              ?.options?.find(
+                                                (item) => item.value === value,
+                                              )?.name;
+                                            break;
+                                          default:
+                                            break;
+                                        }
+
+                                        return (
+                                          <div
+                                            key={key}
+                                            style={{ margin: "10px" }}
+                                          >
+                                            <strong>{label}: </strong>
+                                            {value}
+                                          </div>
+                                        );
+                                      })}
                                     </div>
                                   ),
                                 });

+ 28 - 5
app/store/sd.ts

@@ -1,25 +1,40 @@
 import { StabilityPath, StoreKey } from "@/app/constant";
-import { showToast } from "@/app/components/ui-lib";
 import { getHeaders } from "@/app/client/api";
 import { createPersistStore } from "@/app/utils/store";
 import { nanoid } from "nanoid";
 import { uploadImage, base64Image2Blob } from "@/app/utils/chat";
+import { models, getModelParamBasicData } from "@/app/components/sd/sd-panel";
+
+const defaultModel = {
+  name: models[0].name,
+  value: models[0].value,
+};
+
+const defaultParams = getModelParamBasicData(models[0].params({}), {});
+
+const DEFAULT_SD_STATE = {
+  currentId: 0,
+  draw: [],
+  currentModel: defaultModel,
+  currentParams: defaultParams,
+};
 
 export const useSdStore = createPersistStore<
   {
     currentId: number;
     draw: any[];
+    currentModel: typeof defaultModel;
+    currentParams: any;
   },
   {
     getNextId: () => number;
     sendTask: (data: any, okCall?: Function) => void;
     updateDraw: (draw: any) => void;
+    setCurrentModel: (model: any) => void;
+    setCurrentParams: (data: any) => void;
   }
 >(
-  {
-    currentId: 0,
-    draw: [],
-  },
+  DEFAULT_SD_STATE,
   (set, _get) => {
     function get() {
       return {
@@ -111,6 +126,14 @@ export const useSdStore = createPersistStore<
           }
         });
       },
+      setCurrentModel(model: any) {
+        set({ currentModel: model });
+      },
+      setCurrentParams(data: any) {
+        set({
+          currentParams: data,
+        });
+      },
     };
 
     return methods;