Browse Source

using stream_fetch in App

lloydzhou 1 năm trước cách đây
mục cha
commit
3898c507c4
5 tập tin đã thay đổi với 156 bổ sung137 xóa
  1. 6 2
      app/utils.ts
  2. 2 0
      app/utils/chat.ts
  3. 73 79
      app/utils/stream.ts
  4. 1 3
      src-tauri/src/main.rs
  5. 74 53
      src-tauri/src/stream.rs

+ 6 - 2
app/utils.ts

@@ -288,12 +288,16 @@ export function showPlugins(provider: ServiceProvider, model: string) {
 }
 
 export function adapter(config: Record<string, unknown>) {
-  const { baseURL, url, params, ...rest } = config;
+  const { baseURL, url, params, method, data, ...rest } = config;
   const path = baseURL ? `${baseURL}${url}` : url;
   const fetchUrl = params
     ? `${path}?${new URLSearchParams(params as any).toString()}`
     : path;
-  return fetch(fetchUrl as string, rest)
+  return fetch(fetchUrl as string, {
+    ...rest,
+    method,
+    body: method.toUpperCase() == "GET" ? undefined : data,
+  })
     .then((res) => res.text())
     .then((data) => ({ data }));
 }

+ 2 - 0
app/utils/chat.ts

@@ -10,6 +10,7 @@ import {
   fetchEventSource,
 } from "@fortaine/fetch-event-source";
 import { prettyObject } from "./format";
