Dogtiti 1 vuosi sitten
vanhempi
commit
f6e1f8398b

+ 7 - 0
app/components/realtime-chat/realtime-chat.module.scss

@@ -24,12 +24,19 @@
   .bottom-icons {
     display: flex;
     justify-content: space-between;
+    align-items: center;
     width: 100%;
     position: absolute;
     bottom: 20px;
     box-sizing: border-box;
     padding: 0 20px;
   }
+  .icon-center {
+    display: flex;
+    justify-content: center;
+    align-items: center;
+    gap: 4px;
+  }
 
   .icon-left,
   .icon-right {

+ 229 - 15
app/components/realtime-chat/realtime-chat.tsx

@@ -1,34 +1,220 @@
 import VoiceIcon from "@/app/icons/voice.svg";
 import VoiceOffIcon from "@/app/icons/voice-off.svg";
 import Close24Icon from "@/app/icons/close-24.svg";
+import PowerIcon from "@/app/icons/power.svg";
+
 import styles from "./realtime-chat.module.scss";
 import clsx from "clsx";
 
-import { useState, useRef, useCallback } from "react";
+import { useState, useRef, useCallback, useEffect } from "react";
 
 import { useAccessStore, useChatStore, ChatMessage } from "@/app/store";
 
+import { IconButton } from "@/app/components/button";
+
+import {
+  Modality,
+  RTClient,
+  RTInputAudioItem,
+  RTResponse,
+  TurnDetection,
+} from "rt-client";
+import { AudioHandler } from "@/app/lib/audio";
+
 interface RealtimeChatProps {
   onClose?: () => void;
   onStartVoice?: () => void;
   onPausedVoice?: () => void;
-  sampleRate?: number;
 }
 
 export function RealtimeChat({
   onClose,
   onStartVoice,
   onPausedVoice,
-  sampleRate = 24000,
 }: RealtimeChatProps) {
-  const [isVoicePaused, setIsVoicePaused] = useState(true);
-  const clientRef = useRef<null>(null);
   const currentItemId = useRef<string>("");
   const currentBotMessage = useRef<ChatMessage | null>();
   const currentUserMessage = useRef<ChatMessage | null>();
   const accessStore = useAccessStore.getState();
   const chatStore = useChatStore();
 
+  const [isRecording, setIsRecording] = useState(false);
+  const [isConnected, setIsConnected] = useState(false);
+  const [isConnecting, setIsConnecting] = useState(false);
+  const [modality, setModality] = useState("audio");
+  const [isAzure, setIsAzure] = useState(false);
+  const [endpoint, setEndpoint] = useState("");
+  const [deployment, setDeployment] = useState("");
+  const [useVAD, setUseVAD] = useState(true);
+
+  const clientRef = useRef<RTClient | null>(null);
+  const audioHandlerRef = useRef<AudioHandler | null>(null);
+
+  const apiKey = accessStore.openaiApiKey;
+
+  const handleConnect = async () => {
+    if (!isConnected) {
+      try {
+        setIsConnecting(true);
+        clientRef.current = isAzure
+          ? new RTClient(new URL(endpoint), { key: apiKey }, { deployment })
+          : new RTClient(
+              { key: apiKey },
+              { model: "gpt-4o-realtime-preview-2024-10-01" },
+            );
+        const modalities: Modality[] =
+          modality === "audio" ? ["text", "audio"] : ["text"];
+        const turnDetection: TurnDetection = useVAD
+          ? { type: "server_vad" }
+          : null;
+        clientRef.current.configure({
+          instructions: "Hi",
+          input_audio_transcription: { model: "whisper-1" },
+          turn_detection: turnDetection,
+          tools: [],
+          temperature: 0.9,
+          modalities,
+        });
+        startResponseListener();
+
+        setIsConnected(true);
+      } catch (error) {
+        console.error("Connection failed:", error);
+      } finally {
+        setIsConnecting(false);
+      }
+    } else {
+      await disconnect();
+    }
+  };
+
+  const disconnect = async () => {
+    if (clientRef.current) {
+      try {
+        await clientRef.current.close();
+        clientRef.current = null;
+        setIsConnected(false);
+      } catch (error) {
+        console.error("Disconnect failed:", error);
+      }
+    }
+  };
+
+  const startResponseListener = async () => {
+    if (!clientRef.current) return;
+
+    try {
+      for await (const serverEvent of clientRef.current.events()) {
+        if (serverEvent.type === "response") {
+          await handleResponse(serverEvent);
+        } else if (serverEvent.type === "input_audio") {
+          await handleInputAudio(serverEvent);
+        }
+      }
+    } catch (error) {
+      if (clientRef.current) {
+        console.error("Response iteration error:", error);
+      }
+    }
+  };
+
+  const handleResponse = async (response: RTResponse) => {
+    for await (const item of response) {
+      if (item.type === "message" && item.role === "assistant") {
+        const message = {
+          type: item.role,
+          content: "",
+        };
+        // setMessages((prevMessages) => [...prevMessages, message]);
+        for await (const content of item) {
+          if (content.type === "text") {
+            for await (const text of content.textChunks()) {
+              message.content += text;
+              //   setMessages((prevMessages) => {
+              //     prevMessages[prevMessages.length - 1].content = message.content;
+              //     return [...prevMessages];
+              //   });
+            }
+          } else if (content.type === "audio") {
+            const textTask = async () => {
+              for await (const text of content.transcriptChunks()) {
+                message.content += text;
+                // setMessages((prevMessages) => {
+                //   prevMessages[prevMessages.length - 1].content =
+                //     message.content;
+                //   return [...prevMessages];
+                // });
+              }
+            };
+            const audioTask = async () => {
+              audioHandlerRef.current?.startStreamingPlayback();
+              for await (const audio of content.audioChunks()) {
+                audioHandlerRef.current?.playChunk(audio);
+              }
+            };
+            await Promise.all([textTask(), audioTask()]);
+          }
+        }
+      }
+    }
+  };
+
+  const handleInputAudio = async (item: RTInputAudioItem) => {
+    audioHandlerRef.current?.stopStreamingPlayback();
+    await item.waitForCompletion();
+    // setMessages((prevMessages) => [
+    //   ...prevMessages,
+    //   {
+    //     type: "user",
+    //     content: item.transcription || "",
+    //   },
+    // ]);
+  };
+
+  const toggleRecording = async () => {
+    if (!isRecording && clientRef.current) {
+      try {
+        if (!audioHandlerRef.current) {
+          audioHandlerRef.current = new AudioHandler();
+          await audioHandlerRef.current.initialize();
+        }
+        await audioHandlerRef.current.startRecording(async (chunk) => {
+          await clientRef.current?.sendAudio(chunk);
+        });
+        setIsRecording(true);
+      } catch (error) {
+        console.error("Failed to start recording:", error);
+      }
+    } else if (audioHandlerRef.current) {
+      try {
+        audioHandlerRef.current.stopRecording();
+        if (!useVAD) {
+          const inputAudio = await clientRef.current?.commitAudio();
+          await handleInputAudio(inputAudio!);
+          await clientRef.current?.generateResponse();
+        }
+        setIsRecording(false);
+      } catch (error) {
+        console.error("Failed to stop recording:", error);
+      }
+    }
+  };
+
+  useEffect(() => {
+    const initAudioHandler = async () => {
+      const handler = new AudioHandler();
+      await handler.initialize();
+      audioHandlerRef.current = handler;
+    };
+
+    initAudioHandler().catch(console.error);
+
+    return () => {
+      disconnect();
+      audioHandlerRef.current?.close().catch(console.error);
+    };
+  }, []);
+
   //   useEffect(() => {
   //     if (
   //       clientRef.current?.getTurnDetectionType() === "server_vad" &&
@@ -223,12 +409,16 @@ export function RealtimeChat({
 
   const handleStartVoice = useCallback(() => {
     onStartVoice?.();
-    setIsVoicePaused(false);
+    handleConnect();
   }, []);
 
   const handlePausedVoice = () => {
     onPausedVoice?.();
-    setIsVoicePaused(true);
+  };
+
+  const handleClose = () => {
+    onClose?.();
+    disconnect();
   };
 
   return (
@@ -241,15 +431,39 @@ export function RealtimeChat({
         <div className={styles["icon-center"]}></div>
       </div>
       <div className={styles["bottom-icons"]}>
-        <div className={styles["icon-left"]}>
-          {isVoicePaused ? (
-            <VoiceOffIcon onClick={handleStartVoice} />
-          ) : (
-            <VoiceIcon onClick={handlePausedVoice} />
-          )}
+        <div>
+          <IconButton
+            icon={isRecording ? <VoiceOffIcon /> : <VoiceIcon />}
+            onClick={toggleRecording}
+            disabled={!isConnected}
+            bordered
+            shadow
+          />
+        </div>
+        <div className={styles["icon-center"]}>
+          <IconButton
+            icon={<PowerIcon />}
+            text={
+              isConnecting
+                ? "Connecting..."
+                : isConnected
+                ? "Disconnect"
+                : "Connect"
+            }
+            onClick={handleConnect}
+            disabled={isConnecting}
+            bordered
+            shadow
+          />
         </div>
-        <div className={styles["icon-right"]} onClick={onClose}>
-          <Close24Icon />
+        <div onClick={handleClose}>
+          <IconButton
+            icon={<Close24Icon />}
+            onClick={handleClose}
+            disabled={!isConnected}
+            bordered
+            shadow
+          />
         </div>
       </div>
     </div>

+ 7 - 0
app/icons/power.svg

@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<svg width="24" height="24" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg">
+    <path
+        d="M14.5 8C13.8406 8.37652 13.2062 8.79103 12.6 9.24051C11.5625 10.0097 10.6074 10.8814 9.75 11.8402C6.79377 15.1463 5 19.4891 5 24.2455C5 34.6033 13.5066 43 24 43C34.4934 43 43 34.6033 43 24.2455C43 19.4891 41.2062 15.1463 38.25 11.8402C37.3926 10.8814 36.4375 10.0097 35.4 9.24051C34.7938 8.79103 34.1594 8.37652 33.5 8"
+        stroke="#333" stroke-width="4" stroke-linecap="round" stroke-linejoin="round" />
+    <path d="M24 4V24" stroke="#333" stroke-width="4" stroke-linecap="round" stroke-linejoin="round" />
+</svg>

+ 134 - 0
app/lib/audio.ts

@@ -0,0 +1,134 @@
+export class AudioHandler {
+  private context: AudioContext;
+  private workletNode: AudioWorkletNode | null = null;
+  private stream: MediaStream | null = null;
+  private source: MediaStreamAudioSourceNode | null = null;
+  private readonly sampleRate = 24000;
+
+  private nextPlayTime: number = 0;
+  private isPlaying: boolean = false;
+  private playbackQueue: AudioBufferSourceNode[] = [];
+
+  constructor() {
+    this.context = new AudioContext({ sampleRate: this.sampleRate });
+  }
+
+  async initialize() {
+    await this.context.audioWorklet.addModule("/audio-processor.js");
+  }
+
+  async startRecording(onChunk: (chunk: Uint8Array) => void) {
+    try {
+      if (!this.workletNode) {
+        await this.initialize();
+      }
+
+      this.stream = await navigator.mediaDevices.getUserMedia({
+        audio: {
+          channelCount: 1,
+          sampleRate: this.sampleRate,
+          echoCancellation: true,
+          noiseSuppression: true,
+        },
+      });
+
+      await this.context.resume();
+      this.source = this.context.createMediaStreamSource(this.stream);
+      this.workletNode = new AudioWorkletNode(
+        this.context,
+        "audio-recorder-processor",
+      );
+
+      this.workletNode.port.onmessage = (event) => {
+        if (event.data.eventType === "audio") {
+          const float32Data = event.data.audioData;
+          const int16Data = new Int16Array(float32Data.length);
+
+          for (let i = 0; i < float32Data.length; i++) {
+            const s = Math.max(-1, Math.min(1, float32Data[i]));
+            int16Data[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
+          }
+
+          const uint8Data = new Uint8Array(int16Data.buffer);
+          onChunk(uint8Data);
+        }
+      };
+
+      this.source.connect(this.workletNode);
+      this.workletNode.connect(this.context.destination);
+
+      this.workletNode.port.postMessage({ command: "START_RECORDING" });
+    } catch (error) {
+      console.error("Error starting recording:", error);
+      throw error;
+    }
+  }
+
+  stopRecording() {
+    if (!this.workletNode || !this.source || !this.stream) {
+      throw new Error("Recording not started");
+    }
+
+    this.workletNode.port.postMessage({ command: "STOP_RECORDING" });
+
+    this.workletNode.disconnect();
+    this.source.disconnect();
+    this.stream.getTracks().forEach((track) => track.stop());
+  }
+  startStreamingPlayback() {
+    this.isPlaying = true;
+    this.nextPlayTime = this.context.currentTime;
+  }
+
+  stopStreamingPlayback() {
+    this.isPlaying = false;
+    this.playbackQueue.forEach((source) => source.stop());
+    this.playbackQueue = [];
+  }
+
+  playChunk(chunk: Uint8Array) {
+    if (!this.isPlaying) return;
+
+    const int16Data = new Int16Array(chunk.buffer);
+
+    const float32Data = new Float32Array(int16Data.length);
+    for (let i = 0; i < int16Data.length; i++) {
+      float32Data[i] = int16Data[i] / (int16Data[i] < 0 ? 0x8000 : 0x7fff);
+    }
+
+    const audioBuffer = this.context.createBuffer(
+      1,
+      float32Data.length,
+      this.sampleRate,
+    );
+    audioBuffer.getChannelData(0).set(float32Data);
+
+    const source = this.context.createBufferSource();
+    source.buffer = audioBuffer;
+    source.connect(this.context.destination);
+
+    const chunkDuration = audioBuffer.length / this.sampleRate;
+
+    source.start(this.nextPlayTime);
+
+    this.playbackQueue.push(source);
+    source.onended = () => {
+      const index = this.playbackQueue.indexOf(source);
+      if (index > -1) {
+        this.playbackQueue.splice(index, 1);
+      }
+    };
+
+    this.nextPlayTime += chunkDuration;
+
+    if (this.nextPlayTime < this.context.currentTime) {
+      this.nextPlayTime = this.context.currentTime;
+    }
+  }
+  async close() {
+    this.workletNode?.disconnect();
+    this.source?.disconnect();
+    this.stream?.getTracks().forEach((track) => track.stop());
+    await this.context.close();
+  }
+}

+ 2 - 1
package.json

@@ -52,7 +52,8 @@
     "sass": "^1.59.2",
     "spark-md5": "^3.0.2",
     "use-debounce": "^9.0.4",
-    "zustand": "^4.3.8"
+    "zustand": "^4.3.8",
+    "rt-client": "https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz"
   },
   "devDependencies": {
     "@tauri-apps/api": "^1.6.0",

+ 48 - 0
public/audio-processor.js

@@ -0,0 +1,48 @@
+// @ts-nocheck
+class AudioRecorderProcessor extends AudioWorkletProcessor {
+  constructor() {
+    super();
+    this.isRecording = false;
+    this.bufferSize = 2400; // 100ms at 24kHz
+    this.currentBuffer = [];
+
+    this.port.onmessage = (event) => {
+      if (event.data.command === "START_RECORDING") {
+        this.isRecording = true;
+      } else if (event.data.command === "STOP_RECORDING") {
+        this.isRecording = false;
+
+        if (this.currentBuffer.length > 0) {
+          this.sendBuffer();
+        }
+      }
+    };
+  }
+
+  sendBuffer() {
+    if (this.currentBuffer.length > 0) {
+      const audioData = new Float32Array(this.currentBuffer);
+      this.port.postMessage({
+        eventType: "audio",
+        audioData: audioData,
+      });
+      this.currentBuffer = [];
+    }
+  }
+
+  process(inputs) {
+    const input = inputs[0];
+    if (input.length > 0 && this.isRecording) {
+      const audioData = input[0];
+
+      this.currentBuffer.push(...audioData);
+
+      if (this.currentBuffer.length >= this.bufferSize) {
+        this.sendBuffer();
+      }
+    }
+    return true;
+  }
+}
+
+registerProcessor("audio-recorder-processor", AudioRecorderProcessor);

+ 8 - 2
yarn.lock

@@ -7455,6 +7455,12 @@ robust-predicates@^3.0.0:
   resolved "https://registry.npmmirror.com/robust-predicates/-/robust-predicates-3.0.1.tgz#ecde075044f7f30118682bd9fb3f123109577f9a"
   integrity sha512-ndEIpszUHiG4HtDsQLeIuMvRsDnn8c8rYStabochtUeCvfuvNptb5TUbVD68LRAILPX7p9nqQGh4xJgn3EHS/g==
 
+"rt-client@https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz":
+  version "0.5.0"
+  resolved "https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz#abf2e9a850201e3571b8d36830f77bc52af3de9b"
+  dependencies:
+    ws "^8.18.0"
+
 run-parallel@^1.1.9:
   version "1.2.0"
   resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee"
@@ -8498,9 +8504,9 @@ write-file-atomic@^4.0.2:
     imurmurhash "^0.1.4"
     signal-exit "^3.0.7"
 
-ws@^8.11.0:
+ws@^8.11.0, ws@^8.18.0:
   version "8.18.0"
-  resolved "https://registry.npmmirror.com/ws/-/ws-8.18.0.tgz#0d7505a6eafe2b0e712d232b42279f53bc289bbc"
+  resolved "https://registry.yarnpkg.com/ws/-/ws-8.18.0.tgz#0d7505a6eafe2b0e712d232b42279f53bc289bbc"
   integrity sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw==
 
 xml-name-validator@^4.0.0: