google.ts 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant";
  2. import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
  3. import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
  4. import {
  5. EventStreamContentType,
  6. fetchEventSource,
  7. } from "@fortaine/fetch-event-source";
  8. import { prettyObject } from "@/app/utils/format";
  9. import { getClientConfig } from "@/app/config/client";
  10. import Locale from "../../locales";
  11. import { getServerSideConfig } from "@/app/config/server";
  12. import de from "@/app/locales/de";
  13. export class GeminiProApi implements LLMApi {
  14. extractMessage(res: any) {
  15. console.log("[Response] gemini-pro response: ", res);
  16. return (
  17. res?.candidates?.at(0)?.content?.parts.at(0)?.text ||
  18. res?.error?.message ||
  19. ""
  20. );
  21. }
  22. async chat(options: ChatOptions): Promise<void> {
  23. const apiClient = this;
  24. const messages = options.messages.map((v) => ({
  25. role: v.role.replace("assistant", "model").replace("system", "user"),
  26. parts: [{ text: v.content }],
  27. }));
  28. // google requires that role in neighboring messages must not be the same
  29. for (let i = 0; i < messages.length - 1; ) {
  30. // Check if current and next item both have the role "model"
  31. if (messages[i].role === messages[i + 1].role) {
  32. // Concatenate the 'parts' of the current and next item
  33. messages[i].parts = messages[i].parts.concat(messages[i + 1].parts);
  34. // Remove the next item
  35. messages.splice(i + 1, 1);
  36. } else {
  37. // Move to the next item
  38. i++;
  39. }
  40. }
  41. const modelConfig = {
  42. ...useAppConfig.getState().modelConfig,
  43. ...useChatStore.getState().currentSession().mask.modelConfig,
  44. ...{
  45. model: options.config.model,
  46. },
  47. };
  48. const requestPayload = {
  49. contents: messages,
  50. generationConfig: {
  51. // stopSequences: [
  52. // "Title"
  53. // ],
  54. temperature: modelConfig.temperature,
  55. maxOutputTokens: modelConfig.max_tokens,
  56. topP: modelConfig.top_p,
  57. // "topK": modelConfig.top_k,
  58. },
  59. safetySettings: [
  60. {
  61. category: "HARM_CATEGORY_HARASSMENT",
  62. threshold: "BLOCK_ONLY_HIGH",
  63. },
  64. {
  65. category: "HARM_CATEGORY_HATE_SPEECH",
  66. threshold: "BLOCK_ONLY_HIGH",
  67. },
  68. {
  69. category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  70. threshold: "BLOCK_ONLY_HIGH",
  71. },
  72. {
  73. category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  74. threshold: "BLOCK_ONLY_HIGH",
  75. },
  76. ],
  77. };
  78. console.log("[Request] google payload: ", requestPayload);
  79. const shouldStream = !!options.config.stream;
  80. const controller = new AbortController();
  81. options.onController?.(controller);
  82. try {
  83. const chatPath = this.path(Google.ChatPath);
  84. const chatPayload = {
  85. method: "POST",
  86. body: JSON.stringify(requestPayload),
  87. signal: controller.signal,
  88. headers: getHeaders(),
  89. };
  90. // make a fetch request
  91. const requestTimeoutId = setTimeout(
  92. () => controller.abort(),
  93. REQUEST_TIMEOUT_MS,
  94. );
  95. if (shouldStream) {
  96. let responseText = "";
  97. let remainText = "";
  98. let streamChatPath = chatPath.replace(
  99. "generateContent",
  100. "streamGenerateContent",
  101. );
  102. let finished = false;
  103. let existingTexts: string[] = [];
  104. const finish = () => {
  105. finished = true;
  106. options.onFinish(existingTexts.join(""));
  107. };
  108. // animate response to make it looks smooth
  109. function animateResponseText() {
  110. if (finished || controller.signal.aborted) {
  111. responseText += remainText;
  112. finish();
  113. return;
  114. }
  115. if (remainText.length > 0) {
  116. const fetchCount = Math.max(1, Math.round(remainText.length / 60));
  117. const fetchText = remainText.slice(0, fetchCount);
  118. responseText += fetchText;
  119. remainText = remainText.slice(fetchCount);
  120. options.onUpdate?.(responseText, fetchText);
  121. }
  122. requestAnimationFrame(animateResponseText);
  123. }
  124. // start animaion
  125. animateResponseText();
  126. fetch(streamChatPath, chatPayload)
  127. .then((response) => {
  128. const reader = response?.body?.getReader();
  129. const decoder = new TextDecoder();
  130. let partialData = "";
  131. return reader?.read().then(function processText({
  132. done,
  133. value,
  134. }): Promise<any> {
  135. if (done) {
  136. console.log("Stream complete");
  137. // options.onFinish(responseText + remainText);
  138. finished = true;
  139. return Promise.resolve();
  140. }
  141. partialData += decoder.decode(value, { stream: true });
  142. try {
  143. let data = JSON.parse(ensureProperEnding(partialData));
  144. const textArray = data.reduce(
  145. (acc: string[], item: { candidates: any[] }) => {
  146. const texts = item.candidates.map((candidate) =>
  147. candidate.content.parts
  148. .map((part: { text: any }) => part.text)
  149. .join(""),
  150. );
  151. return acc.concat(texts);
  152. },
  153. [],
  154. );
  155. if (textArray.length > existingTexts.length) {
  156. const deltaArray = textArray.slice(existingTexts.length);
  157. existingTexts = textArray;
  158. remainText += deltaArray.join("");
  159. }
  160. } catch (error) {
  161. // console.log("[Response Animation] error: ", error,partialData);
  162. // skip error message when parsing json
  163. }
  164. return reader.read().then(processText);
  165. });
  166. })
  167. .catch((error) => {
  168. console.error("Error:", error);
  169. });
  170. } else {
  171. const res = await fetch(chatPath, chatPayload);
  172. clearTimeout(requestTimeoutId);
  173. const resJson = await res.json();
  174. if (resJson?.promptFeedback?.blockReason) {
  175. // being blocked
  176. options.onError?.(
  177. new Error(
  178. "Message is being blocked for reason: " +
  179. resJson.promptFeedback.blockReason,
  180. ),
  181. );
  182. }
  183. const message = this.extractMessage(resJson);
  184. options.onFinish(message);
  185. }
  186. } catch (e) {
  187. console.log("[Request] failed to make a chat request", e);
  188. options.onError?.(e as Error);
  189. }
  190. }
  191. usage(): Promise<LLMUsage> {
  192. throw new Error("Method not implemented.");
  193. }
  194. async models(): Promise<LLMModel[]> {
  195. return [];
  196. }
  197. path(path: string): string {
  198. return "/api/google/" + path;
  199. }
  200. }
  201. function ensureProperEnding(str: string) {
  202. if (str.startsWith("[") && !str.endsWith("]")) {
  203. return str + "]";
  204. }
  205. return str;
  206. }