sd.ts 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import { Stability, StoreKey } from "@/app/constant";
  2. import { getHeaders } from "@/app/client/api";
  3. import { createPersistStore } from "@/app/utils/store";
  4. import { nanoid } from "nanoid";
  5. import { uploadImage, base64Image2Blob } from "@/app/utils/chat";
  6. import { models, getModelParamBasicData } from "@/app/components/sd/sd-panel";
  7. const defaultModel = {
  8. name: models[0].name,
  9. value: models[0].value,
  10. };
  11. const defaultParams = getModelParamBasicData(models[0].params({}), {});
  12. const DEFAULT_SD_STATE = {
  13. currentId: 0,
  14. draw: [],
  15. currentModel: defaultModel,
  16. currentParams: defaultParams,
  17. };
  18. export const useSdStore = createPersistStore<
  19. {
  20. currentId: number;
  21. draw: any[];
  22. currentModel: typeof defaultModel;
  23. currentParams: any;
  24. },
  25. {
  26. getNextId: () => number;
  27. sendTask: (data: any, okCall?: Function) => void;
  28. updateDraw: (draw: any) => void;
  29. setCurrentModel: (model: any) => void;
  30. setCurrentParams: (data: any) => void;
  31. }
  32. >(
  33. DEFAULT_SD_STATE,
  34. (set, _get) => {
  35. function get() {
  36. return {
  37. ..._get(),
  38. ...methods,
  39. };
  40. }
  41. const methods = {
  42. getNextId() {
  43. const id = ++_get().currentId;
  44. set({ currentId: id });
  45. return id;
  46. },
  47. sendTask(data: any, okCall?: Function) {
  48. data = { ...data, id: nanoid(), status: "running" };
  49. set({ draw: [data, ..._get().draw] });
  50. this.getNextId();
  51. this.stabilityRequestCall(data);
  52. okCall?.();
  53. },
  54. stabilityRequestCall(data: any) {
  55. const formData = new FormData();
  56. for (let paramsKey in data.params) {
  57. formData.append(paramsKey, data.params[paramsKey]);
  58. }
  59. const headers = getHeaders();
  60. delete headers["Content-Type"];
  61. fetch(`/api/stability/${Stability.GeneratePath}/${data.model}`, {
  62. method: "POST",
  63. headers: {
  64. ...headers,
  65. Accept: "application/json",
  66. },
  67. body: formData,
  68. })
  69. .then((response) => response.json())
  70. .then((resData) => {
  71. if (resData.errors && resData.errors.length > 0) {
  72. this.updateDraw({
  73. ...data,
  74. status: "error",
  75. error: resData.errors[0],
  76. });
  77. this.getNextId();
  78. return;
  79. }
  80. const self = this;
  81. if (resData.finish_reason === "SUCCESS") {
  82. uploadImage(base64Image2Blob(resData.image, "image/png"))
  83. .then((img_data) => {
  84. console.debug("uploadImage success", img_data, self);
  85. self.updateDraw({
  86. ...data,
  87. status: "success",
  88. img_data,
  89. });
  90. })
  91. .catch((e) => {
  92. console.error("uploadImage error", e);
  93. self.updateDraw({
  94. ...data,
  95. status: "error",
  96. error: JSON.stringify(e),
  97. });
  98. });
  99. } else {
  100. self.updateDraw({
  101. ...data,
  102. status: "error",
  103. error: JSON.stringify(resData),
  104. });
  105. }
  106. this.getNextId();
  107. })
  108. .catch((error) => {
  109. this.updateDraw({ ...data, status: "error", error: error.message });
  110. console.error("Error:", error);
  111. this.getNextId();
  112. });
  113. },
  114. updateDraw(_draw: any) {
  115. const draw = _get().draw || [];
  116. draw.some((item, index) => {
  117. if (item.id === _draw.id) {
  118. draw[index] = _draw;
  119. set(() => ({ draw }));
  120. return true;
  121. }
  122. });
  123. },
  124. setCurrentModel(model: any) {
  125. set({ currentModel: model });
  126. },
  127. setCurrentParams(data: any) {
  128. set({
  129. currentParams: data,
  130. });
  131. },
  132. };
  133. return methods;
  134. },
  135. {
  136. name: StoreKey.SdList,
  137. version: 1.0,
  138. },
  139. );