google.ts 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant";
  2. import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
  3. import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
  4. import {
  5. EventStreamContentType,
  6. fetchEventSource,
  7. } from "@fortaine/fetch-event-source";
  8. import { prettyObject } from "@/app/utils/format";
  9. import { getClientConfig } from "@/app/config/client";
  10. import Locale from "../../locales";
  11. import { getServerSideConfig } from "@/app/config/server";
  12. import de from "@/app/locales/de";
  13. export class GeminiProApi implements LLMApi {
  14. extractMessage(res: any) {
  15. console.log("[Response] gemini-pro response: ", res);
  16. return (
  17. res?.candidates?.at(0)?.content?.parts.at(0)?.text ||
  18. res?.error?.message ||
  19. ""
  20. );
  21. }
  22. async chat(options: ChatOptions): Promise<void> {
  23. const apiClient = this;
  24. const messages = options.messages.map((v) => ({
  25. role: v.role.replace("assistant", "model").replace("system", "user"),
  26. parts: [{ text: v.content }],
  27. }));
  28. // google requires that role in neighboring messages must not be the same
  29. for (let i = 0; i < messages.length - 1; ) {
  30. // Check if current and next item both have the role "model"
  31. if (messages[i].role === messages[i + 1].role) {
  32. // Concatenate the 'parts' of the current and next item
  33. messages[i].parts = messages[i].parts.concat(messages[i + 1].parts);
  34. // Remove the next item
  35. messages.splice(i + 1, 1);
  36. } else {
  37. // Move to the next item
  38. i++;
  39. }
  40. }
  41. const modelConfig = {
  42. ...useAppConfig.getState().modelConfig,
  43. ...useChatStore.getState().currentSession().mask.modelConfig,
  44. ...{
  45. model: options.config.model,
  46. },
  47. };
  48. const requestPayload = {
  49. contents: messages,
  50. generationConfig: {
  51. // stopSequences: [
  52. // "Title"
  53. // ],
  54. temperature: modelConfig.temperature,
  55. maxOutputTokens: modelConfig.max_tokens,
  56. topP: modelConfig.top_p,
  57. // "topK": modelConfig.top_k,
  58. },
  59. };
  60. console.log("[Request] google payload: ", requestPayload);
  61. const shouldStream = !!options.config.stream;
  62. const controller = new AbortController();
  63. options.onController?.(controller);
  64. try {
  65. const chatPath = this.path(Google.ChatPath);
  66. const chatPayload = {
  67. method: "POST",
  68. body: JSON.stringify(requestPayload),
  69. signal: controller.signal,
  70. headers: getHeaders(),
  71. };
  72. // make a fetch request
  73. const requestTimeoutId = setTimeout(
  74. () => controller.abort(),
  75. REQUEST_TIMEOUT_MS,
  76. );
  77. if (shouldStream) {
  78. let responseText = "";
  79. let remainText = "";
  80. let streamChatPath = chatPath.replace(
  81. "generateContent",
  82. "streamGenerateContent",
  83. );
  84. let finished = false;
  85. let existingTexts: string[] = [];
  86. const finish = () => {
  87. finished = true;
  88. options.onFinish(existingTexts.join(""));
  89. };
  90. // animate response to make it looks smooth
  91. function animateResponseText() {
  92. if (finished || controller.signal.aborted) {
  93. responseText += remainText;
  94. finish();
  95. return;
  96. }
  97. if (remainText.length > 0) {
  98. const fetchCount = Math.max(1, Math.round(remainText.length / 60));
  99. const fetchText = remainText.slice(0, fetchCount);
  100. responseText += fetchText;
  101. remainText = remainText.slice(fetchCount);
  102. options.onUpdate?.(responseText, fetchText);
  103. }
  104. requestAnimationFrame(animateResponseText);
  105. }
  106. // start animaion
  107. animateResponseText();
  108. fetch(streamChatPath, chatPayload)
  109. .then((response) => {
  110. const reader = response?.body?.getReader();
  111. const decoder = new TextDecoder();
  112. let partialData = "";
  113. return reader?.read().then(function processText({
  114. done,
  115. value,
  116. }): Promise<any> {
  117. if (done) {
  118. console.log("Stream complete");
  119. // options.onFinish(responseText + remainText);
  120. finished = true;
  121. return Promise.resolve();
  122. }
  123. partialData += decoder.decode(value, { stream: true });
  124. try {
  125. let data = JSON.parse(ensureProperEnding(partialData));
  126. const textArray = data.reduce(
  127. (acc: string[], item: { candidates: any[] }) => {
  128. const texts = item.candidates.map((candidate) =>
  129. candidate.content.parts
  130. .map((part: { text: any }) => part.text)
  131. .join(""),
  132. );
  133. return acc.concat(texts);
  134. },
  135. [],
  136. );
  137. if (textArray.length > existingTexts.length) {
  138. const deltaArray = textArray.slice(existingTexts.length);
  139. existingTexts = textArray;
  140. remainText += deltaArray.join("");
  141. }
  142. } catch (error) {
  143. // console.log("[Response Animation] error: ", error,partialData);
  144. // skip error message when parsing json
  145. }
  146. return reader.read().then(processText);
  147. });
  148. })
  149. .catch((error) => {
  150. console.error("Error:", error);
  151. });
  152. } else {
  153. const res = await fetch(chatPath, chatPayload);
  154. clearTimeout(requestTimeoutId);
  155. const resJson = await res.json();
  156. if (resJson?.promptFeedback?.blockReason) {
  157. // being blocked
  158. options.onError?.(
  159. new Error(
  160. "Message is being blocked for reason: " +
  161. resJson.promptFeedback.blockReason,
  162. ),
  163. );
  164. }
  165. const message = this.extractMessage(resJson);
  166. options.onFinish(message);
  167. }
  168. } catch (e) {
  169. console.log("[Request] failed to make a chat request", e);
  170. options.onError?.(e as Error);
  171. }
  172. }
  173. usage(): Promise<LLMUsage> {
  174. throw new Error("Method not implemented.");
  175. }
  176. async models(): Promise<LLMModel[]> {
  177. return [];
  178. }
  179. path(path: string): string {
  180. return "/api/google/" + path;
  181. }
  182. }
  183. function ensureProperEnding(str: string) {
  184. if (str.startsWith("[") && !str.endsWith("]")) {
  185. return str + "]";
  186. }
  187. return str;
  188. }