google.ts 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant";
  2. import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
  3. import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
  4. import { getClientConfig } from "@/app/config/client";
  5. import { DEFAULT_API_HOST } from "@/app/constant";
  6. export class GeminiProApi implements LLMApi {
  7. extractMessage(res: any) {
  8. console.log("[Response] gemini-pro response: ", res);
  9. return (
  10. res?.candidates?.at(0)?.content?.parts.at(0)?.text ||
  11. res?.error?.message ||
  12. ""
  13. );
  14. }
  15. async chat(options: ChatOptions): Promise<void> {
  16. // const apiClient = this;
  17. const messages = options.messages.map((v) => ({
  18. role: v.role.replace("assistant", "model").replace("system", "user"),
  19. parts: [{ text: v.content }],
  20. }));
  21. // google requires that role in neighboring messages must not be the same
  22. for (let i = 0; i < messages.length - 1; ) {
  23. // Check if current and next item both have the role "model"
  24. if (messages[i].role === messages[i + 1].role) {
  25. // Concatenate the 'parts' of the current and next item
  26. messages[i].parts = messages[i].parts.concat(messages[i + 1].parts);
  27. // Remove the next item
  28. messages.splice(i + 1, 1);
  29. } else {
  30. // Move to the next item
  31. i++;
  32. }
  33. }
  34. const modelConfig = {
  35. ...useAppConfig.getState().modelConfig,
  36. ...useChatStore.getState().currentSession().mask.modelConfig,
  37. ...{
  38. model: options.config.model,
  39. },
  40. };
  41. const requestPayload = {
  42. contents: messages,
  43. generationConfig: {
  44. // stopSequences: [
  45. // "Title"
  46. // ],
  47. temperature: modelConfig.temperature,
  48. maxOutputTokens: modelConfig.max_tokens,
  49. topP: modelConfig.top_p,
  50. // "topK": modelConfig.top_k,
  51. },
  52. safetySettings: [
  53. {
  54. category: "HARM_CATEGORY_HARASSMENT",
  55. threshold: "BLOCK_ONLY_HIGH",
  56. },
  57. {
  58. category: "HARM_CATEGORY_HATE_SPEECH",
  59. threshold: "BLOCK_ONLY_HIGH",
  60. },
  61. {
  62. category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  63. threshold: "BLOCK_ONLY_HIGH",
  64. },
  65. {
  66. category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  67. threshold: "BLOCK_ONLY_HIGH",
  68. },
  69. ],
  70. };
  71. const isApp = !!getClientConfig()?.isApp;
  72. const shouldStream = !!options.config.stream;
  73. const controller = new AbortController();
  74. options.onController?.(controller);
  75. const accessStore = useAccessStore.getState();
  76. try {
  77. let chatPath = this.path(Google.ChatPath);
  78. // let baseUrl = accessStore.googleUrl;
  79. chatPath = isApp
  80. ? DEFAULT_API_HOST +
  81. "/api/proxy/google/" +
  82. Google.ChatPath +
  83. `?key=${accessStore.googleApiKey}`
  84. : chatPath;
  85. const chatPayload = {
  86. method: "POST",
  87. body: JSON.stringify(requestPayload),
  88. signal: controller.signal,
  89. headers: getHeaders(),
  90. };
  91. console.log("[Request] google chatPath: ", chatPath, isApp);
  92. // make a fetch request
  93. const requestTimeoutId = setTimeout(
  94. () => controller.abort(),
  95. REQUEST_TIMEOUT_MS,
  96. );
  97. if (shouldStream) {
  98. let responseText = "";
  99. let remainText = "";
  100. let streamChatPath = chatPath.replace(
  101. "generateContent",
  102. "streamGenerateContent",
  103. );
  104. let finished = false;
  105. let existingTexts: string[] = [];
  106. const finish = () => {
  107. finished = true;
  108. options.onFinish(existingTexts.join(""));
  109. };
  110. // animate response to make it looks smooth
  111. function animateResponseText() {
  112. if (finished || controller.signal.aborted) {
  113. responseText += remainText;
  114. finish();
  115. return;
  116. }
  117. if (remainText.length > 0) {
  118. const fetchCount = Math.max(1, Math.round(remainText.length / 60));
  119. const fetchText = remainText.slice(0, fetchCount);
  120. responseText += fetchText;
  121. remainText = remainText.slice(fetchCount);
  122. options.onUpdate?.(responseText, fetchText);
  123. }
  124. requestAnimationFrame(animateResponseText);
  125. }
  126. // start animaion
  127. animateResponseText();
  128. console.log("[Proxy Endpoint] ", streamChatPath);
  129. fetch(streamChatPath, chatPayload)
  130. .then((response) => {
  131. const reader = response?.body?.getReader();
  132. const decoder = new TextDecoder();
  133. let partialData = "";
  134. return reader?.read().then(function processText({
  135. done,
  136. value,
  137. }): Promise<any> {
  138. if (done) {
  139. console.log("Stream complete");
  140. // options.onFinish(responseText + remainText);
  141. finished = true;
  142. return Promise.resolve();
  143. }
  144. partialData += decoder.decode(value, { stream: true });
  145. try {
  146. let data = JSON.parse(ensureProperEnding(partialData));
  147. const textArray = data.reduce(
  148. (acc: string[], item: { candidates: any[] }) => {
  149. const texts = item.candidates.map((candidate) =>
  150. candidate.content.parts
  151. .map((part: { text: any }) => part.text)
  152. .join(""),
  153. );
  154. return acc.concat(texts);
  155. },
  156. [],
  157. );
  158. if (textArray.length > existingTexts.length) {
  159. const deltaArray = textArray.slice(existingTexts.length);
  160. existingTexts = textArray;
  161. remainText += deltaArray.join("");
  162. }
  163. } catch (error) {
  164. // console.log("[Response Animation] error: ", error,partialData);
  165. // skip error message when parsing json
  166. }
  167. return reader.read().then(processText);
  168. });
  169. })
  170. .catch((error) => {
  171. console.error("Error:", error);
  172. });
  173. } else {
  174. const res = await fetch(chatPath, chatPayload);
  175. clearTimeout(requestTimeoutId);
  176. const resJson = await res.json();
  177. if (resJson?.promptFeedback?.blockReason) {
  178. // being blocked
  179. options.onError?.(
  180. new Error(
  181. "Message is being blocked for reason: " +
  182. resJson.promptFeedback.blockReason,
  183. ),
  184. );
  185. }
  186. const message = this.extractMessage(resJson);
  187. options.onFinish(message);
  188. }
  189. } catch (e) {
  190. console.log("[Request] failed to make a chat request", e);
  191. options.onError?.(e as Error);
  192. }
  193. }
  194. usage(): Promise<LLMUsage> {
  195. throw new Error("Method not implemented.");
  196. }
  197. async models(): Promise<LLMModel[]> {
  198. return [];
  199. }
  200. path(path: string): string {
  201. return "/api/google/" + path;
  202. }
  203. }
  204. function ensureProperEnding(str: string) {
  205. if (str.startsWith("[") && !str.endsWith("]")) {
  206. return str + "]";
  207. }
  208. return str;
  209. }