+import { fetch as tauriFetch } from "./stream";
 
 export function compressImage(file: Blob, maxSize: number): Promise<string> {
   return new Promise((resolve, reject) => {
@@ -287,6 +288,7 @@ export function stream(
       REQUEST_TIMEOUT_MS,
     );
     fetchEventSource(chatPath, {
+      fetch: tauriFetch,
       ...chatPayload,
       async onopen(res) {
         clearTimeout(requestTimeoutId);

+ 73 - 79
app/utils/stream.ts

@@ -1,100 +1,94 @@
-// using tauri register_uri_scheme_protocol, register `stream:` protocol
+// using tauri command to send request
 // see src-tauri/src/stream.rs, and src-tauri/src/main.rs
-// 1. window.fetch(`stream://localhost/${fetchUrl}`), get request_id
-// 2. listen event: `stream-response` multi times to get response headers and body
+// 1. invoke('stream_fetch', {url, method, headers, body}), get response with headers.
+// 2. listen event: `stream-response` multi times to get body
 
 type ResponseEvent = {
   id: number;
   payload: {
     request_id: number;
     status?: number;
-    error?: string;
-    name?: string;
-    value?: string;
     chunk?: number[];
   };
 };
 
 export function fetch(url: string, options?: RequestInit): Promise<any> {
   if (window.__TAURI__) {
-    const tauriUri = window.__TAURI__.convertFileSrc(url, "stream");
-    const { signal, ...rest } = options || {};
-    return window
-      .fetch(tauriUri, rest)
-      .then((r) => r.text())
-      .then((rid) => parseInt(rid))
-      .then((request_id: number) => {
-        // 1. using event to get status and statusText and headers, and resolve it
-        let resolve: Function | undefined;
-        let reject: Function | undefined;
-        let status: number;
-        let writable: WritableStream | undefined;
-        let writer: WritableStreamDefaultWriter | undefined;
-        const headers = new Headers();
-        let unlisten: Function | undefined;
+    const { signal, method = "GET", headers = {}, body = [] } = options || {};
+    return window.__TAURI__
+      .invoke("stream_fetch", {
+        method,
+        url,
+        headers,
+        // TODO FormData
+        body:
+          typeof body === "string"
+            ? Array.from(new TextEncoder().encode(body))
+            : [],
+      })
+      .then(
+        (res: {
+          request_id: number;
+          status: number;
+          status_text: string;
+          headers: Record<string, string>;
+        }) => {
+          const { request_id, status, status_text: statusText, headers } = res;
+          console.log("send request_id", request_id, status, statusText);
+          let unlisten: Function | undefined;
+          const ts = new TransformStream();
+          const writer = ts.writable.getWriter();
 
-        if (signal) {
-          signal.addEventListener("abort", () => {
-            // Reject the promise with the abort reason.
+          const close = () => {
             unlisten && unlisten();
-            reject && reject(signal.reason);
+            writer.ready.then(() => {
+              try {
+                writer.releaseLock();
+              } catch (e) {
+                console.error(e);
+              }
+              ts.writable.close();
+            });
+          };
+
+          const response = new Response(ts.readable, {
+            status,
+            statusText,
+            headers,
           });
-        }
-        // @ts-ignore 2. listen response multi times, and write to Response.body
-        window.__TAURI__.event
-          .listen("stream-response", (e: ResponseEvent) => {
-            const { id, payload } = e;
-            const {
-              request_id: rid,
-              status: _status,
-              name,
-              value,
-              error,
-              chunk,
-            } = payload;
-            if (request_id != rid) {
-              return;
-            }
-            /**
-             * 1. get status code
-             * 2. get headers
-             * 3. start get body, then resolve response
-             * 4. get body chunk
-             */
-            if (error) {
-              unlisten && unlisten();
-              return reject && reject(error);
-            } else if (_status) {
-              status = _status;
-            } else if (name && value) {
-              headers.append(name, value);
-            } else if (chunk) {
-              if (resolve) {
-                const ts = new TransformStream();
-                writable = ts.writable;
-                writer = writable.getWriter();
-                resolve(new Response(ts.readable, { status, headers }));
-                resolve = undefined;
+          if (signal) {
+            signal.addEventListener("abort", () => close());
+          }
+          // @ts-ignore 2. listen response multi times, and write to Response.body
+          window.__TAURI__.event
+            .listen("stream-response", (e: ResponseEvent) => {
+              const { id, payload } = e;
+              const { request_id: rid, chunk, status } = payload;
+              if (request_id != rid) {
+                return;
+              }
+              if (chunk) {
+                writer &&
+                  writer.ready.then(() => {
+                    writer && writer.write(new Uint8Array(chunk));
+                  });
+              } else if (status === 0) {
+                // end of body
+                close();
               }
-              writer &&
-                writer.ready.then(() => {
-                  writer && writer.write(new Uint8Array(chunk));
-                });
-            } else if (_status === 0) {
-              // end of body
-              unlisten && unlisten();
-              writer &&
-                writer.ready.then(() => {
-                  writer && writer.releaseLock();
-                  writable && writable.close();
-                });
-            }
-          })
-          .then((u: Function) => (unlisten = u));
-        return new Promise(
-          (_resolve, _reject) => ([resolve, reject] = [_resolve, _reject]),
-        );
+            })
+            .then((u: Function) => (unlisten = u));
+          return response;
+        },
+      )
+      .catch((e) => {
+        console.error("stream error", e);
+        throw e;
       });
   }
   return window.fetch(url, options);
 }
+
+if (undefined !== window) {
+  window.tauriFetch = fetch;
+}

+ 1 - 3
src-tauri/src/main.rs

@@ -5,10 +5,8 @@ mod stream;
 
 fn main() {
   tauri::Builder::default()
+    .invoke_handler(tauri::generate_handler![stream::stream_fetch])
     .plugin(tauri_plugin_window_state::Builder::default().build())
-    .register_uri_scheme_protocol("stream", move |app_handle, request| {
-      stream::stream(app_handle, request)
-    })
     .run(tauri::generate_context!())
     .expect("error while running tauri application");
 }

+ 74 - 53
src-tauri/src/stream.rs

@@ -1,30 +1,25 @@
+//
+//
 
 use std::error::Error;
 use futures_util::{StreamExt};
 use reqwest::Client;
-use tauri::{ Manager, AppHandle };
-use tauri::http::{Request, ResponseBuilder};
-use tauri::http::Response;
+use reqwest::header::{HeaderName, HeaderMap};
 
 static mut REQUEST_COUNTER: u32 = 0;
 
 #[derive(Clone, serde::Serialize)]
-pub struct ErrorPayload {
-  request_id: u32,
-  error: String,
-}
-
-#[derive(Clone, serde::Serialize)]
-pub struct StatusPayload {
+pub struct StreamResponse {
   request_id: u32,
   status: u16,
+  status_text: String,
+  headers: HashMap<String, String>
 }
 
 #[derive(Clone, serde::Serialize)]
-pub struct HeaderPayload {
+pub struct EndPayload {
   request_id: u32,
-  name: String,
-  value: String,
+  status: u16,
 }
 
 #[derive(Clone, serde::Serialize)]
@@ -33,64 +28,90 @@ pub struct ChunkPayload {
   chunk: bytes::Bytes,
 }
 
-pub fn stream(app_handle: &AppHandle, request: &Request) -> Result<Response, Box<dyn Error>> {
+use std::collections::HashMap;
+
+#[derive(serde::Serialize)]
+pub struct CustomResponse {
+  message: String,
+  other_val: usize,
+}
+
+#[tauri::command]
+pub async fn stream_fetch(
+  window: tauri::Window,
+  method: String,
+  url: String,
+  headers: HashMap<String, String>,
+  body: Vec<u8>,
+) -> Result<StreamResponse, String> {
+
   let mut request_id = 0;
   let event_name = "stream-response";
   unsafe {
     REQUEST_COUNTER += 1;
     request_id = REQUEST_COUNTER;
   }
-  let path = request.uri().to_string().replace("stream://localhost/", "").replace("http://stream.localhost/", "");
-  let path = percent_encoding::percent_decode(path.as_bytes())
-    .decode_utf8_lossy()
-    .to_string();
-  // println!("path : {}", path);
-  let client = Client::new();
-  let handle = app_handle.app_handle();
-  // send http request
-  let body = reqwest::Body::from(request.body().clone());
-  let response_future = client.request(request.method().clone(), path)
-    .headers(request.headers().clone())
-    .body(body).send();
-
-  // get response and emit to client
-  tauri::async_runtime::spawn(async move {
-    let res = response_future.await;
-
-    match res {
-      Ok(res) => {
-        handle.emit_all(event_name, StatusPayload{ request_id, status: res.status().as_u16() }).unwrap();
-        for (name, value) in res.headers() {
-          handle.emit_all(event_name, HeaderPayload {
-            request_id,
-            name: name.to_string(),
-            value: std::str::from_utf8(value.as_bytes()).unwrap().to_string()
-          }).unwrap();
-        }
+
+  let mut _headers = HeaderMap::new();
+  for (key, value) in headers {
+      _headers.insert(key.parse::<HeaderName>().unwrap(), value.parse().unwrap());
+  }
+  let body = bytes::Bytes::from(body);
+
+  let response_future = Client::new().request(
+    method.parse::<reqwest::Method>().map_err(|err| format!("failed to parse method: {}", err))?,
+    url.parse::<reqwest::Url>().map_err(|err| format!("failed to parse url: {}", err))?
+  ).headers(_headers).body(body).send();
+
+  let res = response_future.await;
+  let response = match res {
+    Ok(res) => {
+      println!("Error: {:?}", res);
+      // get response and emit to client
+      // .register_uri_scheme_protocol("stream", move |app_handle, request| {
+      let mut headers = HashMap::new();
+      for (name, value) in res.headers() {
+        headers.insert(
+          name.as_str().to_string(),
+          std::str::from_utf8(value.as_bytes()).unwrap().to_string()
+        );
+      }
+      let status = res.status().as_u16();
+
+      tauri::async_runtime::spawn(async move {
         let mut stream = res.bytes_stream();
 
         while let Some(chunk) = stream.next().await {
           match chunk {
             Ok(bytes) => {
-              handle.emit_all(event_name, ChunkPayload{ request_id, chunk: bytes }).unwrap();
+              println!("chunk: {:?}", bytes);
+              window.emit(event_name, ChunkPayload{ request_id, chunk: bytes }).unwrap();
             }
             Err(err) => {
               println!("Error: {:?}", err);
             }
           }
         }
-        handle.emit_all(event_name, StatusPayload { request_id, status: 0 }).unwrap();
+        window.emit(event_name, EndPayload { request_id, status: 0 }).unwrap();
+      });
+
+      StreamResponse {
+        request_id,
+        status,
+        status_text: "OK".to_string(),
+        headers,
       }
-      Err(err) => {
-        println!("Error: {:?}", err.source().expect("REASON").to_string());
-        handle.emit_all(event_name, ErrorPayload {
-          request_id,
-          error: err.source().expect("REASON").to_string()
-        }).unwrap();
+    }
+    Err(err) => {
+      println!("Error: {:?}", err.source().expect("REASON").to_string());
+      StreamResponse {
+        request_id,
+        status: 599,
+        status_text: err.source().expect("REASON").to_string(),
+        headers: HashMap::new(),
       }
     }
-  });
-  return ResponseBuilder::new()
-    .header("Access-Control-Allow-Origin", "*")
-    .status(200).body(request_id.to_string().into())
+  };
+  Ok(response)
 }
+