feat: support for both Send and !Send Framed impls (#189)

Using GATs we achieve genericity over Send and Sync marker traits.
The `FramedRead` and `FramedWrite` traits can now be used in both
single-threaded and multi-threaded contexts.
This commit is contained in:
Benoît Cortier 2023-09-05 21:52:52 -04:00 committed by GitHub
parent 37ac7052aa
commit 783167f23d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 190 additions and 44 deletions

View file

@ -1,26 +1,27 @@
use std::io; use std::io;
use std::pin::Pin;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use ironrdp_pdu::PduHint; use ironrdp_pdu::PduHint;
// TODO: use static async fn / return position impl trait in traits when stabiziled (https://github.com/rust-lang/rust/issues/91611) // TODO: investigate if we could use static async fn / return position impl trait in traits when stabilized:
// https://github.com/rust-lang/rust/issues/91611
pub trait FramedRead { pub trait FramedRead {
/// Reads from stream and fills internal buffer type ReadFut<'read>: std::future::Future<Output = io::Result<usize>> + 'read
fn read<'a>(
&'a mut self,
buf: &'a mut BytesMut,
) -> Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'a>>
where where
Self: 'a; Self: 'read;
/// Reads from stream and fills internal buffer
fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a>;
} }
pub trait FramedWrite { pub trait FramedWrite {
/// Writes an entire buffer into this stream. type WriteAllFut<'write>: std::future::Future<Output = io::Result<()>> + 'write
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'a>>
where where
Self: 'a; Self: 'write;
/// Writes an entire buffer into this stream.
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a>;
} }
pub trait StreamWrapper: Sized { pub trait StreamWrapper: Sized {

View file

@ -73,7 +73,9 @@ pub struct Config {
pub no_server_pointer: bool, pub no_server_pointer: bool,
} }
pub trait State: Send + Sync + core::fmt::Debug { ironrdp_pdu::assert_impl!(Config: Send, Sync);
pub trait State: Send + Sync + core::fmt::Debug + 'static {
fn name(&self) -> &'static str; fn name(&self) -> &'static str;
fn is_terminal(&self) -> bool; fn is_terminal(&self) -> bool;
fn as_any(&self) -> &dyn Any; fn as_any(&self) -> &dyn Any;
@ -81,11 +83,11 @@ pub trait State: Send + Sync + core::fmt::Debug {
ironrdp_pdu::assert_obj_safe!(State); ironrdp_pdu::assert_obj_safe!(State);
pub fn state_downcast<T: State + Any>(state: &dyn State) -> Option<&T> { pub fn state_downcast<T: State>(state: &dyn State) -> Option<&T> {
state.as_any().downcast_ref() state.as_any().downcast_ref()
} }
pub fn state_is<T: State + Any>(state: &dyn State) -> bool { pub fn state_is<T: State>(state: &dyn State) -> bool {
state.as_any().is::<T>() state.as_any().is::<T>()
} }

View file

@ -1,12 +1,12 @@
#[rustfmt::skip] // do not re-order this pub use
pub use ironrdp_async::*;
use std::io; use std::io;
use std::pin::Pin; use std::pin::Pin;
use bytes::BytesMut; use bytes::BytesMut;
use futures_util::io::{AsyncRead, AsyncWrite}; use futures_util::io::{AsyncRead, AsyncWrite};
#[rustfmt::skip] // do not re-order this pub use
pub use ironrdp_async::*;
pub type FuturesFramed<S> = Framed<FuturesStream<S>>; pub type FuturesFramed<S> = Framed<FuturesStream<S>>;
pub struct FuturesStream<S> { pub struct FuturesStream<S> {
@ -35,15 +35,13 @@ impl<S> StreamWrapper for FuturesStream<S> {
impl<S> FramedRead for FuturesStream<S> impl<S> FramedRead for FuturesStream<S>
where where
S: Unpin + AsyncRead, S: Send + Sync + Unpin + AsyncRead,
{ {
fn read<'a>( type ReadFut<'read> = Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + Send + Sync + 'read>>
&'a mut self,
buf: &'a mut BytesMut,
) -> Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'a>>
where where
Self: 'a, Self: 'read;
{
fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> {
use futures_util::io::AsyncReadExt as _; use futures_util::io::AsyncReadExt as _;
Box::pin(async { Box::pin(async {
@ -59,12 +57,81 @@ where
impl<S> FramedWrite for FuturesStream<S> impl<S> FramedWrite for FuturesStream<S>
where where
S: Unpin + AsyncWrite, S: Send + Sync + Unpin + AsyncWrite,
{ {
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'a>> type WriteAllFut<'write> = Pin<Box<dyn std::future::Future<Output = io::Result<()>> + Send + Sync + 'write>>
where where
Self: 'a, Self: 'write;
{
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
use futures_util::io::AsyncWriteExt as _;
Box::pin(async {
self.inner.write_all(buf).await?;
self.inner.flush().await?;
Ok(())
})
}
}
pub type SingleThreadedFuturesFramed<S> = Framed<SingleThreadedFuturesStream<S>>;
pub struct SingleThreadedFuturesStream<S> {
inner: S,
}
impl<S> StreamWrapper for SingleThreadedFuturesStream<S> {
type InnerStream = S;
fn from_inner(stream: Self::InnerStream) -> Self {
Self { inner: stream }
}
fn into_inner(self) -> Self::InnerStream {
self.inner
}
fn get_inner(&self) -> &Self::InnerStream {
&self.inner
}
fn get_inner_mut(&mut self) -> &mut Self::InnerStream {
&mut self.inner
}
}
impl<S> FramedRead for SingleThreadedFuturesStream<S>
where
S: Unpin + AsyncRead,
{
type ReadFut<'read> = Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'read>>
where
Self: 'read;
fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> {
use futures_util::io::AsyncReadExt as _;
Box::pin(async {
// NOTE(perf): tokio implementation is more efficient
let mut read_bytes = [0u8; 1024];
let len = self.inner.read(&mut read_bytes[..]).await?;
buf.extend_from_slice(&read_bytes[..len]);
Ok(len)
})
}
}
impl<S> FramedWrite for SingleThreadedFuturesStream<S>
where
S: Unpin + AsyncWrite,
{
type WriteAllFut<'write> = Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'write>>
where
Self: 'write;
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
use futures_util::io::AsyncWriteExt as _; use futures_util::io::AsyncWriteExt as _;
Box::pin(async { Box::pin(async {

View file

@ -315,7 +315,7 @@ pub fn find_size(bytes: &[u8]) -> PduResult<Option<PduInfo>> {
} }
} }
pub trait PduHint: core::fmt::Debug { pub trait PduHint: Send + Sync + core::fmt::Debug + 'static {
/// Finds next PDU size by reading the next few bytes. /// Finds next PDU size by reading the next few bytes.
fn find_size(&self, bytes: &[u8]) -> PduResult<Option<usize>>; fn find_size(&self, bytes: &[u8]) -> PduResult<Option<usize>>;
} }

View file

@ -174,7 +174,7 @@ macro_rules! cast_int {
/// Asserts that the traits support dynamic dispatch. /// Asserts that the traits support dynamic dispatch.
/// ///
/// From <https://docs.rs/static_assertions/latest/src/static_assertions/assert_obj_safe.rs.html#72-76> /// From <https://docs.rs/static_assertions/1.1.0/src/static_assertions/assert_obj_safe.rs.html#72-76>
#[macro_export] #[macro_export]
macro_rules! assert_obj_safe { macro_rules! assert_obj_safe {
($($xs:path),+ $(,)?) => { ($($xs:path),+ $(,)?) => {
@ -182,6 +182,20 @@ macro_rules! assert_obj_safe {
}; };
} }
/// Asserts that the type implements _all_ of the given traits.
///
/// From <https://docs.rs/static_assertions/1.1.0/src/static_assertions/assert_impl.rs.html#113-121>
#[macro_export]
macro_rules! assert_impl {
($type:ty: $($trait:path),+ $(,)?) => {
const _: fn() = || {
// Only callable when `$type` implements all traits in `$($trait)+`.
fn assert_impl_all<T: ?Sized $(+ $trait)+>() {}
assert_impl_all::<$type>();
};
};
}
/// Implements additional traits for a plain old data structure (POD). /// Implements additional traits for a plain old data structure (POD).
#[macro_export] #[macro_export]
macro_rules! impl_pdu_pod { macro_rules! impl_pdu_pod {

View file

@ -1,8 +1,10 @@
#[rustfmt::skip] // do not re-order this pub use
pub use ironrdp_async::*;
use std::io; use std::io;
use std::pin::Pin; use std::pin::Pin;
use bytes::BytesMut; use bytes::BytesMut;
pub use ironrdp_async::*;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
pub type TokioFramed<S> = Framed<TokioStream<S>>; pub type TokioFramed<S> = Framed<TokioStream<S>>;
@ -33,15 +35,13 @@ impl<S> StreamWrapper for TokioStream<S> {
impl<S> FramedRead for TokioStream<S> impl<S> FramedRead for TokioStream<S>
where where
S: Unpin + AsyncRead, S: Send + Sync + Unpin + AsyncRead,
{ {
fn read<'a>( type ReadFut<'read> = Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + Send + Sync + 'read>>
&'a mut self,
buf: &'a mut BytesMut,
) -> Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'a>>
where where
Self: 'a, Self: 'read;
{
fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> {
use tokio::io::AsyncReadExt as _; use tokio::io::AsyncReadExt as _;
Box::pin(async { self.inner.read_buf(buf).await }) Box::pin(async { self.inner.read_buf(buf).await })
@ -50,12 +50,74 @@ where
impl<S> FramedWrite for TokioStream<S> impl<S> FramedWrite for TokioStream<S>
where where
S: Unpin + AsyncWrite, S: Send + Sync + Unpin + AsyncWrite,
{ {
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'a>> type WriteAllFut<'write> = Pin<Box<dyn std::future::Future<Output = io::Result<()>> + Send + Sync + 'write>>
where where
Self: 'a, Self: 'write;
{
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
use tokio::io::AsyncWriteExt as _;
Box::pin(async {
self.inner.write_all(buf).await?;
self.inner.flush().await?;
Ok(())
})
}
}
pub type SingleThreadedTokioFramed<S> = Framed<SingleThreadedTokioStream<S>>;
pub struct SingleThreadedTokioStream<S> {
inner: S,
}
impl<S> StreamWrapper for SingleThreadedTokioStream<S> {
type InnerStream = S;
fn from_inner(stream: Self::InnerStream) -> Self {
Self { inner: stream }
}
fn into_inner(self) -> Self::InnerStream {
self.inner
}
fn get_inner(&self) -> &Self::InnerStream {
&self.inner
}
fn get_inner_mut(&mut self) -> &mut Self::InnerStream {
&mut self.inner
}
}
impl<S> FramedRead for SingleThreadedTokioStream<S>
where
S: Unpin + AsyncRead,
{
type ReadFut<'read> = Pin<Box<dyn std::future::Future<Output = io::Result<usize>> + 'read>>
where
Self: 'read;
fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a> {
use tokio::io::AsyncReadExt as _;
Box::pin(async { self.inner.read_buf(buf).await })
}
}
impl<S> FramedWrite for SingleThreadedTokioStream<S>
where
S: Unpin + AsyncWrite,
{
type WriteAllFut<'write> = Pin<Box<dyn std::future::Future<Output = io::Result<()>> + 'write>>
where
Self: 'write;
fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
use tokio::io::AsyncWriteExt as _; use tokio::io::AsyncWriteExt as _;
Box::pin(async { Box::pin(async {

View file

@ -325,7 +325,7 @@ impl Session {
.take() .take()
.expect("run called only once"); .expect("run called only once");
let mut framed = ironrdp_futures::FuturesFramed::new(rdp_reader); let mut framed = ironrdp_futures::SingleThreadedFuturesFramed::new(rdp_reader);
info!("Start RDP session"); info!("Start RDP session");
@ -565,7 +565,7 @@ async fn connect(
destination: String, destination: String,
pcb: Option<String>, pcb: Option<String>,
) -> Result<(connector::ConnectionResult, WebSocketCompat), IronRdpError> { ) -> Result<(connector::ConnectionResult, WebSocketCompat), IronRdpError> {
let mut framed = ironrdp_futures::FuturesFramed::new(ws); let mut framed = ironrdp_futures::SingleThreadedFuturesFramed::new(ws);
let mut connector = connector::ClientConnector::new(config) let mut connector = connector::ClientConnector::new(config)
.with_server_name(&destination) .with_server_name(&destination)