Jelajahi Sumber

using stream: schema to fetch in App

lloydzhou 1 tahun lalu
induk
melakukan
2d920f7ccc
7 mengubah file dengan 204 tambahan dan 122 penghapusan
  1. 1 0
      app/global.d.ts
  2. 1 40
      app/utils.ts
  3. 100 0
      app/utils/stream.ts
  4. 1 35
      src-tauri/Cargo.lock
  5. 1 0
      src-tauri/Cargo.toml
  6. 4 47
      src-tauri/src/main.rs
  7. 96 0
      src-tauri/src/stream.rs

+ 1 - 0
app/global.d.ts

@@ -12,6 +12,7 @@ declare module "*.svg";
 
 declare interface Window {
   __TAURI__?: {
+    convertFileSrc(url: string, protocol?: string): string;
     writeText(text: string): Promise<void>;
     invoke(command: string, payload?: Record<string, unknown>): Promise<any>;
     dialog: {

+ 1 - 40
app/utils.ts

@@ -3,6 +3,7 @@ import { showToast } from "./components/ui-lib";
 import Locale from "./locales";
 import { RequestMessage } from "./client/api";
 import { ServiceProvider } from "./constant";
+import { fetch } from "./utils/stream";
 
 export function trimTopic(topic: string) {
   // Fix an issue where double quotes still show in the Indonesian language
@@ -286,46 +287,6 @@ export function showPlugins(provider: ServiceProvider, model: string) {
   return false;
 }
 
-export function fetch(
-  url: string,
-  options?: Record<string, unknown>,
-): Promise<any> {
-  if (window.__TAURI__) {
-    const tauriUri = window.__TAURI__.convertFileSrc(url, "sse");
-    return window.fetch(tauriUri, options).then((r) => {
-      // 1. create response,
-      // TODO using event to get status and statusText and headers
-      const { status, statusText } = r;
-      const { readable, writable } = new TransformStream();
-      const res = new Response(readable, { status, statusText });
-      // 2. call fetch_read_body multi times, and write to Response.body
-      const writer = writable.getWriter();
-      let unlisten;
-      window.__TAURI__.event
-        .listen("sse-response", (e) => {
-          const { id, payload } = e;
-          console.log("event", id, payload);
-          writer.ready.then(() => {
-            if (payload !== 0) {
-              writer.write(new Uint8Array(payload));
-            } else {
-              writer.releaseLock();
-              writable.close();
-              unlisten && unlisten();
-            }
-          });
-        })
-        .then((u) => (unlisten = u));
-      return res;
-    });
-  }
-  return window.fetch(url, options);
-}
-
-if (undefined !== window) {
-  window.tauriFetch = fetch;
-}
-
 export function adapter(config: Record<string, unknown>) {
   const { baseURL, url, params, ...rest } = config;
   const path = baseURL ? `${baseURL}${url}` : url;

+ 100 - 0
app/utils/stream.ts

@@ -0,0 +1,100 @@
+// using tauri register_uri_scheme_protocol, register `stream:` protocol
+// see src-tauri/src/stream.rs, and src-tauri/src/main.rs
+// 1. window.fetch(`stream://localhost/${fetchUrl}`), get request_id
+// 2. listen event: `stream-response` multi times to get response headers and body
+
+type ResponseEvent = {
+  id: number;
+  payload: {
+    request_id: number;
+    status?: number;
+    error?: string;
+    name?: string;
+    value?: string;
+    chunk?: number[];
+  };
+};
+
+export function fetch(url: string, options?: RequestInit): Promise<any> {
+  if (window.__TAURI__) {
+    const tauriUri = window.__TAURI__.convertFileSrc(url, "stream");
+    const { signal, ...rest } = options || {};
+    return window
+      .fetch(tauriUri, rest)
+      .then((r) => r.text())
+      .then((rid) => parseInt(rid))
+      .then((request_id: number) => {
+        // 1. using event to get status and statusText and headers, and resolve it
+        let resolve: Function | undefined;
+        let reject: Function | undefined;
+        let status: number;
+        let writable: WritableStream | undefined;
+        let writer: WritableStreamDefaultWriter | undefined;
+        const headers = new Headers();
+        let unlisten: Function | undefined;
+
+        if (signal) {
+          signal.addEventListener("abort", () => {
+            // Reject the promise with the abort reason.
+            unlisten && unlisten();
+            reject && reject(signal.reason);
+          });
+        }
+        // @ts-ignore 2. listen response multi times, and write to Response.body
+        window.__TAURI__.event
+          .listen("stream-response", (e: ResponseEvent) => {
+            const { id, payload } = e;
+            const {
+              request_id: rid,
+              status: _status,
+              name,
+              value,
+              error,
+              chunk,
+            } = payload;
+            if (request_id != rid) {
+              return;
+            }
+            /**
+             * 1. get status code
+             * 2. get headers
+             * 3. start get body, then resolve response
+             * 4. get body chunk
+             */
+            if (error) {
+              unlisten && unlisten();
+              return reject && reject(error);
+            } else if (_status) {
+              status = _status;
+            } else if (name && value) {
+              headers.append(name, value);
+            } else if (chunk) {
+              if (resolve) {
+                const ts = new TransformStream();
+                writable = ts.writable;
+                writer = writable.getWriter();
+                resolve(new Response(ts.readable, { status, headers }));
+                resolve = undefined;
+              }
+              writer &&
+                writer.ready.then(() => {
+                  writer && writer.write(new Uint8Array(chunk));
+                });
+            } else if (_status === 0) {
+              // end of body
+              unlisten && unlisten();
+              writer &&
+                writer.ready.then(() => {
+                  writer && writer.releaseLock();
+                  writable && writable.close();
+                });
+            }
+          })
+          .then((u: Function) => (unlisten = u));
+        return new Promise(
+          (_resolve, _reject) => ([resolve, reject] = [_resolve, _reject]),
+        );
+      });
+  }
+  return window.fetch(url, options);
+}

+ 1 - 35
src-tauri/Cargo.lock

@@ -1986,6 +1986,7 @@ checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54"
 name = "nextchat"
 version = "0.1.0"
 dependencies = [
+ "bytes",
  "futures-util",
  "percent-encoding",
  "reqwest",
@@ -2216,17 +2217,6 @@ dependencies = [
  "pin-project-lite",
 ]
 
-[[package]]
-name = "os_info"
-version = "3.8.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ae99c7fa6dd38c7cafe1ec085e804f8f555a2f8659b0dbe03f1f9963a9b51092"
-dependencies = [
- "log",
- "serde",
- "windows-sys 0.52.0",
-]
-
 [[package]]
 name = "overload"
 version = "0.1.1"
@@ -3251,19 +3241,6 @@ dependencies = [
  "unicode-ident",
 ]
 
-[[package]]
-name = "sys-locale"
-version = "0.2.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8a11bd9c338fdba09f7881ab41551932ad42e405f61d01e8406baea71c07aee"
-dependencies = [
- "js-sys",
- "libc",
- "wasm-bindgen",
- "web-sys",
- "windows-sys 0.45.0",
-]
-
 [[package]]
 name = "system-configuration"
 version = "0.5.1"
@@ -3412,7 +3389,6 @@ dependencies = [
  "objc",
  "once_cell",
  "open",
- "os_info",
  "percent-encoding",
  "rand 0.8.5",
  "raw-window-handle",
@@ -3425,7 +3401,6 @@ dependencies = [
  "serde_repr",
  "serialize-to-javascript",
  "state",
- "sys-locale",
  "tar",
  "tauri-macros",
  "tauri-runtime",
@@ -4345,15 +4320,6 @@ dependencies = [
  "windows-targets 0.48.0",
 ]
 
-[[package]]
-name = "windows-sys"
-version = "0.52.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
-dependencies = [
- "windows-targets 0.52.0",
-]
-
 [[package]]
 name = "windows-targets"
 version = "0.42.2"

+ 1 - 0
src-tauri/Cargo.toml

@@ -41,6 +41,7 @@ tauri-plugin-window-state = { git = "https://github.com/tauri-apps/plugins-works
 percent-encoding = "2.3.1"
 reqwest = "0.11.18"
 futures-util = "0.3.30"
+bytes = "1.7.2"
 
 [features]
 # this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled.

+ 4 - 47
src-tauri/src/main.rs

@@ -1,57 +1,14 @@
 // Prevents additional console window on Windows in release, DO NOT REMOVE!!
 #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
 
-use futures_util::{StreamExt};
-use reqwest::Client;
-use tauri::{ Manager};
-use tauri::http::{ResponseBuilder};
+mod stream;
 
 fn main() {
   tauri::Builder::default()
     .plugin(tauri_plugin_window_state::Builder::default().build())
-    .register_uri_scheme_protocol("sse", |app_handle, request| {
-      let path = request.uri().strip_prefix("sse://localhost/").unwrap();
-      let path = percent_encoding::percent_decode(path.as_bytes())
-        .decode_utf8_lossy()
-        .to_string();
-      // println!("path : {}", path);
-      let client = Client::new();
-      let window = app_handle.get_window("main").unwrap();
-      // send http request
-      let body = reqwest::Body::from(request.body().clone());
-      let response_future = client.request(request.method().clone(), path)
-        .headers(request.headers().clone())
-        .body(body).send();
-
-      // get response and emit to client
-      tauri::async_runtime::spawn(async move {
-        let res = response_future.await;
-
-        match res {
-          Ok(res) => {
-            let mut stream = res.bytes_stream();
-
-            while let Some(chunk) = stream.next().await {
-              match chunk {
-                Ok(bytes) => {
-                  window.emit("sse-response", bytes).unwrap();
-                }
-                Err(err) => {
-                  println!("Error: {:?}", err);
-                }
-              }
-            }
-            window.emit("sse-response", 0).unwrap();
-          }
-          Err(err) => {
-            println!("Error: {:?}", err);
-          }
-        }
-      });
-      ResponseBuilder::new()
-        .header("Access-Control-Allow-Origin", "*")
-        .status(200).body("OK".into())
-      })
+    .register_uri_scheme_protocol("stream", move |app_handle, request| {
+      stream::stream(app_handle, request)
+    })
     .run(tauri::generate_context!())
     .expect("error while running tauri application");
 }

+ 96 - 0
src-tauri/src/stream.rs

@@ -0,0 +1,96 @@
+
+use std::error::Error;
+use futures_util::{StreamExt};
+use reqwest::Client;
+use tauri::{ Manager, AppHandle };
+use tauri::http::{Request, ResponseBuilder};
+use tauri::http::Response;
+
+static mut REQUEST_COUNTER: u32 = 0;
+
+#[derive(Clone, serde::Serialize)]
+pub struct ErrorPayload {
+  request_id: u32,
+  error: String,
+}
+
+#[derive(Clone, serde::Serialize)]
+pub struct StatusPayload {
+  request_id: u32,
+  status: u16,
+}
+
+#[derive(Clone, serde::Serialize)]
+pub struct HeaderPayload {
+  request_id: u32,
+  name: String,
+  value: String,
+}
+
+#[derive(Clone, serde::Serialize)]
+pub struct ChunkPayload {
+  request_id: u32,
+  chunk: bytes::Bytes,
+}
+
+pub fn stream(app_handle: &AppHandle, request: &Request) -> Result<Response, Box<dyn Error>> {
+  let mut request_id = 0;
+  let event_name = "stream-response";
+  unsafe {
+    REQUEST_COUNTER += 1;
+    request_id = REQUEST_COUNTER;
+  }
+  let path = request.uri().to_string().replace("stream://localhost/", "").replace("http://stream.localhost/", "");
+  let path = percent_encoding::percent_decode(path.as_bytes())
+    .decode_utf8_lossy()
+    .to_string();
+  // println!("path : {}", path);
+  let client = Client::new();
+  let handle = app_handle.app_handle();
+  // send http request
+  let body = reqwest::Body::from(request.body().clone());
+  let response_future = client.request(request.method().clone(), path)
+    .headers(request.headers().clone())
+    .body(body).send();
+
+  // get response and emit to client
+  tauri::async_runtime::spawn(async move {
+    let res = response_future.await;
+
+    match res {
+      Ok(res) => {
+        handle.emit_all(event_name, StatusPayload{ request_id, status: res.status().as_u16() }).unwrap();
+        for (name, value) in res.headers() {
+          handle.emit_all(event_name, HeaderPayload {
+            request_id,
+            name: name.to_string(),
+            value: std::str::from_utf8(value.as_bytes()).unwrap().to_string()
+          }).unwrap();
+        }
+        let mut stream = res.bytes_stream();
+
+        while let Some(chunk) = stream.next().await {
+          match chunk {
+            Ok(bytes) => {
+              handle.emit_all(event_name, ChunkPayload{ request_id, chunk: bytes }).unwrap();
+            }
+            Err(err) => {
+              println!("Error: {:?}", err);
+            }
+          }
+        }
+        handle.emit_all(event_name, StatusPayload { request_id, status: 0 }).unwrap();
+      }
+      Err(err) => {
+        println!("Error: {:?}", err.source().expect("REASON").to_string());
+        handle.emit_all(event_name, ErrorPayload {
+          request_id,
+          error: err.source().expect("REASON").to_string()
+        }).unwrap();
+      }
+    }
+  });
+  return ResponseBuilder::new()
+    .header("Access-Control-Allow-Origin", "*")
+    .status(200).body(request_id.to_string().into())
+}