fix: send/receive websocket data

This commit is contained in:
Saber Haj Rabiee
2025-04-15 07:43:47 -07:00
parent 6d5ab73594
commit 30c445a419
6 changed files with 187 additions and 118 deletions

View File

@@ -42,7 +42,7 @@ pub trait Client<Error, InputStreamError = Error, OutputStreamError = Error> {
Output = Result<
(
impl Stream<Item = Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Bytes> + Send + 'static,
),
Error,
>,
@@ -62,8 +62,8 @@ pub mod browser {
response::browser::BrowserResponse,
};
use bytes::Bytes;
use futures::{Sink, SinkExt, StreamExt, TryStreamExt};
use gloo_net::websocket::{events::CloseEvent, Message, WebSocketError};
use futures::{Sink, SinkExt, StreamExt};
use gloo_net::websocket::{Message, WebSocketError};
use send_wrapper::SendWrapper;
use std::future::Future;
@@ -115,7 +115,7 @@ pub mod browser {
impl futures::Stream<Item = Result<Bytes, Bytes>>
+ Send
+ 'static,
impl futures::Sink<Result<Bytes, Bytes>> + Send + 'static,
impl futures::Sink<Bytes> + Send + 'static,
),
Error,
>,
@@ -131,18 +131,23 @@ pub mod browser {
})?;
let (sink, stream) = websocket.split();
let stream = stream
.map_err(|err| {
let stream = stream.map(|message| match message {
Ok(message) => {
crate::deserialize_result::<OutputStreamError>(
match message {
Message::Text(text) => Bytes::from(text),
Message::Bytes(bytes) => Bytes::from(bytes),
},
)
}
Err(err) => {
web_sys::console::error_1(&err.to_string().into());
OutputStreamError::from_server_fn_error(
Err(OutputStreamError::from_server_fn_error(
ServerFnErrorErr::Request(err.to_string()),
)
.ser()
})
.map_ok(move |msg| match msg {
Message::Text(text) => Bytes::from(text),
Message::Bytes(bytes) => Bytes::from(bytes),
});
.ser())
}
});
let stream = SendWrapper::new(stream);
struct SendWrapperSink<S> {
@@ -195,29 +200,11 @@ pub mod browser {
}
}
let sink =
sink.with(|message: Result<Bytes, Bytes>| async move {
match message {
Ok(message) => Ok(Message::Bytes(message.into())),
Err(err) => {
let err = InputStreamError::de(err);
let formatted_err = format!("{:?}", err);
web_sys::console::error_1(
&js_sys::JsString::from(
formatted_err.clone(),
),
);
const CLOSE_CODE_ERROR: u16 = 1011;
Err(WebSocketError::ConnectionClose(
CloseEvent {
code: CLOSE_CODE_ERROR,
reason: formatted_err,
was_clean: true,
},
))
}
}
});
let sink = sink.with(|message: Bytes| async move {
Ok::<Message, WebSocketError>(Message::Bytes(
message.into(),
))
});
let sink = SendWrapperSink::new(Box::pin(sink));
Ok((stream, sink))
@@ -246,13 +233,19 @@ pub mod reqwest {
/// Implements [`Client`] for a request made by [`reqwest`].
pub struct ReqwestClient;
impl<E: FromServerFnError + Send + 'static> Client<E> for ReqwestClient {
impl<
Error: FromServerFnError,
InputStreamError: FromServerFnError,
OutputStreamError: FromServerFnError,
> Client<Error, InputStreamError, OutputStreamError> for ReqwestClient
{
type Request = Request;
type Response = Response;
fn send(
req: Self::Request,
) -> impl Future<Output = Result<Self::Response, E>> + Send {
) -> impl Future<Output = Result<Self::Response, Error>> + Send
{
CLIENT.execute(req).map_err(|e| {
ServerFnErrorErr::Request(e.to_string()).into_app_error()
})
@@ -262,12 +255,10 @@ pub mod reqwest {
path: &str,
) -> Result<
(
impl futures::Stream<Item = Result<bytes::Bytes, Bytes>>
+ Send
+ 'static,
impl futures::Sink<Result<bytes::Bytes, Bytes>> + Send + 'static,
impl futures::Stream<Item = Result<Bytes, Bytes>> + Send + 'static,
impl futures::Sink<Bytes> + Send + 'static,
),
E,
Error,
> {
let mut websocket_server_url = get_server_url().to_string();
if let Some(postfix) = websocket_server_url.strip_prefix("http://")
@@ -281,7 +272,7 @@ pub mod reqwest {
let url = format!("{}{}", websocket_server_url, path);
let (ws_stream, _) =
tokio_tungstenite::connect_async(url).await.map_err(|e| {
E::from_server_fn_error(ServerFnErrorErr::Request(
Error::from_server_fn_error(ServerFnErrorErr::Request(
e.to_string(),
))
})?;
@@ -290,26 +281,21 @@ pub mod reqwest {
Ok((
read.map(|msg| match msg {
Ok(msg) => Ok(msg.into_data()),
Err(e) => Err(E::from_server_fn_error(
Ok(msg) => crate::deserialize_result::<OutputStreamError>(
msg.into_data(),
),
Err(e) => Err(OutputStreamError::from_server_fn_error(
ServerFnErrorErr::Request(e.to_string()),
)
.ser()),
}),
write.with(|msg: Result<Bytes, Bytes>| async move {
match msg {
Ok(msg) => {
Ok(tokio_tungstenite::tungstenite::Message::Binary(
msg,
))
}
Err(err) => {
let err = E::de(err);
Err(tokio_tungstenite::tungstenite::Error::Io(
std::io::Error::other(format!("{:?}", err)),
))
}
}
write.with(|msg: Bytes| async move {
Ok::<
tokio_tungstenite::tungstenite::Message,
tokio_tungstenite::tungstenite::Error,
>(
tokio_tungstenite::tungstenite::Message::Binary(msg)
)
}),
))
}

View File

@@ -136,6 +136,7 @@ use base64::{engine::general_purpose::STANDARD_NO_PAD, DecodeError, Engine};
// re-exported to make it possible to implement a custom Client without adding a separate
// dependency on `bytes`
pub use bytes::Bytes;
use bytes::{BufMut, BytesMut};
use client::Client;
use codec::{Encoding, FromReq, FromRes, IntoReq, IntoRes};
#[doc(hidden)]
@@ -656,14 +657,17 @@ where
let output = server_fn(input.into()).await?;
let output = output.stream.map(|output| match output {
Ok(output) => OutputEncoding::encode(&output).map_err(|e| {
OutputStreamError::from_server_fn_error(
ServerFnErrorErr::Serialization(e.to_string()),
)
.ser()
}),
Err(err) => Err(err.ser()),
let output = output.stream.map(|output| {
let result = match output {
Ok(output) => OutputEncoding::encode(&output).map_err(|e| {
OutputStreamError::from_server_fn_error(
ServerFnErrorErr::Serialization(e.to_string()),
)
.ser()
}),
Err(err) => Err(err.ser()),
};
serialize_result(result)
});
Server::spawn(async move {
@@ -695,23 +699,21 @@ where
pin_mut!(input);
pin_mut!(sink);
while let Some(input) = input.stream.next().await {
if sink
.send(
input
.and_then(|input| {
InputEncoding::encode(&input).map_err(|e| {
InputStreamError::from_server_fn_error(
ServerFnErrorErr::Serialization(
e.to_string(),
),
)
})
})
.map_err(|e| e.ser()),
)
.await
.is_err()
{
let result = match input {
Ok(input) => {
InputEncoding::encode(&input).map_err(|e| {
InputStreamError::from_server_fn_error(
ServerFnErrorErr::Serialization(
e.to_string(),
),
)
.ser()
})
}
Err(err) => Err(err.ser()),
};
let result = serialize_result(result);
if sink.send(result).await.is_err() {
break;
}
}
@@ -740,6 +742,53 @@ where
}
}
/// Serializes a Result<Bytes, Bytes> into a single Bytes instance.
/// Format: [tag: u8][content: Bytes]
/// - Tag 0: Ok variant
/// - Tag 1: Err variant
pub(crate) fn serialize_result(result: Result<Bytes, Bytes>) -> Bytes {
match result {
Ok(bytes) => {
let mut buf = BytesMut::with_capacity(1 + bytes.len());
buf.put_u8(0); // Tag for Ok variant
buf.extend_from_slice(&bytes);
buf.freeze()
}
Err(bytes) => {
let mut buf = BytesMut::with_capacity(1 + bytes.len());
buf.put_u8(1); // Tag for Err variant
buf.extend_from_slice(&bytes);
buf.freeze()
}
}
}
/// Deserializes a Bytes instance back into a Result<Bytes, Bytes>.
pub(crate) fn deserialize_result<E: FromServerFnError>(
bytes: Bytes,
) -> Result<Bytes, Bytes> {
if bytes.is_empty() {
return Err(E::from_server_fn_error(
ServerFnErrorErr::Deserialization("Data is empty".into()),
)
.ser());
}
let tag = bytes[0];
let content = bytes.slice(1..);
match tag {
0 => Ok(content),
1 => Err(content),
_ => {
return Err(E::from_server_fn_error(
ServerFnErrorErr::Deserialization("Invalid data tag".into()),
)
.ser())
} // Invalid tag
}
}
/// Encode format type
pub enum Format {
/// Binary representation
@@ -1218,3 +1267,45 @@ pub mod mock {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::JsonEncoding;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
enum TestError {
ServerFnError(ServerFnErrorErr),
}
impl FromServerFnError for TestError {
type Encoder = JsonEncoding;
fn from_server_fn_error(value: ServerFnErrorErr) -> Self {
Self::ServerFnError(value)
}
}
#[test]
fn test_result_serialization() {
// Test Ok variant
let ok_result: Result<Bytes, Bytes> =
Ok(Bytes::from_static(b"success data"));
let serialized = serialize_result(ok_result);
let deserialized = deserialize_result::<TestError>(serialized);
assert!(deserialized.is_ok());
assert_eq!(deserialized.unwrap(), Bytes::from_static(b"success data"));
// Test Err variant
let err_result: Result<Bytes, Bytes> =
Err(Bytes::from_static(b"error details"));
let serialized = serialize_result(err_result);
let deserialized = deserialize_result::<TestError>(serialized);
assert!(deserialized.is_err());
assert_eq!(
deserialized.unwrap_err(),
Bytes::from_static(b"error details")
);
}
}

View File

@@ -117,7 +117,7 @@ where
) -> Result<
(
impl Stream<Item = Result<Bytes, Bytes>> + Send + 'static,
impl futures::Sink<Result<Bytes, Bytes>> + Send + 'static,
impl futures::Sink<Bytes> + Send + 'static,
Self::WebsocketResponse,
),
Error,
@@ -133,7 +133,7 @@ where
let (mut response_stream_tx, response_stream_rx) =
futures::channel::mpsc::channel(2048);
let (response_sink_tx, mut response_sink_rx) =
futures::channel::mpsc::channel::<Result<Bytes, Bytes>>(2048);
futures::channel::mpsc::channel::<Bytes>(2048);
actix_web::rt::spawn(async move {
loop {
@@ -142,16 +142,9 @@ where
let Some(incoming) = incoming else {
break;
};
match incoming {
Ok(message) => {
if let Err(err) = session.binary(message).await {
if let Err(err) = session.binary(incoming).await {
_ = response_stream_tx.start_send(Err(InputStreamError::from_server_fn_error(ServerFnErrorErr::Request(err.to_string())).ser()));
}
}
Err(err) => {
_ = response_stream_tx.start_send(Err(err));
}
}
},
outgoing = msg_stream.next().fuse() => {
let Some(outgoing) = outgoing else {
@@ -166,11 +159,11 @@ where
Ok(Message::Binary(bytes)) => {
_ = response_stream_tx
.start_send(
Ok(bytes),
crate::deserialize_result::<InputStreamError>(bytes),
);
}
Ok(Message::Text(text)) => {
_ = response_stream_tx.start_send(Ok(text.into_bytes()));
_ = response_stream_tx.start_send(crate::deserialize_result::<InputStreamError>(text.into_bytes()));
}
Ok(_other) => {
}

View File

@@ -79,7 +79,7 @@ where
) -> Result<
(
impl Stream<Item = Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Bytes> + Send + 'static,
Self::WebsocketResponse,
),
Error,
@@ -91,7 +91,7 @@ where
futures::stream::Once<
std::future::Ready<Result<Bytes, Bytes>>,
>,
futures::sink::Drain<Result<Bytes, Bytes>>,
futures::sink::Drain<Bytes>,
Self::WebsocketResponse,
),
Error,
@@ -117,9 +117,9 @@ where
))
})?;
let (mut outgoing_tx, outgoing_rx) =
futures::channel::mpsc::channel(2048);
let (incoming_tx, mut incoming_rx) =
futures::channel::mpsc::channel::<Result<Bytes, Bytes>>(2048);
let (incoming_tx, mut incoming_rx) =
futures::channel::mpsc::channel::<Bytes>(2048);
let response = upgrade
.on_failed_upgrade({
let mut outgoing_tx = outgoing_tx.clone();
@@ -134,18 +134,11 @@ where
let Some(incoming) = incoming else {
break;
};
match incoming {
Ok(message) => {
if let Err(err) = session.send(Message::Binary(message)).await {
_ = outgoing_tx.start_send(Err(InputStreamError::from_server_fn_error(ServerFnErrorErr::Request(err.to_string())).ser()));
}
}
Err(err) => {
_ = outgoing_tx.start_send(Err(err));
}
if let Err(err) = session.send(Message::Binary(incoming)).await {
_ = outgoing_tx.start_send(Err(InputStreamError::from_server_fn_error(ServerFnErrorErr::Request(err.to_string())).ser()));
}
},
outgoing = session.recv().fuse() => {
outgoing = session.recv().fuse() => {
let Some(outgoing) = outgoing else {
break;
};
@@ -153,14 +146,20 @@ where
Ok(Message::Binary(bytes)) => {
_ = outgoing_tx
.start_send(
Ok(bytes),
crate::deserialize_result::<InputStreamError>(bytes),
);
}
Ok(Message::Text(text)) => {
_ = outgoing_tx.start_send(Ok(Bytes::from(text)));
_ = outgoing_tx.start_send(crate::deserialize_result::<InputStreamError>(Bytes::from(text)));
}
Ok(Message::Ping(bytes)) => {
if session.send(Message::Pong(bytes)).await.is_err() {
break;
}
}
Ok(_other) => {}
Err(e) => {
println!("2");
_ = outgoing_tx.start_send(Err(InputStreamError::from_server_fn_error(ServerFnErrorErr::Response(e.to_string())).ser()));
}
}

View File

@@ -79,7 +79,7 @@ where
) -> Result<
(
impl Stream<Item = Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Bytes> + Send + 'static,
Self::WebsocketResponse,
),
Error,
@@ -87,7 +87,7 @@ where
Err::<
(
futures::stream::Once<std::future::Ready<Result<Bytes, Bytes>>>,
futures::sink::Drain<Result<Bytes, Bytes>>,
futures::sink::Drain<Bytes>,
Self::WebsocketResponse,
),
_,

View File

@@ -360,7 +360,7 @@ where
Output = Result<
(
impl Stream<Item = Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Bytes> + Send + 'static,
Self::WebsocketResponse,
),
Error,
@@ -415,7 +415,7 @@ where
) -> Result<
(
impl Stream<Item = Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Result<Bytes, Bytes>> + Send + 'static,
impl Sink<Bytes> + Send + 'static,
Self::WebsocketResponse,
),
Error,
@@ -424,7 +424,7 @@ where
Err::<
(
futures::stream::Once<std::future::Ready<Result<Bytes, Bytes>>>,
futures::sink::Drain<Result<Bytes, Bytes>>,
futures::sink::Drain<Bytes>,
Self::WebsocketResponse,
),
_,