|
|
@@ -1,99 +1,122 @@
|
|
|
-import { initDB, useIndexedDB } from "react-indexed-db-hook";
|
|
|
import { StabilityPath, StoreKey } from "@/app/constant";
|
|
|
-import { create, StoreApi } from "zustand";
|
|
|
import { showToast } from "@/app/components/ui-lib";
|
|
|
import { getHeaders } from "@/app/client/api";
|
|
|
+import { createPersistStore } from "@/app/utils/store";
|
|
|
+import { nanoid } from "nanoid";
|
|
|
+import { uploadImage, base64Image2Blob } from "@/app/utils/chat";
|
|
|
|
|
|
-export const SdDbConfig = {
|
|
|
- name: "@chatgpt-next-web/sd",
|
|
|
- version: 1,
|
|
|
- objectStoresMeta: [
|
|
|
- {
|
|
|
- store: StoreKey.SdList,
|
|
|
- storeConfig: { keyPath: "id", autoIncrement: true },
|
|
|
- storeSchema: [
|
|
|
- { name: "model", keypath: "model", options: { unique: false } },
|
|
|
- {
|
|
|
- name: "model_name",
|
|
|
- keypath: "model_name",
|
|
|
- options: { unique: false },
|
|
|
- },
|
|
|
- { name: "status", keypath: "status", options: { unique: false } },
|
|
|
- { name: "params", keypath: "params", options: { unique: false } },
|
|
|
- { name: "img_data", keypath: "img_data", options: { unique: false } },
|
|
|
- { name: "error", keypath: "error", options: { unique: false } },
|
|
|
- {
|
|
|
- name: "created_at",
|
|
|
- keypath: "created_at",
|
|
|
- options: { unique: false },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- ],
|
|
|
-};
|
|
|
-
|
|
|
-export function SdDbInit() {
|
|
|
- initDB(SdDbConfig);
|
|
|
-}
|
|
|
-
|
|
|
-type SdStore = {
|
|
|
- execCount: number;
|
|
|
- execCountInc: () => void;
|
|
|
-};
|
|
|
-
|
|
|
-export const useSdStore = create<SdStore>()((set) => ({
|
|
|
- execCount: 1,
|
|
|
- execCountInc: () => set((state) => ({ execCount: state.execCount + 1 })),
|
|
|
-}));
|
|
|
+export const useSdStore = createPersistStore<
|
|
|
+ {
|
|
|
+ currentId: number;
|
|
|
+ draw: any[];
|
|
|
+ },
|
|
|
+ {
|
|
|
+ getNextId: () => number;
|
|
|
+ sendTask: (data: any, okCall?: Function) => void;
|
|
|
+ updateDraw: (draw: any) => void;
|
|
|
+ }
|
|
|
+>(
|
|
|
+ {
|
|
|
+ currentId: 0,
|
|
|
+ draw: [],
|
|
|
+ },
|
|
|
+ (set, _get) => {
|
|
|
+ function get() {
|
|
|
+ return {
|
|
|
+ ..._get(),
|
|
|
+ ...methods,
|
|
|
+ };
|
|
|
+ }
|
|
|
|
|
|
-export function sendSdTask(data: any, db: any, inc: any, okCall?: Function) {
|
|
|
- db.add(data).then(
|
|
|
- (id: number) => {
|
|
|
- data = { ...data, id, status: "running" };
|
|
|
- db.update(data);
|
|
|
- inc();
|
|
|
- stabilityRequestCall(data, db, inc);
|
|
|
- okCall?.();
|
|
|
- },
|
|
|
- (error: any) => {
|
|
|
- console.error(error);
|
|
|
- showToast(`error: ` + error.message);
|
|
|
- },
|
|
|
- );
|
|
|
-}
|
|
|
+ const methods = {
|
|
|
+ getNextId() {
|
|
|
+ const id = ++_get().currentId;
|
|
|
+ set({ currentId: id });
|
|
|
+ return id;
|
|
|
+ },
|
|
|
+ sendTask(data: any, okCall?: Function) {
|
|
|
+ data = { ...data, id: nanoid(), status: "running" };
|
|
|
+ set({ draw: [data, ..._get().draw] });
|
|
|
+ this.getNextId();
|
|
|
+ this.stabilityRequestCall(data);
|
|
|
+ okCall?.();
|
|
|
+ },
|
|
|
+ stabilityRequestCall(data: any) {
|
|
|
+ const formData = new FormData();
|
|
|
+ for (let paramsKey in data.params) {
|
|
|
+ formData.append(paramsKey, data.params[paramsKey]);
|
|
|
+ }
|
|
|
+ const headers = getHeaders();
|
|
|
+ delete headers["Content-Type"];
|
|
|
+ fetch(`/api/stability/${StabilityPath.GeneratePath}/${data.model}`, {
|
|
|
+ method: "POST",
|
|
|
+ headers: {
|
|
|
+ ...headers,
|
|
|
+ Accept: "application/json",
|
|
|
+ },
|
|
|
+ body: formData,
|
|
|
+ })
|
|
|
+ .then((response) => response.json())
|
|
|
+ .then((resData) => {
|
|
|
+ if (resData.errors && resData.errors.length > 0) {
|
|
|
+ this.updateDraw({
|
|
|
+ ...data,
|
|
|
+ status: "error",
|
|
|
+ error: resData.errors[0],
|
|
|
+ });
|
|
|
+ this.getNextId();
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (resData.finish_reason === "SUCCESS") {
|
|
|
+ const self = this;
|
|
|
+ uploadImage(base64Image2Blob(resData.image, "image/png"))
|
|
|
+ .then((img_data) => {
|
|
|
+ console.debug("uploadImage success", img_data, self);
|
|
|
+ self.updateDraw({
|
|
|
+ ...data,
|
|
|
+ status: "success",
|
|
|
+ img_data,
|
|
|
+ });
|
|
|
+ })
|
|
|
+ .catch((e) => {
|
|
|
+ console.error("uploadImage error", e);
|
|
|
+ self.updateDraw({
|
|
|
+ ...data,
|
|
|
+ status: "error",
|
|
|
+ error: JSON.stringify(resData),
|
|
|
+ });
|
|
|
+ });
|
|
|
+ } else {
|
|
|
+ self.updateDraw({
|
|
|
+ ...data,
|
|
|
+ status: "error",
|
|
|
+ error: JSON.stringify(resData),
|
|
|
+ });
|
|
|
+ }
|
|
|
+ this.getNextId();
|
|
|
+ })
|
|
|
+ .catch((error) => {
|
|
|
+ this.updateDraw({ ...data, status: "error", error: error.message });
|
|
|
+ console.error("Error:", error);
|
|
|
+ this.getNextId();
|
|
|
+ });
|
|
|
+ },
|
|
|
+ updateDraw(_draw: any) {
|
|
|
+ const draw = _get().draw || [];
|
|
|
+ draw.some((item, index) => {
|
|
|
+ if (item.id === _draw.id) {
|
|
|
+ draw[index] = _draw;
|
|
|
+ set(() => ({ draw }));
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ },
|
|
|
+ };
|
|
|
|
|
|
-export function stabilityRequestCall(data: any, db: any, inc: any) {
|
|
|
- const formData = new FormData();
|
|
|
- for (let paramsKey in data.params) {
|
|
|
- formData.append(paramsKey, data.params[paramsKey]);
|
|
|
- }
|
|
|
- const headers = getHeaders();
|
|
|
- delete headers["Content-Type"];
|
|
|
- fetch(`/api/stability/${StabilityPath.GeneratePath}/${data.model}`, {
|
|
|
- method: "POST",
|
|
|
- headers: {
|
|
|
- ...headers,
|
|
|
- Accept: "application/json",
|
|
|
- },
|
|
|
- body: formData,
|
|
|
- })
|
|
|
- .then((response) => response.json())
|
|
|
- .then((resData) => {
|
|
|
- if (resData.errors && resData.errors.length > 0) {
|
|
|
- db.update({ ...data, status: "error", error: resData.errors[0] });
|
|
|
- inc();
|
|
|
- return;
|
|
|
- }
|
|
|
- if (resData.finish_reason === "SUCCESS") {
|
|
|
- db.update({ ...data, status: "success", img_data: resData.image });
|
|
|
- } else {
|
|
|
- db.update({ ...data, status: "error", error: JSON.stringify(resData) });
|
|
|
- }
|
|
|
- inc();
|
|
|
- })
|
|
|
- .catch((error) => {
|
|
|
- db.update({ ...data, status: "error", error: error.message });
|
|
|
- console.error("Error:", error);
|
|
|
- inc();
|
|
|
- });
|
|
|
-}
|
|
|
+ return methods;
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: StoreKey.SdList,
|
|
|
+ version: 1.0,
|
|
|
+ },
|
|
|
+);
|