diff --git a/server_fn/src/client.rs b/server_fn/src/client.rs index a6c413895..34766aa86 100644 --- a/server_fn/src/client.rs +++ b/server_fn/src/client.rs @@ -42,7 +42,7 @@ pub trait Client { Output = Result< ( impl Stream> + Send + 'static, - impl Sink> + Send + 'static, + impl Sink + Send + 'static, ), Error, >, @@ -62,8 +62,8 @@ pub mod browser { response::browser::BrowserResponse, }; use bytes::Bytes; - use futures::{Sink, SinkExt, StreamExt, TryStreamExt}; - use gloo_net::websocket::{events::CloseEvent, Message, WebSocketError}; + use futures::{Sink, SinkExt, StreamExt}; + use gloo_net::websocket::{Message, WebSocketError}; use send_wrapper::SendWrapper; use std::future::Future; @@ -115,7 +115,7 @@ pub mod browser { impl futures::Stream> + Send + 'static, - impl futures::Sink> + Send + 'static, + impl futures::Sink + Send + 'static, ), Error, >, @@ -131,18 +131,23 @@ pub mod browser { })?; let (sink, stream) = websocket.split(); - let stream = stream - .map_err(|err| { + let stream = stream.map(|message| match message { + Ok(message) => { + crate::deserialize_result::( + match message { + Message::Text(text) => Bytes::from(text), + Message::Bytes(bytes) => Bytes::from(bytes), + }, + ) + } + Err(err) => { web_sys::console::error_1(&err.to_string().into()); - OutputStreamError::from_server_fn_error( + Err(OutputStreamError::from_server_fn_error( ServerFnErrorErr::Request(err.to_string()), ) - .ser() - }) - .map_ok(move |msg| match msg { - Message::Text(text) => Bytes::from(text), - Message::Bytes(bytes) => Bytes::from(bytes), - }); + .ser()) + } + }); let stream = SendWrapper::new(stream); struct SendWrapperSink { @@ -195,29 +200,11 @@ pub mod browser { } } - let sink = - sink.with(|message: Result| async move { - match message { - Ok(message) => Ok(Message::Bytes(message.into())), - Err(err) => { - let err = InputStreamError::de(err); - let formatted_err = format!("{:?}", err); - web_sys::console::error_1( - &js_sys::JsString::from( - formatted_err.clone(), - ), - ); - const CLOSE_CODE_ERROR: u16 = 1011; - Err(WebSocketError::ConnectionClose( - CloseEvent { - code: CLOSE_CODE_ERROR, - reason: formatted_err, - was_clean: true, - }, - )) - } - } - }); + let sink = sink.with(|message: Bytes| async move { + Ok::(Message::Bytes( + message.into(), + )) + }); let sink = SendWrapperSink::new(Box::pin(sink)); Ok((stream, sink)) @@ -246,13 +233,19 @@ pub mod reqwest { /// Implements [`Client`] for a request made by [`reqwest`]. pub struct ReqwestClient; - impl Client for ReqwestClient { + impl< + Error: FromServerFnError, + InputStreamError: FromServerFnError, + OutputStreamError: FromServerFnError, + > Client for ReqwestClient + { type Request = Request; type Response = Response; fn send( req: Self::Request, - ) -> impl Future> + Send { + ) -> impl Future> + Send + { CLIENT.execute(req).map_err(|e| { ServerFnErrorErr::Request(e.to_string()).into_app_error() }) @@ -262,12 +255,10 @@ pub mod reqwest { path: &str, ) -> Result< ( - impl futures::Stream> - + Send - + 'static, - impl futures::Sink> + Send + 'static, + impl futures::Stream> + Send + 'static, + impl futures::Sink + Send + 'static, ), - E, + Error, > { let mut websocket_server_url = get_server_url().to_string(); if let Some(postfix) = websocket_server_url.strip_prefix("http://") @@ -281,7 +272,7 @@ pub mod reqwest { let url = format!("{}{}", websocket_server_url, path); let (ws_stream, _) = tokio_tungstenite::connect_async(url).await.map_err(|e| { - E::from_server_fn_error(ServerFnErrorErr::Request( + Error::from_server_fn_error(ServerFnErrorErr::Request( e.to_string(), )) })?; @@ -290,26 +281,21 @@ pub mod reqwest { Ok(( read.map(|msg| match msg { - Ok(msg) => Ok(msg.into_data()), - Err(e) => Err(E::from_server_fn_error( + Ok(msg) => crate::deserialize_result::( + msg.into_data(), + ), + Err(e) => Err(OutputStreamError::from_server_fn_error( ServerFnErrorErr::Request(e.to_string()), ) .ser()), }), - write.with(|msg: Result| async move { - match msg { - Ok(msg) => { - Ok(tokio_tungstenite::tungstenite::Message::Binary( - msg, - )) - } - Err(err) => { - let err = E::de(err); - Err(tokio_tungstenite::tungstenite::Error::Io( - std::io::Error::other(format!("{:?}", err)), - )) - } - } + write.with(|msg: Bytes| async move { + Ok::< + tokio_tungstenite::tungstenite::Message, + tokio_tungstenite::tungstenite::Error, + >( + tokio_tungstenite::tungstenite::Message::Binary(msg) + ) }), )) } diff --git a/server_fn/src/lib.rs b/server_fn/src/lib.rs index f006e9f3f..705e26266 100644 --- a/server_fn/src/lib.rs +++ b/server_fn/src/lib.rs @@ -136,6 +136,7 @@ use base64::{engine::general_purpose::STANDARD_NO_PAD, DecodeError, Engine}; // re-exported to make it possible to implement a custom Client without adding a separate // dependency on `bytes` pub use bytes::Bytes; +use bytes::{BufMut, BytesMut}; use client::Client; use codec::{Encoding, FromReq, FromRes, IntoReq, IntoRes}; #[doc(hidden)] @@ -656,14 +657,17 @@ where let output = server_fn(input.into()).await?; - let output = output.stream.map(|output| match output { - Ok(output) => OutputEncoding::encode(&output).map_err(|e| { - OutputStreamError::from_server_fn_error( - ServerFnErrorErr::Serialization(e.to_string()), - ) - .ser() - }), - Err(err) => Err(err.ser()), + let output = output.stream.map(|output| { + let result = match output { + Ok(output) => OutputEncoding::encode(&output).map_err(|e| { + OutputStreamError::from_server_fn_error( + ServerFnErrorErr::Serialization(e.to_string()), + ) + .ser() + }), + Err(err) => Err(err.ser()), + }; + serialize_result(result) }); Server::spawn(async move { @@ -695,23 +699,21 @@ where pin_mut!(input); pin_mut!(sink); while let Some(input) = input.stream.next().await { - if sink - .send( - input - .and_then(|input| { - InputEncoding::encode(&input).map_err(|e| { - InputStreamError::from_server_fn_error( - ServerFnErrorErr::Serialization( - e.to_string(), - ), - ) - }) - }) - .map_err(|e| e.ser()), - ) - .await - .is_err() - { + let result = match input { + Ok(input) => { + InputEncoding::encode(&input).map_err(|e| { + InputStreamError::from_server_fn_error( + ServerFnErrorErr::Serialization( + e.to_string(), + ), + ) + .ser() + }) + } + Err(err) => Err(err.ser()), + }; + let result = serialize_result(result); + if sink.send(result).await.is_err() { break; } } @@ -740,6 +742,53 @@ where } } +/// Serializes a Result into a single Bytes instance. +/// Format: [tag: u8][content: Bytes] +/// - Tag 0: Ok variant +/// - Tag 1: Err variant +pub(crate) fn serialize_result(result: Result) -> Bytes { + match result { + Ok(bytes) => { + let mut buf = BytesMut::with_capacity(1 + bytes.len()); + buf.put_u8(0); // Tag for Ok variant + buf.extend_from_slice(&bytes); + buf.freeze() + } + Err(bytes) => { + let mut buf = BytesMut::with_capacity(1 + bytes.len()); + buf.put_u8(1); // Tag for Err variant + buf.extend_from_slice(&bytes); + buf.freeze() + } + } +} + +/// Deserializes a Bytes instance back into a Result. +pub(crate) fn deserialize_result( + bytes: Bytes, +) -> Result { + if bytes.is_empty() { + return Err(E::from_server_fn_error( + ServerFnErrorErr::Deserialization("Data is empty".into()), + ) + .ser()); + } + + let tag = bytes[0]; + let content = bytes.slice(1..); + + match tag { + 0 => Ok(content), + 1 => Err(content), + _ => { + return Err(E::from_server_fn_error( + ServerFnErrorErr::Deserialization("Invalid data tag".into()), + ) + .ser()) + } // Invalid tag + } +} + /// Encode format type pub enum Format { /// Binary representation @@ -1218,3 +1267,45 @@ pub mod mock { } } } + +#[cfg(test)] +mod tests { + + use super::*; + use crate::codec::JsonEncoding; + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Serialize, Deserialize)] + enum TestError { + ServerFnError(ServerFnErrorErr), + } + + impl FromServerFnError for TestError { + type Encoder = JsonEncoding; + + fn from_server_fn_error(value: ServerFnErrorErr) -> Self { + Self::ServerFnError(value) + } + } + #[test] + fn test_result_serialization() { + // Test Ok variant + let ok_result: Result = + Ok(Bytes::from_static(b"success data")); + let serialized = serialize_result(ok_result); + let deserialized = deserialize_result::(serialized); + assert!(deserialized.is_ok()); + assert_eq!(deserialized.unwrap(), Bytes::from_static(b"success data")); + + // Test Err variant + let err_result: Result = + Err(Bytes::from_static(b"error details")); + let serialized = serialize_result(err_result); + let deserialized = deserialize_result::(serialized); + assert!(deserialized.is_err()); + assert_eq!( + deserialized.unwrap_err(), + Bytes::from_static(b"error details") + ); + } +} diff --git a/server_fn/src/request/actix.rs b/server_fn/src/request/actix.rs index 261ef7f16..a8f25eff7 100644 --- a/server_fn/src/request/actix.rs +++ b/server_fn/src/request/actix.rs @@ -117,7 +117,7 @@ where ) -> Result< ( impl Stream> + Send + 'static, - impl futures::Sink> + Send + 'static, + impl futures::Sink + Send + 'static, Self::WebsocketResponse, ), Error, @@ -133,7 +133,7 @@ where let (mut response_stream_tx, response_stream_rx) = futures::channel::mpsc::channel(2048); let (response_sink_tx, mut response_sink_rx) = - futures::channel::mpsc::channel::>(2048); + futures::channel::mpsc::channel::(2048); actix_web::rt::spawn(async move { loop { @@ -142,16 +142,9 @@ where let Some(incoming) = incoming else { break; }; - match incoming { - Ok(message) => { - if let Err(err) = session.binary(message).await { + if let Err(err) = session.binary(incoming).await { _ = response_stream_tx.start_send(Err(InputStreamError::from_server_fn_error(ServerFnErrorErr::Request(err.to_string())).ser())); } - } - Err(err) => { - _ = response_stream_tx.start_send(Err(err)); - } - } }, outgoing = msg_stream.next().fuse() => { let Some(outgoing) = outgoing else { @@ -166,11 +159,11 @@ where Ok(Message::Binary(bytes)) => { _ = response_stream_tx .start_send( - Ok(bytes), + crate::deserialize_result::(bytes), ); } Ok(Message::Text(text)) => { - _ = response_stream_tx.start_send(Ok(text.into_bytes())); + _ = response_stream_tx.start_send(crate::deserialize_result::(text.into_bytes())); } Ok(_other) => { } diff --git a/server_fn/src/request/axum.rs b/server_fn/src/request/axum.rs index 99f676666..5f2e4e64b 100644 --- a/server_fn/src/request/axum.rs +++ b/server_fn/src/request/axum.rs @@ -79,7 +79,7 @@ where ) -> Result< ( impl Stream> + Send + 'static, - impl Sink> + Send + 'static, + impl Sink + Send + 'static, Self::WebsocketResponse, ), Error, @@ -91,7 +91,7 @@ where futures::stream::Once< std::future::Ready>, >, - futures::sink::Drain>, + futures::sink::Drain, Self::WebsocketResponse, ), Error, @@ -117,9 +117,9 @@ where )) })?; let (mut outgoing_tx, outgoing_rx) = - futures::channel::mpsc::channel(2048); - let (incoming_tx, mut incoming_rx) = futures::channel::mpsc::channel::>(2048); + let (incoming_tx, mut incoming_rx) = + futures::channel::mpsc::channel::(2048); let response = upgrade .on_failed_upgrade({ let mut outgoing_tx = outgoing_tx.clone(); @@ -134,18 +134,11 @@ where let Some(incoming) = incoming else { break; }; - match incoming { - Ok(message) => { - if let Err(err) = session.send(Message::Binary(message)).await { - _ = outgoing_tx.start_send(Err(InputStreamError::from_server_fn_error(ServerFnErrorErr::Request(err.to_string())).ser())); - } - } - Err(err) => { - _ = outgoing_tx.start_send(Err(err)); - } + if let Err(err) = session.send(Message::Binary(incoming)).await { + _ = outgoing_tx.start_send(Err(InputStreamError::from_server_fn_error(ServerFnErrorErr::Request(err.to_string())).ser())); } }, - outgoing = session.recv().fuse() => { + outgoing = session.recv().fuse() => { let Some(outgoing) = outgoing else { break; }; @@ -153,14 +146,20 @@ where Ok(Message::Binary(bytes)) => { _ = outgoing_tx .start_send( - Ok(bytes), + crate::deserialize_result::(bytes), ); } Ok(Message::Text(text)) => { - _ = outgoing_tx.start_send(Ok(Bytes::from(text))); + _ = outgoing_tx.start_send(crate::deserialize_result::(Bytes::from(text))); + } + Ok(Message::Ping(bytes)) => { + if session.send(Message::Pong(bytes)).await.is_err() { + break; + } } Ok(_other) => {} Err(e) => { + println!("2"); _ = outgoing_tx.start_send(Err(InputStreamError::from_server_fn_error(ServerFnErrorErr::Response(e.to_string())).ser())); } } diff --git a/server_fn/src/request/generic.rs b/server_fn/src/request/generic.rs index 7da7beeb0..3e72adb27 100644 --- a/server_fn/src/request/generic.rs +++ b/server_fn/src/request/generic.rs @@ -79,7 +79,7 @@ where ) -> Result< ( impl Stream> + Send + 'static, - impl Sink> + Send + 'static, + impl Sink + Send + 'static, Self::WebsocketResponse, ), Error, @@ -87,7 +87,7 @@ where Err::< ( futures::stream::Once>>, - futures::sink::Drain>, + futures::sink::Drain, Self::WebsocketResponse, ), _, diff --git a/server_fn/src/request/mod.rs b/server_fn/src/request/mod.rs index a1c808f45..f10a1b361 100644 --- a/server_fn/src/request/mod.rs +++ b/server_fn/src/request/mod.rs @@ -360,7 +360,7 @@ where Output = Result< ( impl Stream> + Send + 'static, - impl Sink> + Send + 'static, + impl Sink + Send + 'static, Self::WebsocketResponse, ), Error, @@ -415,7 +415,7 @@ where ) -> Result< ( impl Stream> + Send + 'static, - impl Sink> + Send + 'static, + impl Sink + Send + 'static, Self::WebsocketResponse, ), Error, @@ -424,7 +424,7 @@ where Err::< ( futures::stream::Once>>, - futures::sink::Drain>, + futures::sink::Drain, Self::WebsocketResponse, ), _,