Przeglądaj źródła

Add vision support (#4076)

TheRam_ 1 rok temu
rodzic
commit
e2da3406d2

+ 9 - 1
app/client/api.ts

@@ -14,9 +14,17 @@ export type MessageRole = (typeof ROLES)[number];
 export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
 export type ChatModel = ModelType;
 
+export interface MultimodalContent {
+  type: "text" | "image_url";
+  text?: string;
+  image_url?: {
+    url: string;
+  };
+}
+
 export interface RequestMessage {
   role: MessageRole;
-  content: string;
+  content: string | MultimodalContent[];
 }
 
 export interface LLMConfig {

+ 54 - 7
app/client/platforms/google.ts

@@ -3,6 +3,12 @@ import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
 import { getClientConfig } from "@/app/config/client";
 import { DEFAULT_API_HOST } from "@/app/constant";
+import {
+  getMessageTextContent,
+  getMessageImages,
+  isVisionModel,
+} from "@/app/utils";
+
 export class GeminiProApi implements LLMApi {
   extractMessage(res: any) {
     console.log("[Response] gemini-pro response: ", res);
@@ -15,10 +21,33 @@ export class GeminiProApi implements LLMApi {
   }
   async chat(options: ChatOptions): Promise<void> {
     // const apiClient = this;
-    const messages = options.messages.map((v) => ({
-      role: v.role.replace("assistant", "model").replace("system", "user"),
-      parts: [{ text: v.content }],
-    }));
+    const visionModel = isVisionModel(options.config.model);
+    let multimodal = false;
+    const messages = options.messages.map((v) => {
+      let parts: any[] = [{ text: getMessageTextContent(v) }];
+      if (visionModel) {
+        const images = getMessageImages(v);
+        if (images.length > 0) {
+          multimodal = true;
+          parts = parts.concat(
+            images.map((image) => {
+              const imageType = image.split(";")[0].split(":")[1];
+              const imageData = image.split(",")[1];
+              return {
+                inline_data: {
+                  mime_type: imageType,
+                  data: imageData,
+                },
+              };
+            }),
+          );
+        }
+      }
+      return {
+        role: v.role.replace("assistant", "model").replace("system", "user"),
+        parts: parts,
+      };
+    });
 
     // google requires that role in neighboring messages must not be the same
     for (let i = 0; i < messages.length - 1; ) {
@@ -33,7 +62,9 @@ export class GeminiProApi implements LLMApi {
         i++;
       }
     }
-
+    // if (visionModel && messages.length > 1) {
+    //   options.onError?.(new Error("Multiturn chat is not enabled for models/gemini-pro-vision"));
+    // }
     const modelConfig = {
       ...useAppConfig.getState().modelConfig,
       ...useChatStore.getState().currentSession().mask.modelConfig,
@@ -80,13 +111,16 @@ export class GeminiProApi implements LLMApi {
     const controller = new AbortController();
     options.onController?.(controller);
     try {
-      let chatPath = this.path(Google.ChatPath);
+      let googleChatPath = visionModel
+        ? Google.VisionChatPath
+        : Google.ChatPath;
+      let chatPath = this.path(googleChatPath);
 
       // let baseUrl = accessStore.googleUrl;
 
       if (!baseUrl) {
         baseUrl = isApp
-          ? DEFAULT_API_HOST + "/api/proxy/google/" + Google.ChatPath
+          ? DEFAULT_API_HOST + "/api/proxy/google/" + googleChatPath
           : chatPath;
       }
 
@@ -152,6 +186,19 @@ export class GeminiProApi implements LLMApi {
               value,
             }): Promise<any> {
               if (done) {
+                if (response.status !== 200) {
+                  try {
+                    let data = JSON.parse(ensureProperEnding(partialData));
+                    if (data && data[0].error) {
+                      options.onError?.(new Error(data[0].error.message));
+                    } else {
+                      options.onError?.(new Error("Request failed"));
+                    }
+                  } catch (_) {
+                    options.onError?.(new Error("Request failed"));
+                  }
+                }
+
                 console.log("Stream complete");
                 // options.onFinish(responseText + remainText);
                 finished = true;

+ 15 - 2
app/client/platforms/openai.ts

@@ -9,7 +9,14 @@ import {
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
 
-import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
+import {
+  ChatOptions,
+  getHeaders,
+  LLMApi,
+  LLMModel,
+  LLMUsage,
+  MultimodalContent,
+} from "../api";
 import Locale from "../../locales";
 import {
   EventStreamContentType,
@@ -18,6 +25,11 @@ import {
 import { prettyObject } from "@/app/utils/format";
 import { getClientConfig } from "@/app/config/client";
 import { makeAzurePath } from "@/app/azure";
+import {
+  getMessageTextContent,
+  getMessageImages,
+  isVisionModel,
+} from "@/app/utils";
 
 export interface OpenAIListModelResponse {
   object: string;
@@ -72,9 +84,10 @@ export class ChatGPTApi implements LLMApi {
   }
 
   async chat(options: ChatOptions) {
+    const visionModel = isVisionModel(options.config.model);
     const messages = options.messages.map((v) => ({
       role: v.role,
-      content: v.content,
+      content: visionModel ? v.content : getMessageTextContent(v),
     }));
 
     const modelConfig = {

+ 122 - 13
app/components/chat.module.scss

@@ -1,5 +1,47 @@
 @import "../styles/animation.scss";
 
+.attach-images {
+  position: absolute;
+  left: 30px;
+  bottom: 32px;
+  display: flex;
+}
+
+.attach-image {
+  cursor: default;
+  width: 64px;
+  height: 64px;
+  border: rgba($color: #888, $alpha: 0.2) 1px solid;
+  border-radius: 5px;
+  margin-right: 10px;
+  background-size: cover;
+  background-position: center;
+  background-color: var(--white);
+
+  .attach-image-mask {
+    width: 100%;
+    height: 100%;
+    opacity: 0;
+    transition: all ease 0.2s;
+  }
+
+  .attach-image-mask:hover {
+    opacity: 1;
+  }
+
+  .delete-image {
+    width: 24px;
+    height: 24px;
+    cursor: pointer;
+    display: flex;
+    align-items: center;
+    justify-content: center;
+    border-radius: 5px;
+    float: right;
+    background-color: var(--white);
+  }
+}
+
 .chat-input-actions {
   display: flex;
   flex-wrap: wrap;
@@ -189,12 +231,10 @@
 
   animation: slide-in ease 0.3s;
 
-  $linear: linear-gradient(
-    to right,
-    rgba(0, 0, 0, 0),
-    rgba(0, 0, 0, 1),
-    rgba(0, 0, 0, 0)
-  );
+  $linear: linear-gradient(to right,
+      rgba(0, 0, 0, 0),
+      rgba(0, 0, 0, 1),
+      rgba(0, 0, 0, 0));
   mask-image: $linear;
 
   @mixin show {
@@ -327,7 +367,7 @@
   }
 }
 
-.chat-message-user > .chat-message-container {
+.chat-message-user>.chat-message-container {
   align-items: flex-end;
 }
 
@@ -349,6 +389,7 @@
       padding: 7px;
     }
   }
+
   /* Specific styles for iOS devices */
   @media screen and (max-device-width: 812px) and (-webkit-min-device-pixel-ratio: 2) {
     @supports (-webkit-touch-callout: none) {
@@ -381,6 +422,64 @@
   transition: all ease 0.3s;
 }
 
+.chat-message-item-image {
+  width: 100%;
+  margin-top: 10px;
+}
+
+.chat-message-item-images {
+  width: 100%;
+  display: grid;
+  justify-content: left;
+  grid-gap: 10px;
+  grid-template-columns: repeat(var(--image-count), auto);
+  margin-top: 10px;
+}
+
+.chat-message-item-image-multi {
+  object-fit: cover;
+  background-size: cover;
+  background-position: center;
+  background-repeat: no-repeat;
+}
+
+.chat-message-item-image,
+.chat-message-item-image-multi {
+  box-sizing: border-box;
+  border-radius: 10px;
+  border: rgba($color: #888, $alpha: 0.2) 1px solid;
+}
+
+
+@media only screen and (max-width: 600px) {
+  $calc-image-width: calc(100vw/3*2/var(--image-count));
+
+  .chat-message-item-image-multi {
+    width: $calc-image-width;
+    height: $calc-image-width;
+  }
+  
+  .chat-message-item-image {
+    max-width: calc(100vw/3*2);
+  }
+}
+
+@media screen and (min-width: 600px) {
+  $max-image-width: calc(calc(1200px - var(--sidebar-width))/3*2/var(--image-count));
+  $image-width: calc(calc(var(--window-width) - var(--sidebar-width))/3*2/var(--image-count));
+
+  .chat-message-item-image-multi {
+    width: $image-width;
+    height: $image-width;
+    max-width: $max-image-width;
+    max-height: $max-image-width;
+  }
+
+  .chat-message-item-image {
+    max-width: calc(calc(1200px - var(--sidebar-width))/3*2);
+  }
+}
+
 .chat-message-action-date {
   font-size: 12px;
   opacity: 0.2;
@@ -395,7 +494,7 @@
   z-index: 1;
 }
 
-.chat-message-user > .chat-message-container > .chat-message-item {
+.chat-message-user>.chat-message-container>.chat-message-item {
   background-color: var(--second);
 
   &:hover {
@@ -460,6 +559,7 @@
 
       @include single-line();
     }
+
     .hint-content {
       font-size: 12px;
 
@@ -474,15 +574,26 @@
 }
 
 .chat-input-panel-inner {
+  cursor: text;
   display: flex;
   flex: 1;
+  border-radius: 10px;
+  border: var(--border-in-light);
+}
+
+.chat-input-panel-inner-attach {
+  padding-bottom: 80px;
+}
+
+.chat-input-panel-inner:has(.chat-input:focus) {
+  border: 1px solid var(--primary);
 }
 
 .chat-input {
   height: 100%;
   width: 100%;
   border-radius: 10px;
-  border: var(--border-in-light);
+  border: none;
   box-shadow: 0 -2px 5px rgba(0, 0, 0, 0.03);
   background-color: var(--white);
   color: var(--black);
@@ -494,9 +605,7 @@
   min-height: 68px;
 }
 
-.chat-input:focus {
-  border: 1px solid var(--primary);
-}
+.chat-input:focus {}
 
 .chat-input-send {
   background-color: var(--primary);
@@ -515,4 +624,4 @@
   .chat-input-send {
     bottom: 30px;
   }
-}
+}

+ 180 - 11
app/components/chat.tsx

@@ -15,6 +15,7 @@ import ExportIcon from "../icons/share.svg";
 import ReturnIcon from "../icons/return.svg";
 import CopyIcon from "../icons/copy.svg";
 import LoadingIcon from "../icons/three-dots.svg";
+import LoadingButtonIcon from "../icons/loading.svg";
 import PromptIcon from "../icons/prompt.svg";
 import MaskIcon from "../icons/mask.svg";
 import MaxIcon from "../icons/max.svg";
@@ -27,6 +28,7 @@ import PinIcon from "../icons/pin.svg";
 import EditIcon from "../icons/rename.svg";
 import ConfirmIcon from "../icons/confirm.svg";
 import CancelIcon from "../icons/cancel.svg";
+import ImageIcon from "../icons/image.svg";
 
 import LightIcon from "../icons/light.svg";
 import DarkIcon from "../icons/dark.svg";
@@ -53,6 +55,10 @@ import {
   selectOrCopy,
   autoGrowTextArea,
   useMobileScreen,
+  getMessageTextContent,
+  getMessageImages,
+  isVisionModel,
+  compressImage,
 } from "../utils";
 
 import dynamic from "next/dynamic";
@@ -89,6 +95,7 @@ import { prettyObject } from "../utils/format";
 import { ExportMessageModal } from "./exporter";
 import { getClientConfig } from "../config/client";
 import { useAllModels } from "../utils/hooks";
+import { MultimodalContent } from "../client/api";
 
 const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
   loading: () => <LoadingIcon />,
@@ -406,10 +413,14 @@ function useScrollToBottom() {
 }
 
 export function ChatActions(props: {
+  uploadImage: () => void;
+  setAttachImages: (images: string[]) => void;
+  setUploading: (uploading: boolean) => void;
   showPromptModal: () => void;
   scrollToBottom: () => void;
   showPromptHints: () => void;
   hitBottom: boolean;
+  uploading: boolean;
 }) {
   const config = useAppConfig();
   const navigate = useNavigate();
@@ -437,8 +448,16 @@ export function ChatActions(props: {
     [allModels],
   );
   const [showModelSelector, setShowModelSelector] = useState(false);
+  const [showUploadImage, setShowUploadImage] = useState(false);
 
   useEffect(() => {
+    const show = isVisionModel(currentModel);
+    setShowUploadImage(show);
+    if (!show) {
+      props.setAttachImages([]);
+      props.setUploading(false);
+    }
+
     // if current model is not available
     // switch to first available model
     const isUnavaliableModel = !models.some((m) => m.name === currentModel);
@@ -475,6 +494,13 @@ export function ChatActions(props: {
         />
       )}
 
+      {showUploadImage && (
+        <ChatAction
+          onClick={props.uploadImage}
+          text={Locale.Chat.InputActions.UploadImage}
+          icon={props.uploading ? <LoadingButtonIcon /> : <ImageIcon />}
+        />
+      )}
       <ChatAction
         onClick={nextTheme}
         text={Locale.Chat.InputActions.Theme[theme]}
@@ -610,6 +636,14 @@ export function EditMessageModal(props: { onClose: () => void }) {
   );
 }
 
+export function DeleteImageButton(props: { deleteImage: () => void }) {
+  return (
+    <div className={styles["delete-image"]} onClick={props.deleteImage}>
+      <DeleteIcon />
+    </div>
+  );
+}
+
 function _Chat() {
   type RenderMessage = ChatMessage & { preview?: boolean };
 
@@ -628,6 +662,8 @@ function _Chat() {
   const [hitBottom, setHitBottom] = useState(true);
   const isMobileScreen = useMobileScreen();
   const navigate = useNavigate();
+  const [attachImages, setAttachImages] = useState<string[]>([]);
+  const [uploading, setUploading] = useState(false);
 
   // prompt hints
   const promptStore = usePromptStore();
@@ -705,7 +741,10 @@ function _Chat() {
       return;
     }
     setIsLoading(true);
-    chatStore.onUserInput(userInput).then(() => setIsLoading(false));
+    chatStore
+      .onUserInput(userInput, attachImages)
+      .then(() => setIsLoading(false));
+    setAttachImages([]);
     localStorage.setItem(LAST_INPUT_KEY, userInput);
     setUserInput("");
     setPromptHints([]);
@@ -783,9 +822,9 @@ function _Chat() {
   };
   const onRightClick = (e: any, message: ChatMessage) => {
     // copy to clipboard
-    if (selectOrCopy(e.currentTarget, message.content)) {
+    if (selectOrCopy(e.currentTarget, getMessageTextContent(message))) {
       if (userInput.length === 0) {
-        setUserInput(message.content);
+        setUserInput(getMessageTextContent(message));
       }
 
       e.preventDefault();
@@ -853,7 +892,9 @@ function _Chat() {
 
     // resend the message
     setIsLoading(true);
-    chatStore.onUserInput(userMessage.content).then(() => setIsLoading(false));
+    const textContent = getMessageTextContent(userMessage);
+    const images = getMessageImages(userMessage);
+    chatStore.onUserInput(textContent, images).then(() => setIsLoading(false));
     inputRef.current?.focus();
   };
 
@@ -1048,6 +1089,51 @@ function _Chat() {
     // eslint-disable-next-line react-hooks/exhaustive-deps
   }, []);
 
+  async function uploadImage() {
+    const images: string[] = [];
+    images.push(...attachImages);
+
+    images.push(
+      ...(await new Promise<string[]>((res, rej) => {
+        const fileInput = document.createElement("input");
+        fileInput.type = "file";
+        fileInput.accept =
+          "image/png, image/jpeg, image/webp, image/heic, image/heif";
+        fileInput.multiple = true;
+        fileInput.onchange = (event: any) => {
+          setUploading(true);
+          const files = event.target.files;
+          const imagesData: string[] = [];
+          for (let i = 0; i < files.length; i++) {
+            const file = event.target.files[i];
+            compressImage(file, 256 * 1024)
+              .then((dataUrl) => {
+                imagesData.push(dataUrl);
+                if (
+                  imagesData.length === 3 ||
+                  imagesData.length === files.length
+                ) {
+                  setUploading(false);
+                  res(imagesData);
+                }
+              })
+              .catch((e) => {
+                setUploading(false);
+                rej(e);
+              });
+          }
+        };
+        fileInput.click();
+      })),
+    );
+
+    const imagesLength = images.length;
+    if (imagesLength > 3) {
+      images.splice(3, imagesLength - 3);
+    }
+    setAttachImages(images);
+  }
+
   return (
     <div className={styles.chat} key={session.id}>
       <div className="window-header" data-tauri-drag-region>
@@ -1154,15 +1240,29 @@ function _Chat() {
                           onClick={async () => {
                             const newMessage = await showPrompt(
                               Locale.Chat.Actions.Edit,
-                              message.content,
+                              getMessageTextContent(message),
                               10,
                             );
+                            let newContent: string | MultimodalContent[] =
+                              newMessage;
+                            const images = getMessageImages(message);
+                            if (images.length > 0) {
+                              newContent = [{ type: "text", text: newMessage }];
+                              for (let i = 0; i < images.length; i++) {
+                                newContent.push({
+                                  type: "image_url",
+                                  image_url: {
+                                    url: images[i],
+                                  },
+                                });
+                              }
+                            }
                             chatStore.updateCurrentSession((session) => {
                               const m = session.mask.context
                                 .concat(session.messages)
                                 .find((m) => m.id === message.id);
                               if (m) {
-                                m.content = newMessage;
+                                m.content = newContent;
                               }
                             });
                           }}
@@ -1217,7 +1317,11 @@ function _Chat() {
                               <ChatAction
                                 text={Locale.Chat.Actions.Copy}
                                 icon={<CopyIcon />}
-                                onClick={() => copyToClipboard(message.content)}
+                                onClick={() =>
+                                  copyToClipboard(
+                                    getMessageTextContent(message),
+                                  )
+                                }
                               />
                             </>
                           )}
@@ -1232,7 +1336,7 @@ function _Chat() {
                   )}
                   <div className={styles["chat-message-item"]}>
                     <Markdown
-                      content={message.content}
+                      content={getMessageTextContent(message)}
                       loading={
                         (message.preview || message.streaming) &&
                         message.content.length === 0 &&
@@ -1241,12 +1345,42 @@ function _Chat() {
                       onContextMenu={(e) => onRightClick(e, message)}
                       onDoubleClickCapture={() => {
                         if (!isMobileScreen) return;
-                        setUserInput(message.content);
+                        setUserInput(getMessageTextContent(message));
                       }}
                       fontSize={fontSize}
                       parentRef={scrollRef}
                       defaultShow={i >= messages.length - 6}
                     />
+                    {getMessageImages(message).length == 1 && (
+                      <img
+                        className={styles["chat-message-item-image"]}
+                        src={getMessageImages(message)[0]}
+                        alt=""
+                      />
+                    )}
+                    {getMessageImages(message).length > 1 && (
+                      <div
+                        className={styles["chat-message-item-images"]}
+                        style={
+                          {
+                            "--image-count": getMessageImages(message).length,
+                          } as React.CSSProperties
+                        }
+                      >
+                        {getMessageImages(message).map((image, index) => {
+                          return (
+                            <img
+                              className={
+                                styles["chat-message-item-image-multi"]
+                              }
+                              key={index}
+                              src={image}
+                              alt=""
+                            />
+                          );
+                        })}
+                      </div>
+                    )}
                   </div>
 
                   <div className={styles["chat-message-action-date"]}>
@@ -1266,9 +1400,13 @@ function _Chat() {
         <PromptHints prompts={promptHints} onPromptSelect={onPromptSelect} />
 
         <ChatActions
+          uploadImage={uploadImage}
+          setAttachImages={setAttachImages}
+          setUploading={setUploading}
           showPromptModal={() => setShowPromptModal(true)}
           scrollToBottom={scrollToBottom}
           hitBottom={hitBottom}
+          uploading={uploading}
           showPromptHints={() => {
             // Click again to close
             if (promptHints.length > 0) {
@@ -1281,8 +1419,16 @@ function _Chat() {
             onSearch("");
           }}
         />
-        <div className={styles["chat-input-panel-inner"]}>
+        <label
+          className={`${styles["chat-input-panel-inner"]} ${
+            attachImages.length != 0
+              ? styles["chat-input-panel-inner-attach"]
+              : ""
+          }`}
+          htmlFor="chat-input"
+        >
           <textarea
+            id="chat-input"
             ref={inputRef}
             className={styles["chat-input"]}
             placeholder={Locale.Chat.Input(submitKey)}
@@ -1297,6 +1443,29 @@ function _Chat() {
               fontSize: config.fontSize,
             }}
           />
+          {attachImages.length != 0 && (
+            <div className={styles["attach-images"]}>
+              {attachImages.map((image, index) => {
+                return (
+                  <div
+                    key={index}
+                    className={styles["attach-image"]}
+                    style={{ backgroundImage: `url("${image}")` }}
+                  >
+                    <div className={styles["attach-image-mask"]}>
+                      <DeleteImageButton
+                        deleteImage={() => {
+                          setAttachImages(
+                            attachImages.filter((_, i) => i !== index),
+                          );
+                        }}
+                      />
+                    </div>
+                  </div>
+                );
+              })}
+            </div>
+          )}
           <IconButton
             icon={<SendWhiteIcon />}
             text={Locale.Chat.Send}
@@ -1304,7 +1473,7 @@ function _Chat() {
             type="primary"
             onClick={() => doSubmit(userInput)}
           />
-        </div>
+        </label>
       </div>
 
       {showExport && (

+ 56 - 3
app/components/exporter.module.scss

@@ -94,6 +94,7 @@
 
   button {
     flex-grow: 1;
+
     &:not(:last-child) {
       margin-right: 10px;
     }
@@ -190,6 +191,59 @@
         pre {
           overflow: hidden;
         }
+
+        .message-image {
+          width: 100%;
+          margin-top: 10px;
+        }
+
+        .message-images {
+          display: grid;
+          justify-content: left;
+          grid-gap: 10px;
+          grid-template-columns: repeat(var(--image-count), auto);
+          margin-top: 10px;
+        }
+
+        @media screen and (max-width: 600px) {
+          $image-width: calc(calc(100vw/2)/var(--image-count));
+
+          .message-image-multi {
+            width: $image-width;
+            height: $image-width;
+          }
+
+          .message-image {
+            max-width: calc(100vw/3*2);
+          }
+        }
+
+        @media screen and (min-width: 600px) {
+          $max-image-width: calc(900px/3*2/var(--image-count));
+          $image-width: calc(80vw/3*2/var(--image-count));
+
+          .message-image-multi {
+            width: $image-width;
+            height: $image-width;
+            max-width: $max-image-width;
+            max-height: $max-image-width;
+          }
+
+          .message-image {
+            max-width: calc(100vw/3*2);
+          }
+        }
+
+        .message-image-multi {
+          object-fit: cover;
+        }
+
+        .message-image,
+        .message-image-multi {
+          box-sizing: border-box;
+          border-radius: 10px;
+          border: rgba($color: #888, $alpha: 0.2) 1px solid;
+        }
       }
 
       &-assistant {
@@ -213,6 +267,5 @@
     }
   }
 
-  .default-theme {
-  }
-}
+  .default-theme {}
+}

+ 40 - 5
app/components/exporter.tsx

@@ -12,7 +12,12 @@ import {
   showToast,
 } from "./ui-lib";
 import { IconButton } from "./button";
-import { copyToClipboard, downloadAs, useMobileScreen } from "../utils";
+import {
+  copyToClipboard,
+  downloadAs,
+  getMessageImages,
+  useMobileScreen,
+} from "../utils";
 
 import CopyIcon from "../icons/copy.svg";
 import LoadingIcon from "../icons/three-dots.svg";
@@ -34,6 +39,7 @@ import { prettyObject } from "../utils/format";
 import { EXPORT_MESSAGE_CLASS_NAME, ModelProvider } from "../constant";
 import { getClientConfig } from "../config/client";
 import { ClientApi } from "../client/api";
+import { getMessageTextContent } from "../utils";
 
 const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
   loading: () => <LoadingIcon />,
@@ -287,7 +293,7 @@ export function RenderExport(props: {
           id={`${m.role}:${i}`}
           className={EXPORT_MESSAGE_CLASS_NAME}
         >
-          <Markdown content={m.content} defaultShow />
+          <Markdown content={getMessageTextContent(m)} defaultShow />
         </div>
       ))}
     </div>
@@ -580,10 +586,37 @@ export function ImagePreviewer(props: {
 
               <div className={styles["body"]}>
                 <Markdown
-                  content={m.content}
+                  content={getMessageTextContent(m)}
                   fontSize={config.fontSize}
                   defaultShow
                 />
+                {getMessageImages(m).length == 1 && (
+                  <img
+                    key={i}
+                    src={getMessageImages(m)[0]}
+                    alt="message"
+                    className={styles["message-image"]}
+                  />
+                )}
+                {getMessageImages(m).length > 1 && (
+                  <div
+                    className={styles["message-images"]}
+                    style={
+                      {
+                        "--image-count": getMessageImages(m).length,
+                      } as React.CSSProperties
+                    }
+                  >
+                    {getMessageImages(m).map((src, i) => (
+                      <img
+                        key={i}
+                        src={src}
+                        alt="message"
+                        className={styles["message-image-multi"]}
+                      />
+                    ))}
+                  </div>
+                )}
               </div>
             </div>
           );
@@ -602,8 +635,10 @@ export function MarkdownPreviewer(props: {
     props.messages
       .map((m) => {
         return m.role === "user"
-          ? `## ${Locale.Export.MessageFromYou}:\n${m.content}`
-          : `## ${Locale.Export.MessageFromChatGPT}:\n${m.content.trim()}`;
+          ? `## ${Locale.Export.MessageFromYou}:\n${getMessageTextContent(m)}`
+          : `## ${Locale.Export.MessageFromChatGPT}:\n${getMessageTextContent(
+              m,
+            ).trim()}`;
       })
       .join("\n\n");
 

+ 21 - 4
app/components/mask.tsx

@@ -22,7 +22,7 @@ import {
   useAppConfig,
   useChatStore,
 } from "../store";
-import { ROLES } from "../client/api";
+import { MultimodalContent, ROLES } from "../client/api";
 import {
   Input,
   List,
@@ -38,7 +38,12 @@ import { useNavigate } from "react-router-dom";
 
 import chatStyle from "./chat.module.scss";
 import { useEffect, useState } from "react";
-import { copyToClipboard, downloadAs, readFromFile } from "../utils";
+import {
+  copyToClipboard,
+  downloadAs,
+  getMessageImages,
+  readFromFile,
+} from "../utils";
 import { Updater } from "../typing";
 import { ModelConfigList } from "./model-config";
 import { FileName, Path } from "../constant";
@@ -50,6 +55,7 @@ import {
   Draggable,
   OnDragEndResponder,
 } from "@hello-pangea/dnd";
+import { getMessageTextContent } from "../utils";
 
 // drag and drop helper function
 function reorder<T>(list: T[], startIndex: number, endIndex: number): T[] {
@@ -244,7 +250,7 @@ function ContextPromptItem(props: {
         </>
       )}
       <Input
-        value={props.prompt.content}
+        value={getMessageTextContent(props.prompt)}
         type="text"
         className={chatStyle["context-content"]}
         rows={focusingInput ? 5 : 1}
@@ -289,7 +295,18 @@ export function ContextPrompts(props: {
   };
 
   const updateContextPrompt = (i: number, prompt: ChatMessage) => {
-    props.updateContext((context) => (context[i] = prompt));
+    props.updateContext((context) => {
+      const images = getMessageImages(context[i]);
+      context[i] = prompt;
+      if (images.length > 0) {
+        const text = getMessageTextContent(context[i]);
+        const newContext: MultimodalContent[] = [{ type: "text", text }];
+        for (const img of images) {
+          newContext.push({ type: "image_url", image_url: { url: img } });
+        }
+        context[i].content = newContext;
+      }
+    });
   };
 
   const onDragEnd: OnDragEndResponder = (result) => {

+ 5 - 2
app/components/message-selector.tsx

@@ -7,6 +7,7 @@ import { MaskAvatar } from "./mask";
 import Locale from "../locales";
 
 import styles from "./message-selector.module.scss";
+import { getMessageTextContent } from "../utils";
 
 function useShiftRange() {
   const [startIndex, setStartIndex] = useState<number>();
@@ -103,7 +104,9 @@ export function MessageSelector(props: {
     const searchResults = new Set<string>();
     if (text.length > 0) {
       messages.forEach((m) =>
-        m.content.includes(text) ? searchResults.add(m.id!) : null,
+        getMessageTextContent(m).includes(text)
+          ? searchResults.add(m.id!)
+          : null,
       );
     }
     setSearchIds(searchResults);
@@ -219,7 +222,7 @@ export function MessageSelector(props: {
                   {new Date(m.date).toLocaleString()}
                 </div>
                 <div className={`${styles["content"]} one-line`}>
-                  {m.content}
+                  {getMessageTextContent(m)}
                 </div>
               </div>
 

+ 11 - 0
app/constant.ts

@@ -88,6 +88,7 @@ export const Azure = {
 export const Google = {
   ExampleEndpoint: "https://generativelanguage.googleapis.com/",
   ChatPath: "v1beta/models/gemini-pro:generateContent",
+  VisionChatPath: "v1beta/models/gemini-pro-vision:generateContent",
 
   // /api/openai/v1/chat/completions
 };
@@ -103,6 +104,7 @@ Latex block: $$e=mc^2$$
 `;
 
 export const SUMMARIZE_MODEL = "gpt-3.5-turbo";
+export const GEMINI_SUMMARIZE_MODEL = "gemini-pro";
 
 export const KnowledgeCutOffDate: Record<string, string> = {
   default: "2021-09",
@@ -278,6 +280,15 @@ export const DEFAULT_MODELS = [
       providerType: "google",
     },
   },
+  {
+    name: "gemini-pro-vision",
+    available: true,
+    provider: {
+      id: "google",
+      providerName: "Google",
+      providerType: "google",
+    },
+  },
 ] as const;
 
 export const CHAT_PAGE_SIZE = 15;

Plik diff jest za duży
+ 0 - 0
app/icons/image.svg


+ 1 - 0
app/icons/loading.svg

@@ -0,0 +1 @@
+<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#fff" style=""><rect id="backgroundrect" width="100%" height="100%" x="0" y="0" fill="none" stroke="none" style="" class="" /><g class="currentLayer" style=""><title>Layer 1</title><circle cx="4" cy="8" r="1.926" fill="#333" id="svg_1" class=""><animate attributeName="r" begin="0s" calcMode="linear" dur="0.8s" from="2" repeatCount="indefinite" to="2" values="2;1.2;2" /><animate attributeName="fill-opacity" begin="0s" calcMode="linear" dur="0.8s" from="1" repeatCount="indefinite" to="1" values="1;.5;1" /></circle><circle cx="8" cy="8" r="1.2736" fill="#333" fill-opacity=".3" id="svg_2" class=""><animate attributeName="r" begin="0s" calcMode="linear" dur="0.8s" from="1.2" repeatCount="indefinite" to="1.2" values="1.2;2;1.2" /><animate attributeName="fill-opacity" begin="0s" calcMode="linear" dur="0.8s" from=".5" repeatCount="indefinite" to=".5" values=".5;1;.5" /></circle><circle cx="12" cy="8" r="1.926" fill="#333" id="svg_3" class=""><animate attributeName="r" begin="0s" calcMode="linear" dur="0.8s" from="2" repeatCount="indefinite" to="2" values="2;1.2;2" /><animate attributeName="fill-opacity" begin="0s" calcMode="linear" dur="0.8s" from="1" repeatCount="indefinite" to="1" values="1;.5;1" /></circle></g></svg>

+ 2 - 2
app/locales/cn.ts

@@ -63,6 +63,7 @@ const cn = {
       Masks: "所有面具",
       Clear: "清除聊天",
       Settings: "对话设置",
+      UploadImage: "上传图片",
     },
     Rename: "重命名对话",
     Typing: "正在输入…",
@@ -315,8 +316,7 @@ const cn = {
       Google: {
         ApiKey: {
           Title: "API 密钥",
-          SubTitle:
-            "从 Google AI 获取您的 API 密钥",
+          SubTitle: "从 Google AI 获取您的 API 密钥",
           Placeholder: "输入您的 Google AI Studio API 密钥",
         },
 

+ 2 - 2
app/locales/en.ts

@@ -65,6 +65,7 @@ const en: LocaleType = {
       Masks: "Masks",
       Clear: "Clear Context",
       Settings: "Settings",
+      UploadImage: "Upload Images",
     },
     Rename: "Rename Chat",
     Typing: "Typing…",
@@ -322,8 +323,7 @@ const en: LocaleType = {
       Google: {
         ApiKey: {
           Title: "API Key",
-          SubTitle:
-            "Obtain your API Key from Google AI",
+          SubTitle: "Obtain your API Key from Google AI",
           Placeholder: "Enter your Google AI Studio API Key",
         },
 

+ 39 - 9
app/store/chat.ts

@@ -1,4 +1,4 @@
-import { trimTopic } from "../utils";
+import { trimTopic, getMessageTextContent } from "../utils";
 
 import Locale, { getLang } from "../locales";
 import { showToast } from "../components/ui-lib";
@@ -12,8 +12,9 @@ import {
   ModelProvider,
   StoreKey,
   SUMMARIZE_MODEL,
+  GEMINI_SUMMARIZE_MODEL,
 } from "../constant";
-import { ClientApi, RequestMessage } from "../client/api";
+import { ClientApi, RequestMessage, MultimodalContent } from "../client/api";
 import { ChatControllerPool } from "../client/controller";
 import { prettyObject } from "../utils/format";
 import { estimateTokenLength } from "../utils/token";
@@ -84,11 +85,20 @@ function createEmptySession(): ChatSession {
 
 function getSummarizeModel(currentModel: string) {
   // if it is using gpt-* models, force to use 3.5 to summarize
-  return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel;
+  if (currentModel.startsWith("gpt")) {
+    return SUMMARIZE_MODEL;
+  }
+  if (currentModel.startsWith("gemini-pro")) {
+    return GEMINI_SUMMARIZE_MODEL;
+  }
+  return currentModel;
 }
 
 function countMessages(msgs: ChatMessage[]) {
-  return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
+  return msgs.reduce(
+    (pre, cur) => pre + estimateTokenLength(getMessageTextContent(cur)),
+    0,
+  );
 }
 
 function fillTemplateWith(input: string, modelConfig: ModelConfig) {
@@ -280,16 +290,36 @@ export const useChatStore = createPersistStore(
         get().summarizeSession();
       },
 
-      async onUserInput(content: string) {
+      async onUserInput(content: string, attachImages?: string[]) {
         const session = get().currentSession();
         const modelConfig = session.mask.modelConfig;
 
         const userContent = fillTemplateWith(content, modelConfig);
         console.log("[User Input] after template: ", userContent);
 
-        const userMessage: ChatMessage = createMessage({
+        let mContent: string | MultimodalContent[] = userContent;
+
+        if (attachImages && attachImages.length > 0) {
+          mContent = [
+            {
+              type: "text",
+              text: userContent,
+            },
+          ];
+          mContent = mContent.concat(
+            attachImages.map((url) => {
+              return {
+                type: "image_url",
+                image_url: {
+                  url: url,
+                },
+              };
+            }),
+          );
+        }
+        let userMessage: ChatMessage = createMessage({
           role: "user",
-          content: userContent,
+          content: mContent,
         });
 
         const botMessage: ChatMessage = createMessage({
@@ -307,7 +337,7 @@ export const useChatStore = createPersistStore(
         get().updateCurrentSession((session) => {
           const savedUserMessage = {
             ...userMessage,
-            content,
+            content: mContent,
           };
           session.messages = session.messages.concat([
             savedUserMessage,
@@ -461,7 +491,7 @@ export const useChatStore = createPersistStore(
         ) {
           const msg = messages[i];
           if (!msg || msg.isError) continue;
-          tokenCount += estimateTokenLength(msg.content);
+          tokenCount += estimateTokenLength(getMessageTextContent(msg));
           reversedRecentMessages.push(msg);
         }
 

+ 92 - 12
app/utils.ts

@@ -1,12 +1,16 @@
 import { useEffect, useState } from "react";
 import { showToast } from "./components/ui-lib";
 import Locale from "./locales";
+import { RequestMessage } from "./client/api";
+import { DEFAULT_MODELS } from "./constant";
 
 export function trimTopic(topic: string) {
   // Fix an issue where double quotes still show in the Indonesian language
   // This will remove the specified punctuation from the end of the string
   // and also trim quotes from both the start and end if they exist.
-  return topic.replace(/^["“”]+|["“”]+$/g, "").replace(/[,。!?”“"、,.!?]*$/, "");
+  return topic
+    .replace(/^["“”]+|["“”]+$/g, "")
+    .replace(/[,。!?”“"、,.!?]*$/, "");
 }
 
 export async function copyToClipboard(text: string) {
@@ -40,8 +44,8 @@ export async function downloadAs(text: string, filename: string) {
       defaultPath: `${filename}`,
       filters: [
         {
-          name: `${filename.split('.').pop()} files`,
-          extensions: [`${filename.split('.').pop()}`],
+          name: `${filename.split(".").pop()} files`,
+          extensions: [`${filename.split(".").pop()}`],
         },
         {
           name: "All Files",
@@ -54,7 +58,7 @@ export async function downloadAs(text: string, filename: string) {
       try {
         await window.__TAURI__.fs.writeBinaryFile(
           result,
-          new Uint8Array([...text].map((c) => c.charCodeAt(0)))
+          new Uint8Array([...text].map((c) => c.charCodeAt(0))),
         );
         showToast(Locale.Download.Success);
       } catch (error) {
@@ -69,16 +73,59 @@ export async function downloadAs(text: string, filename: string) {
       "href",
       "data:text/plain;charset=utf-8," + encodeURIComponent(text),
     );
-  element.setAttribute("download", filename);
+    element.setAttribute("download", filename);
 
-  element.style.display = "none";
-  document.body.appendChild(element);
+    element.style.display = "none";
+    document.body.appendChild(element);
 
-  element.click();
+    element.click();
 
-  document.body.removeChild(element);
+    document.body.removeChild(element);
+  }
 }
+
+export function compressImage(file: File, maxSize: number): Promise<string> {
+  return new Promise((resolve, reject) => {
+    const reader = new FileReader();
+    reader.onload = (readerEvent: any) => {
+      const image = new Image();
+      image.onload = () => {
+        let canvas = document.createElement("canvas");
+        let ctx = canvas.getContext("2d");
+        let width = image.width;
+        let height = image.height;
+        let quality = 0.9;
+        let dataUrl;
+
+        do {
+          canvas.width = width;
+          canvas.height = height;
+          ctx?.clearRect(0, 0, canvas.width, canvas.height);
+          ctx?.drawImage(image, 0, 0, width, height);
+          dataUrl = canvas.toDataURL("image/jpeg", quality);
+
+          if (dataUrl.length < maxSize) break;
+
+          if (quality > 0.5) {
+            // Prioritize quality reduction
+            quality -= 0.1;
+          } else {
+            // Then reduce the size
+            width *= 0.9;
+            height *= 0.9;
+          }
+        } while (dataUrl.length > maxSize);
+
+        resolve(dataUrl);
+      };
+      image.onerror = reject;
+      image.src = readerEvent.target.result;
+    };
+    reader.onerror = reject;
+    reader.readAsDataURL(file);
+  });
 }
+
 export function readFromFile() {
   return new Promise<string>((res, rej) => {
     const fileInput = document.createElement("input");
@@ -212,8 +259,41 @@ export function getCSSVar(varName: string) {
 export function isMacOS(): boolean {
   if (typeof window !== "undefined") {
     let userAgent = window.navigator.userAgent.toLocaleLowerCase();
-    const macintosh = /iphone|ipad|ipod|macintosh/.test(userAgent)
-    return !!macintosh
+    const macintosh = /iphone|ipad|ipod|macintosh/.test(userAgent);
+    return !!macintosh;
+  }
+  return false;
+}
+
+export function getMessageTextContent(message: RequestMessage) {
+  if (typeof message.content === "string") {
+    return message.content;
+  }
+  for (const c of message.content) {
+    if (c.type === "text") {
+      return c.text ?? "";
+    }
+  }
+  return "";
+}
+
+export function getMessageImages(message: RequestMessage): string[] {
+  if (typeof message.content === "string") {
+    return [];
+  }
+  const urls: string[] = [];
+  for (const c of message.content) {
+    if (c.type === "image_url") {
+      urls.push(c.image_url?.url ?? "");
+    }
   }
-  return false
+  return urls;
+}
+
+export function isVisionModel(model: string) {
+  return (
+    model.startsWith("gpt-4-vision") ||
+    model.startsWith("gemini-pro-vision") ||
+    !DEFAULT_MODELS.find((m) => m.name == model)
+  );
 }

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików