sd.ts 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import {
  2. Stability,
  3. StoreKey,
  4. ACCESS_CODE_PREFIX,
  5. ApiPath,
  6. } from "@/app/constant";
  7. import { getBearerToken } from "@/app/client/api";
  8. import { createPersistStore } from "@/app/utils/store";
  9. import { nanoid } from "nanoid";
  10. import { uploadImage, base64Image2Blob } from "@/app/utils/chat";
  11. import { models, getModelParamBasicData } from "@/app/components/sd/sd-panel";
  12. import { useAccessStore } from "./access";
  13. const defaultModel = {
  14. name: models[0].name,
  15. value: models[0].value,
  16. };
  17. const defaultParams = getModelParamBasicData(models[0].params({}), {});
  18. const DEFAULT_SD_STATE = {
  19. currentId: 0,
  20. draw: [],
  21. currentModel: defaultModel,
  22. currentParams: defaultParams,
  23. };
  24. export const useSdStore = createPersistStore<
  25. {
  26. currentId: number;
  27. draw: any[];
  28. currentModel: typeof defaultModel;
  29. currentParams: any;
  30. },
  31. {
  32. getNextId: () => number;
  33. sendTask: (data: any, okCall?: Function) => void;
  34. updateDraw: (draw: any) => void;
  35. setCurrentModel: (model: any) => void;
  36. setCurrentParams: (data: any) => void;
  37. }
  38. >(
  39. DEFAULT_SD_STATE,
  40. (set, _get) => {
  41. function get() {
  42. return {
  43. ..._get(),
  44. ...methods,
  45. };
  46. }
  47. const methods = {
  48. getNextId() {
  49. const id = ++_get().currentId;
  50. set({ currentId: id });
  51. return id;
  52. },
  53. sendTask(data: any, okCall?: Function) {
  54. data = { ...data, id: nanoid(), status: "running" };
  55. set({ draw: [data, ..._get().draw] });
  56. this.getNextId();
  57. this.stabilityRequestCall(data);
  58. okCall?.();
  59. },
  60. stabilityRequestCall(data: any) {
  61. const accessStore = useAccessStore.getState();
  62. let prefix: string = ApiPath.Stability as string;
  63. let bearerToken = "";
  64. if (accessStore.useCustomConfig) {
  65. prefix = accessStore.stabilityUrl || (ApiPath.Stability as string);
  66. bearerToken = getBearerToken(accessStore.stabilityApiKey);
  67. }
  68. if (!bearerToken && accessStore.enabledAccessControl()) {
  69. bearerToken = getBearerToken(
  70. ACCESS_CODE_PREFIX + accessStore.accessCode,
  71. );
  72. }
  73. const headers = {
  74. Accept: "application/json",
  75. Authorization: bearerToken,
  76. };
  77. const path = `${prefix}/${Stability.GeneratePath}/${data.model}`;
  78. const formData = new FormData();
  79. for (let paramsKey in data.params) {
  80. formData.append(paramsKey, data.params[paramsKey]);
  81. }
  82. fetch(path, {
  83. method: "POST",
  84. headers,
  85. body: formData,
  86. })
  87. .then((response) => response.json())
  88. .then((resData) => {
  89. if (resData.errors && resData.errors.length > 0) {
  90. this.updateDraw({
  91. ...data,
  92. status: "error",
  93. error: resData.errors[0],
  94. });
  95. this.getNextId();
  96. return;
  97. }
  98. const self = this;
  99. if (resData.finish_reason === "SUCCESS") {
  100. uploadImage(base64Image2Blob(resData.image, "image/png"))
  101. .then((img_data) => {
  102. console.debug("uploadImage success", img_data, self);
  103. self.updateDraw({
  104. ...data,
  105. status: "success",
  106. img_data,
  107. });
  108. })
  109. .catch((e) => {
  110. console.error("uploadImage error", e);
  111. self.updateDraw({
  112. ...data,
  113. status: "error",
  114. error: JSON.stringify(e),
  115. });
  116. });
  117. } else {
  118. self.updateDraw({
  119. ...data,
  120. status: "error",
  121. error: JSON.stringify(resData),
  122. });
  123. }
  124. this.getNextId();
  125. })
  126. .catch((error) => {
  127. this.updateDraw({ ...data, status: "error", error: error.message });
  128. console.error("Error:", error);
  129. this.getNextId();
  130. });
  131. },
  132. updateDraw(_draw: any) {
  133. const draw = _get().draw || [];
  134. draw.some((item, index) => {
  135. if (item.id === _draw.id) {
  136. draw[index] = _draw;
  137. set(() => ({ draw }));
  138. return true;
  139. }
  140. });
  141. },
  142. setCurrentModel(model: any) {
  143. set({ currentModel: model });
  144. },
  145. setCurrentParams(data: any) {
  146. set({
  147. currentParams: data,
  148. });
  149. },
  150. };
  151. return methods;
  152. },
  153. {
  154. name: StoreKey.SdList,
  155. version: 1.0,
  156. },
  157. );