refactor: error handling

Base all library errors on `ironrdp_error::Error`, a lightweight and
`no_std`-compatible generic `Error` type.
A custom consumer-defined type (such as `PduErrorKind`) for
domain-specific details is wrapped by this type.
This commit is contained in:
Benoît CORTIER 2023-05-12 20:50:03 -04:00 committed by Benoît Cortier
parent 8857c3c25e
commit cf2287739d
56 changed files with 1416 additions and 999 deletions

View file

@ -24,6 +24,8 @@ Pay attention to the "**Architecture Invariant**" sections.
- `crates/ironrdp-session`: state machines to drive an RDP session.
- `crates/ironrdp-input`: utilities to manage and build input packets.
- `crates/ironrdp-rdcleanpath`: RDCleanPath PDU structure used by IronRDP web client and Devolutions Gateway.
- `crates/ironrdp-error`: lightweight and `no_std`-compatible generic `Error` and `Report` types.
The `Error` type wraps a custom consumer-defined type for domain-specific details (such as `PduErrorKind`).
**Architectural Invariant**: doing I/O is not allowed for these crates.

9
Cargo.lock generated
View file

@ -1544,6 +1544,7 @@ name = "ironrdp-connector"
version = "0.1.0"
dependencies = [
"arbitrary",
"ironrdp-error",
"ironrdp-pdu",
"rand_core 0.6.4",
"rstest",
@ -1551,6 +1552,10 @@ dependencies = [
"tracing",
]
[[package]]
name = "ironrdp-error"
version = "0.1.0"
[[package]]
name = "ironrdp-futures"
version = "0.1.0"
@ -1579,6 +1584,7 @@ dependencies = [
"bmp",
"byteorder",
"expect-test",
"ironrdp-error",
"ironrdp-pdu",
"lazy_static",
"num-derive 0.3.3",
@ -1604,6 +1610,7 @@ dependencies = [
"byteorder",
"der-parser 8.2.0",
"expect-test",
"ironrdp-error",
"ironrdp-testsuite-core",
"lazy_static",
"md-5 0.10.5",
@ -1638,6 +1645,7 @@ version = "0.1.0"
dependencies = [
"bitflags 2.0.2",
"ironrdp-connector",
"ironrdp-error",
"ironrdp-graphics",
"ironrdp-pdu",
"sspi",
@ -1659,6 +1667,7 @@ version = "0.0.0"
dependencies = [
"anyhow",
"array-concat",
"expect-test",
"hex",
"ironrdp-connector",
"ironrdp-fuzzing",

View file

@ -24,6 +24,7 @@ categories = ["network-programming"]
expect-test = "1"
ironrdp-async = { version = "0.1", path = "crates/ironrdp-async" }
ironrdp-connector = { version = "0.1", path = "crates/ironrdp-connector" }
ironrdp-error = { version = "0.1", path = "crates/ironrdp-error" }
ironrdp-futures = { version = "0.1", path = "crates/ironrdp-futures" }
ironrdp-fuzzing = { path = "crates/ironrdp-fuzzing" }
ironrdp-graphics = { version = "0.1", path = "crates/ironrdp-graphics" }

View file

@ -1,4 +1,6 @@
use ironrdp_connector::{ClientConnector, ClientConnectorState, ConnectionResult, Sequence as _, State as _};
use ironrdp_connector::{
ClientConnector, ClientConnectorState, ConnectionResult, ConnectorResult, Sequence as _, State as _,
};
use crate::framed::{Framed, FramedRead, FramedWrite};
@ -7,10 +9,7 @@ pub struct ShouldUpgrade {
}
#[instrument(skip_all)]
pub async fn connect_begin<S>(
framed: &mut Framed<S>,
connector: &mut ClientConnector,
) -> ironrdp_connector::Result<ShouldUpgrade>
pub async fn connect_begin<S>(framed: &mut Framed<S>, connector: &mut ClientConnector) -> ConnectorResult<ShouldUpgrade>
where
S: Sync + FramedRead + FramedWrite,
{
@ -47,7 +46,7 @@ pub async fn connect_finalize<S>(
_: Upgraded,
framed: &mut Framed<S>,
mut connector: ClientConnector,
) -> ironrdp_connector::Result<ConnectionResult>
) -> ConnectorResult<ConnectionResult>
where
S: FramedRead + FramedWrite,
{
@ -78,7 +77,7 @@ pub async fn single_connect_step<S>(
framed: &mut Framed<S>,
connector: &mut ClientConnector,
buf: &mut Vec<u8>,
) -> ironrdp_connector::Result<ironrdp_connector::Written>
) -> ConnectorResult<ironrdp_connector::Written>
where
S: FramedWrite + FramedRead,
{
@ -92,7 +91,7 @@ where
let pdu = framed
.read_by_hint(next_pdu_hint)
.await
.map_err(|e| ironrdp_connector::Error::new("read frame by hint").with_custom(e))?;
.map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;
trace!(length = pdu.len(), "PDU received");
@ -107,7 +106,7 @@ where
framed
.write_all(response)
.await
.map_err(|e| ironrdp_connector::Error::new("write all").with_custom(e))?;
.map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
}
Ok(written)

View file

@ -301,10 +301,9 @@ impl Config {
.pipe(u32::try_from)
.unwrap(),
client_name: whoami::hostname(),
client_dir: std::env::current_dir()
.expect("current directory")
.to_string_lossy()
.into_owned(),
// NOTE: hardcode this value like in freerdp
// https://github.com/FreeRDP/FreeRDP/blob/4e24b966c86fdf494a782f0dfcfc43a057a2ea60/libfreerdp/core/settings.c#LL49C34-L49C70
client_dir: "C:\\Windows\\System32\\mstscax.dll".to_owned(),
platform: match whoami::platform() {
whoami::Platform::Windows => MajorPlatformType::Windows,
whoami::Platform::Linux => MajorPlatformType::Unix,

View file

@ -208,8 +208,8 @@ impl GuiContext {
graphics_context.set_buffer(image_buffer, width, height);
}
Event::UserEvent(RdpOutputEvent::ConnectionFailure(error)) => {
error!(%error);
println!("Connection error: {error:#}");
error!(?error);
println!("Connection error: {}", error.report());
control_flow.set_exit_with_code(exitcode::PROTOCOL);
}
Event::UserEvent(RdpOutputEvent::Terminated(result)) => {
@ -219,8 +219,8 @@ impl GuiContext {
exitcode::OK
}
Err(error) => {
error!(error = format!("{error:#}"));
println!("Active session error: {error:#}");
error!(?error);
println!("Active session error: {}", error.report());
exitcode::PROTOCOL
}
};

View file

@ -1,7 +1,8 @@
use ironrdp::connector::{ConnectionResult, ConnectorResult};
use ironrdp::graphics::image_processing::PixelFormat;
use ironrdp::pdu::input::fast_path::FastPathInputEvent;
use ironrdp::session::image::DecodedImage;
use ironrdp::session::{ActiveStage, ActiveStageOutput};
use ironrdp::session::{ActiveStage, ActiveStageOutput, SessionResult};
use ironrdp::{connector, session};
use smallvec::SmallVec;
use sspi::network_client::reqwest_network_client::RequestClientFactory;
@ -14,8 +15,8 @@ use crate::config::Config;
#[derive(Debug)]
pub enum RdpOutputEvent {
Image { buffer: Vec<u32>, width: u16, height: u16 },
ConnectionFailure(connector::Error),
Terminated(session::Result<()>),
ConnectionFailure(connector::ConnectorError),
Terminated(SessionResult<()>),
}
#[derive(Debug)]
@ -80,15 +81,15 @@ enum RdpControlFlow {
type UpgradedFramed = ironrdp_tokio::TokioFramed<ironrdp_tls::TlsStream<TcpStream>>;
async fn connect(config: &Config) -> connector::Result<(connector::ConnectionResult, UpgradedFramed)> {
async fn connect(config: &Config) -> ConnectorResult<(ConnectionResult, UpgradedFramed)> {
let server_addr = config
.destination
.lookup_addr()
.map_err(|e| connector::Error::new("lookup addr").with_custom(e))?;
.map_err(|e| connector::custom_err!("lookup addr", e))?;
let stream = TcpStream::connect(&server_addr)
.await
.map_err(|e| connector::Error::new("TCP connect").with_custom(e))?;
.map_err(|e| connector::custom_err!("TCP connect", e))?;
let mut framed = ironrdp_tokio::TokioFramed::new(stream);
@ -106,7 +107,7 @@ async fn connect(config: &Config) -> connector::Result<(connector::ConnectionRes
let (upgraded_stream, server_public_key) = ironrdp_tls::upgrade(initial_stream, config.destination.name())
.await
.map_err(|e| connector::Error::new("TLS upgrade").with_custom(e))?;
.map_err(|e| connector::custom_err!("TLS upgrade", e))?;
let upgraded = ironrdp_tokio::mark_as_upgraded(should_upgrade, &mut connector, server_public_key);
@ -119,10 +120,10 @@ async fn connect(config: &Config) -> connector::Result<(connector::ConnectionRes
async fn active_session(
mut framed: UpgradedFramed,
connection_result: connector::ConnectionResult,
connection_result: ConnectionResult,
event_loop_proxy: &EventLoopProxy<RdpOutputEvent>,
input_event_receiver: &mut mpsc::UnboundedReceiver<RdpInputEvent>,
) -> session::Result<RdpControlFlow> {
) -> SessionResult<RdpControlFlow> {
let mut image = DecodedImage::new(
PixelFormat::RgbA32,
connection_result.desktop_size.width,
@ -134,14 +135,14 @@ async fn active_session(
'outer: loop {
tokio::select! {
frame = framed.read_pdu() => {
let (action, payload) = frame.map_err(|e| session::Error::new("read frame").with_custom(e))?;
let (action, payload) = frame.map_err(|e| session::custom_err!("read frame", e))?;
trace!(?action, frame_length = payload.len(), "Frame received");
let outputs = active_stage.process(&mut image, action, &payload)?;
for out in outputs {
match out {
ActiveStageOutput::ResponseFrame(frame) => framed.write_all(&frame).await.map_err(|e| session::Error::new("write response").with_custom(e))?,
ActiveStageOutput::ResponseFrame(frame) => framed.write_all(&frame).await.map_err(|e| session::custom_err!("write response", e))?,
ActiveStageOutput::GraphicsUpdate(_region) => {
let buffer: Vec<u32> = image
.data()
@ -160,14 +161,14 @@ async fn active_session(
width: image.width(),
height: image.height(),
})
.map_err(|e| session::Error::new("event_loop_proxy").with_custom(e))?;
.map_err(|e| session::custom_err!("event_loop_proxy", e))?;
}
ActiveStageOutput::Terminate => break 'outer,
}
}
}
input_event = input_event_receiver.recv() => {
let input_event = input_event.ok_or(session::Error::new("GUI is stopped"))?;
let input_event = input_event.ok_or_else(|| session::general_err!("GUI is stopped"))?;
match input_event {
RdpInputEvent::Resize { mut width, mut height } => {
@ -201,9 +202,9 @@ async fn active_session(
let mut frame = Vec::new();
fastpath_input
.to_buffer(&mut frame)
.map_err(|e| session::Error::new("FastPathInput encode").with_custom(e))?;
.map_err(|e| session::custom_err!("FastPathInput encode", e))?;
framed.write_all(&frame).await.map_err(|e| session::Error::new("write FastPathInput PDU").with_custom(e))?;
framed.write_all(&frame).await.map_err(|e| session::custom_err!("write FastPathInput PDU", e))?;
}
RdpInputEvent::Close => {
// TODO: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/27915739-8f77-487e-9927-55008af7fd68

View file

@ -19,9 +19,10 @@ test = false
arbitrary = ["dep:arbitrary"]
[dependencies]
ironrdp-pdu.workspace = true
tracing.workspace = true
sspi.workspace = true
rstest.workspace = true
rand_core = { version = "0.6.4", features = ["std"] } # TODO: dependency injection?
arbitrary = { version = "1", features = ["derive"], optional = true }
ironrdp-error.workspace = true
ironrdp-pdu.workspace = true
rand_core = { version = "0.6.4", features = ["std"] } # TODO: dependency injection?
rstest.workspace = true
sspi.workspace = true
tracing.workspace = true

View file

@ -2,7 +2,7 @@ use std::mem;
use ironrdp_pdu::{mcs, PduHint};
use crate::{Error, Result, Sequence, State, Written};
use crate::{ConnectorError, ConnectorErrorExt as _, ConnectorResult, Sequence, State, Written};
#[derive(Default, Debug)]
#[non_exhaustive]
@ -81,10 +81,10 @@ impl Sequence for ChannelConnectionSequence {
}
}
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<Written> {
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> ConnectorResult<Written> {
let (written, next_state) = match mem::take(&mut self.state) {
ChannelConnectionState::Consumed => {
return Err(Error::new(
return Err(general_err!(
"channel connection sequence state is consumed (this is a bug)",
))
}
@ -97,7 +97,7 @@ impl Sequence for ChannelConnectionSequence {
debug!(message = ?erect_domain_request, "Send");
let written = ironrdp_pdu::encode_buf(&erect_domain_request, output)?;
let written = ironrdp_pdu::encode_buf(&erect_domain_request, output).map_err(ConnectorError::pdu)?;
(
Written::from_size(written)?,
@ -110,7 +110,7 @@ impl Sequence for ChannelConnectionSequence {
debug!(message = ?attach_user_request, "Send");
let written = ironrdp_pdu::encode_buf(&attach_user_request, output)?;
let written = ironrdp_pdu::encode_buf(&attach_user_request, output).map_err(ConnectorError::pdu)?;
(
Written::from_size(written)?,
@ -119,7 +119,8 @@ impl Sequence for ChannelConnectionSequence {
}
ChannelConnectionState::WaitAttachUserConfirm => {
let attach_user_confirm = ironrdp_pdu::decode::<mcs::AttachUserConfirm>(input)?;
let attach_user_confirm =
ironrdp_pdu::decode::<mcs::AttachUserConfirm>(input).map_err(ConnectorError::pdu)?;
let user_channel_id = attach_user_confirm.initiator_id;
@ -152,7 +153,7 @@ impl Sequence for ChannelConnectionSequence {
debug!(message = ?channel_join_request, "Send");
let written = ironrdp_pdu::encode_buf(&channel_join_request, output)?;
let written = ironrdp_pdu::encode_buf(&channel_join_request, output).map_err(ConnectorError::pdu)?;
(
Written::from_size(written)?,
@ -163,7 +164,8 @@ impl Sequence for ChannelConnectionSequence {
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id, index } => {
let channel_id = self.channel_ids[index];
let channel_join_confirm = ironrdp_pdu::decode::<mcs::ChannelJoinConfirm>(input)?;
let channel_join_confirm =
ironrdp_pdu::decode::<mcs::ChannelJoinConfirm>(input).map_err(ConnectorError::pdu)?;
debug!(message = ?channel_join_confirm, "Received");
@ -171,7 +173,7 @@ impl Sequence for ChannelConnectionSequence {
|| channel_join_confirm.channel_id != channel_join_confirm.requested_channel_id
|| channel_join_confirm.channel_id != channel_id
{
return Err(Error::new("received bad MCS Channel Join Confirm"));
return Err(general_err!("received bad MCS Channel Join Confirm"));
}
let next_index = index + 1;
@ -188,7 +190,7 @@ impl Sequence for ChannelConnectionSequence {
(Written::Nothing, next_state)
}
ChannelConnectionState::AllJoined { .. } => return Err(Error::new("all channels are already joined")),
ChannelConnectionState::AllJoined { .. } => return Err(general_err!("all channels are already joined")),
};
self.state = next_state;

View file

@ -8,7 +8,10 @@ use sspi::credssp;
use crate::channel_connection::{ChannelConnectionSequence, ChannelConnectionState};
use crate::connection_finalization::ConnectionFinalizationSequence;
use crate::license_exchange::LicenseExchangeSequence;
use crate::{legacy, Config, DesktopSize, Error, Result, Sequence, ServerName, State, StaticChannels, Written};
use crate::{
legacy, Config, ConnectorError, ConnectorErrorExt as _, ConnectorErrorKind, ConnectorResult, DesktopSize, Sequence,
ServerName, State, StaticChannels, Written,
};
#[derive(Clone, Copy, Debug)]
pub struct CredsspTsRequestHint;
@ -16,11 +19,11 @@ pub struct CredsspTsRequestHint;
pub const CREDSSP_TS_REQUEST_HINT: CredsspTsRequestHint = CredsspTsRequestHint;
impl PduHint for CredsspTsRequestHint {
fn find_size(&self, bytes: &[u8]) -> ironrdp_pdu::Result<Option<usize>> {
fn find_size(&self, bytes: &[u8]) -> ironrdp_pdu::PduResult<Option<usize>> {
match sspi::credssp::TsRequest::read_length(bytes) {
Ok(length) => Ok(Some(length)),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
Err(e) => Err(ironrdp_pdu::Error::custom(e)),
Err(e) => Err(ironrdp_pdu::custom_err!("CredsspTsRequestHint", e)),
}
}
}
@ -31,7 +34,7 @@ pub struct CredsspEarlyUserAuthResultHint;
pub const CREDSSP_EARLY_USER_AUTH_RESULT_HINT: CredsspEarlyUserAuthResultHint = CredsspEarlyUserAuthResultHint;
impl PduHint for CredsspEarlyUserAuthResultHint {
fn find_size(&self, _: &[u8]) -> ironrdp_pdu::Result<Option<usize>> {
fn find_size(&self, _: &[u8]) -> ironrdp_pdu::PduResult<Option<usize>> {
Ok(Some(sspi::credssp::EARLY_USER_AUTH_RESULT_PDU_SIZE))
}
}
@ -274,11 +277,11 @@ impl Sequence for ClientConnector {
&self.state
}
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<Written> {
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> ConnectorResult<Written> {
let (written, next_state) = match mem::take(&mut self.state) {
// Invalid state
ClientConnectorState::Consumed => {
return Err(Error::new("connector sequence state is consumed (this is a bug)"))
return Err(general_err!("connector sequence state is consumed (this is a bug)",))
}
//== Connection Initiation ==//
@ -292,7 +295,7 @@ impl Sequence for ClientConnector {
debug!(message = ?connection_request, "Send");
let written = ironrdp_pdu::encode_buf(&connection_request, output)?;
let written = ironrdp_pdu::encode_buf(&connection_request, output).map_err(ConnectorError::pdu)?;
(
Written::from_size(written)?,
@ -300,7 +303,8 @@ impl Sequence for ClientConnector {
)
}
ClientConnectorState::ConnectionInitiationWaitConfirm => {
let connection_confirm = ironrdp_pdu::decode::<nego::ConnectionConfirm>(input)?;
let connection_confirm =
ironrdp_pdu::decode::<nego::ConnectionConfirm>(input).map_err(ConnectorError::pdu)?;
debug!(message = ?connection_confirm, "Received");
@ -308,14 +312,14 @@ impl Sequence for ClientConnector {
nego::ConnectionConfirm::Response { flags, protocol } => (flags, protocol),
nego::ConnectionConfirm::Failure { code } => {
error!(?code, "Received connection failure code");
return Err(Error::new("connection failed"));
return Err(general_err!("connection failed"));
}
};
info!(?selected_protocol, ?flags, "Server confirmed connection");
if !self.config.security_protocol.contains(selected_protocol) {
return Err(Error::new(
return Err(general_err!(
"server selected a security protocol that is unsupported by this client",
));
}
@ -352,17 +356,17 @@ impl Sequence for ClientConnector {
let server_public_key = self
.server_public_key
.take()
.ok_or(Error::new("server public key is missing"))?;
.ok_or_else(|| general_err!("server public key is missing"))?;
let network_client_factory = self
.network_client_factory
.take()
.ok_or(Error::new("CredSSP network client factory is missing"))?;
.ok_or_else(|| general_err!("CredSSP network client factory is missing"))?;
let server_name = self
.server_name
.take()
.ok_or(Error::new("server name is missing"))?
.ok_or_else(|| general_err!("server name is missing"))?
.into_inner();
let service_principal_name = format!("TERMSRV/{server_name}");
@ -378,11 +382,14 @@ impl Sequence for ClientConnector {
network_client_factory,
}),
service_principal_name,
)?;
)
.map_err(|e| ConnectorError::new("CredSSP", ConnectorErrorKind::Credssp(e)))?;
let initial_ts_request = credssp::TsRequest::default();
let result = credssp_client.process(initial_ts_request)?;
let result = credssp_client
.process(initial_ts_request)
.map_err(|e| ConnectorError::new("CredSSP", ConnectorErrorKind::Credssp(e)))?;
let (ts_request_from_client, next_state) = match result {
credssp::ClientState::ReplyNeeded(ts_request) => (
@ -409,9 +416,11 @@ impl Sequence for ClientConnector {
mut credssp_client,
} => {
let ts_request_from_server = credssp::TsRequest::from_buffer(input)
.map_err(|e| Error::new("CredSSP").with_reason(format!("TsRequest decode: {e}")))?;
.map_err(|e| reason_err!("CredSSP", "TsRequest decode: {e}"))?;
let result = credssp_client.process(ts_request_from_server)?;
let result = credssp_client
.process(ts_request_from_server)
.map_err(|e| ConnectorError::new("CredSSP", ConnectorErrorKind::Credssp(e)))?;
let (ts_request_from_client, next_state) = match result {
credssp::ClientState::ReplyNeeded(ts_request) => (
@ -439,10 +448,10 @@ impl Sequence for ClientConnector {
}
ClientConnectorState::CredsspEarlyUserAuthResult { selected_protocol } => {
let early_user_auth_result = credssp::EarlyUserAuthResult::from_buffer(input)
.map_err(|e| Error::new("CredSSP").with_reason(format!("EarlyUserAuthResult decode: {e}")))?;
.map_err(|e| custom_err!("credssp::EarlyUserAuthResult", e))?;
let credssp::EarlyUserAuthResult::Success = early_user_auth_result else {
return Err(Error::new("CredSSP").with_kind(crate::ErrorKind::AccessDenied));
return Err(ConnectorError::new("CredSSP", ConnectorErrorKind::AccessDenied));
};
(
@ -484,7 +493,7 @@ impl Sequence for ClientConnector {
if client_gcc_blocks.security == gcc::ClientSecurityData::no_security()
&& server_gcc_blocks.security != gcc::ServerSecurityData::no_security()
{
return Err(Error::new("cant satisfy server security settings"));
return Err(general_err!("cant satisfy server security settings"));
}
if server_gcc_blocks.message_channel.is_some() {
@ -561,7 +570,7 @@ impl Sequence for ClientConnector {
static_channels,
} => {
if selected_protocol == nego::SecurityProtocol::RDP {
return Err(Error::new("standard RDP Security (RC4 encryption) is not supported"));
return Err(general_err!("standard RDP Security (RC4 encryption) is not supported"));
}
(
@ -584,7 +593,7 @@ impl Sequence for ClientConnector {
let routing_addr = self
.server_addr
.as_ref()
.ok_or(Error::new("server address is missing"))?;
.ok_or_else(|| general_err!("server address is missing"))?;
let client_info = create_client_info_pdu(&self.config, routing_addr);
@ -690,7 +699,9 @@ impl Sequence for ClientConnector {
{
server_demand_active.pdu.capability_sets
} else {
return Err(Error::new("unexpected Share Control Pdu (expected ServerDemandActive)"));
return Err(general_err!(
"unexpected Share Control Pdu (expected ServerDemandActive)",
));
};
let desktop_size = capability_sets
@ -770,7 +781,7 @@ impl Sequence for ClientConnector {
//== Connected ==//
// The client connector job is done.
ClientConnectorState::Connected { .. } => return Err(Error::new("already connected")),
ClientConnectorState::Connected { .. } => return Err(general_err!("already connected")),
};
self.state = next_state;
@ -1024,7 +1035,7 @@ fn create_client_confirm_active(
}
}
fn write_credssp_request(ts_request: credssp::TsRequest, output: &mut Vec<u8>) -> crate::Result<usize> {
fn write_credssp_request(ts_request: credssp::TsRequest, output: &mut Vec<u8>) -> crate::ConnectorResult<usize> {
let length = usize::from(ts_request.buffer_len());
if output.len() < length {
@ -1033,7 +1044,7 @@ fn write_credssp_request(ts_request: credssp::TsRequest, output: &mut Vec<u8>) -
ts_request
.encode_ts_request(output.as_mut_slice())
.map_err(|e| Error::new("CredSSP").with_reason(format!("TsRequest encode: {e}")))?;
.map_err(|e| reason_err!("CredSSP", "TsRequest encode: {e}"))?;
Ok(length)
}

View file

@ -5,7 +5,7 @@ use ironrdp_pdu::rdp::headers::ShareDataPdu;
use ironrdp_pdu::rdp::{finalization_messages, server_error_info};
use ironrdp_pdu::PduHint;
use crate::{legacy, Error, Result, Sequence, State, Written};
use crate::{legacy, ConnectorResult, Sequence, State, Written};
#[derive(Default, Debug)]
#[non_exhaustive]
@ -81,10 +81,10 @@ impl Sequence for ConnectionFinalizationSequence {
&self.state
}
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<Written> {
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> ConnectorResult<Written> {
let (written, next_state) = match mem::take(&mut self.state) {
ConnectionFinalizationState::Consumed => {
return Err(Error::new(
return Err(general_err!(
"connection finalization sequence state is consumed (this is a bug)",
))
}
@ -170,7 +170,7 @@ impl Sequence for ConnectionFinalizationSequence {
debug!("Server Control (Cooperate)");
ConnectionFinalizationState::WaitForResponse
} else {
return Err(Error::new("invalid Control Cooperate PDU"));
return Err(general_err!("invalid Control Cooperate PDU"));
}
}
finalization_messages::ControlAction::GrantedControl => {
@ -187,10 +187,10 @@ impl Sequence for ConnectionFinalizationSequence {
debug!("Server Control (Granted Control)");
ConnectionFinalizationState::WaitForResponse
} else {
return Err(Error::new("invalid Granted Control PDU"));
return Err(general_err!("invalid Granted Control PDU"));
}
}
_ => return Err(Error::new("unexpected control action")),
_ => return Err(general_err!("unexpected control action")),
},
ShareDataPdu::ServerSetErrorInfo(server_error_info::ServerSetErrorInfoPdu(error_info)) => {
match error_info {
@ -198,9 +198,11 @@ impl Sequence for ConnectionFinalizationSequence {
server_error_info::ProtocolIndependentCode::None,
) => ConnectionFinalizationState::WaitForResponse,
_ => {
return Err(
Error::new("server returned error info").with_reason(error_info.description())
)
return Err(reason_err!(
"ServerSetErrorInfo",
"server returned error info: {}",
error_info.description()
));
}
}
}
@ -214,13 +216,13 @@ impl Sequence for ConnectionFinalizationSequence {
ConnectionFinalizationState::Finished
}
_ => return Err(Error::new("unexpected server message")),
_ => return Err(general_err!("unexpected server message")),
};
(Written::Nothing, next_state)
}
ConnectionFinalizationState::Finished => return Err(Error::new("finalization already finished")),
ConnectionFinalizationState::Finished => return Err(general_err!("finalization already finished")),
};
self.state = next_state;

View file

@ -4,10 +4,12 @@ use std::borrow::Cow;
use ironrdp_pdu::{rdp, x224, PduParsing};
pub fn encode_x224_packet<T: PduParsing>(x224_msg: &T, buf: &mut Vec<u8>) -> crate::Result<usize>
use crate::{ConnectorError, ConnectorErrorExt as _, ConnectorResult};
pub fn encode_x224_packet<T: PduParsing>(x224_msg: &T, buf: &mut Vec<u8>) -> ConnectorResult<usize>
where
T: PduParsing,
crate::Error: From<T::Error>,
ConnectorError: From<T::Error>,
{
let x224_msg_len = x224_msg.buffer_length();
let mut x224_msg_buf = Vec::with_capacity(x224_msg_len);
@ -18,17 +20,17 @@ where
data: Cow::Owned(x224_msg_buf),
};
let written = ironrdp_pdu::encode_buf(&pdu, buf)?;
let written = ironrdp_pdu::encode_buf(&pdu, buf).map_err(ConnectorError::pdu)?;
Ok(written)
}
pub fn decode_x224_packet<T>(src: &[u8]) -> crate::Result<T>
pub fn decode_x224_packet<T>(src: &[u8]) -> ConnectorResult<T>
where
T: PduParsing,
crate::Error: From<T::Error>,
ConnectorError: From<T::Error>,
{
let x224_payload = ironrdp_pdu::decode::<x224::X224Data>(src)?;
let x224_payload = ironrdp_pdu::decode::<x224::X224Data>(src).map_err(ConnectorError::pdu)?;
let x224_msg = T::from_buffer(x224_payload.data.as_ref())?;
Ok(x224_msg)
}
@ -38,10 +40,10 @@ pub fn encode_send_data_request<T>(
channel_id: u16,
user_msg: &T,
buf: &mut Vec<u8>,
) -> crate::Result<usize>
) -> ConnectorResult<usize>
where
T: PduParsing,
crate::Error: From<T::Error>,
ConnectorError: From<T::Error>,
{
let user_data_len = user_msg.buffer_length();
let mut user_data = Vec::with_capacity(user_data_len);
@ -54,7 +56,7 @@ where
user_data: Cow::Owned(user_data),
};
let written = ironrdp_pdu::encode_buf(&pdu, buf)?;
let written = ironrdp_pdu::encode_buf(&pdu, buf).map_err(ConnectorError::pdu)?;
Ok(written)
}
@ -67,20 +69,20 @@ pub struct SendDataIndicationCtx<'a> {
}
impl SendDataIndicationCtx<'_> {
pub fn decode_user_data<T>(&self) -> crate::Result<T>
pub fn decode_user_data<T>(&self) -> ConnectorResult<T>
where
T: PduParsing,
crate::Error: From<T::Error>,
ConnectorError: From<T::Error>,
{
let msg = T::from_buffer(self.user_data)?;
Ok(msg)
}
}
pub fn decode_send_data_indication(src: &[u8]) -> crate::Result<SendDataIndicationCtx<'_>> {
pub fn decode_send_data_indication(src: &[u8]) -> ConnectorResult<SendDataIndicationCtx<'_>> {
use ironrdp_pdu::mcs::McsMessage;
let mcs_msg = ironrdp_pdu::decode::<McsMessage>(src)?;
let mcs_msg = ironrdp_pdu::decode::<McsMessage>(src).map_err(ConnectorError::pdu)?;
match mcs_msg {
McsMessage::SendDataIndication(msg) => {
@ -94,10 +96,16 @@ pub fn decode_send_data_indication(src: &[u8]) -> crate::Result<SendDataIndicati
user_data,
})
}
McsMessage::DisconnectProviderUltimatum(msg) => {
Err(crate::Error::new("received disconnect provider ultimatum").with_reason(format!("{:?}", msg.reason)))
}
unexpected => Err(crate::Error::new("unexpected MCS message").with_reason(ironrdp_pdu::name(&unexpected))),
McsMessage::DisconnectProviderUltimatum(msg) => Err(reason_err!(
"decode_send_data_indication",
"received disconnect provider ultimatum: {:?}",
msg.reason
)),
unexpected => Err(reason_err!(
"decode_send_data_indication",
"unexpected MCS message: {}",
ironrdp_pdu::name(&unexpected)
)),
}
}
@ -107,7 +115,7 @@ pub fn encode_share_control(
share_id: u32,
pdu: rdp::headers::ShareControlPdu,
buf: &mut Vec<u8>,
) -> crate::Result<usize> {
) -> ConnectorResult<usize> {
let pdu_source = initiator_id;
let share_control_header = rdp::headers::ShareControlHeader {
@ -128,7 +136,7 @@ pub struct ShareControlCtx {
pub pdu: rdp::headers::ShareControlPdu,
}
pub fn decode_share_control(ctx: SendDataIndicationCtx<'_>) -> crate::Result<ShareControlCtx> {
pub fn decode_share_control(ctx: SendDataIndicationCtx<'_>) -> ConnectorResult<ShareControlCtx> {
let user_msg = ctx.decode_user_data::<rdp::headers::ShareControlHeader>()?;
Ok(ShareControlCtx {
@ -146,7 +154,7 @@ pub fn encode_share_data(
share_id: u32,
pdu: rdp::headers::ShareDataPdu,
buf: &mut Vec<u8>,
) -> crate::Result<usize> {
) -> ConnectorResult<usize> {
let share_data_header = rdp::headers::ShareDataHeader {
share_data_pdu: pdu,
stream_priority: rdp::headers::StreamPriority::Medium,
@ -168,11 +176,11 @@ pub struct ShareDataCtx {
pub pdu: rdp::headers::ShareDataPdu,
}
pub fn decode_share_data(ctx: SendDataIndicationCtx<'_>) -> crate::Result<ShareDataCtx> {
pub fn decode_share_data(ctx: SendDataIndicationCtx<'_>) -> ConnectorResult<ShareDataCtx> {
let ctx = decode_share_control(ctx)?;
let rdp::headers::ShareControlPdu::Data(share_data_header) = ctx.pdu else {
return Err(crate::Error::new("received unexpected Share Control Pdu (expected SHare Data Header)"));
return Err(general_err!("received unexpected Share Control Pdu (expected SHare Data Header)"));
};
Ok(ShareDataCtx {
@ -184,26 +192,6 @@ pub fn decode_share_data(ctx: SendDataIndicationCtx<'_>) -> crate::Result<ShareD
})
}
impl From<ironrdp_pdu::mcs::McsError> for crate::Error {
fn from(e: ironrdp_pdu::mcs::McsError) -> Self {
Self::new("MCS").with_reason(e.to_string())
}
}
impl From<ironrdp_pdu::rdp::server_license::ServerLicenseError> for crate::Error {
fn from(e: ironrdp_pdu::rdp::server_license::ServerLicenseError) -> Self {
Self::new("server license").with_reason(e.to_string())
}
}
impl From<ironrdp_pdu::rdp::RdpError> for crate::Error {
fn from(e: ironrdp_pdu::rdp::RdpError) -> Self {
Self::new("RDP").with_reason(e.to_string())
}
}
impl From<ironrdp_pdu::rdp::vc::ChannelError> for crate::Error {
fn from(e: ironrdp_pdu::rdp::vc::ChannelError) -> Self {
Self::new("virtual channel").with_reason(e.to_string())
}
impl ironrdp_error::legacy::CatchAllKind for crate::ConnectorErrorKind {
const CATCH_ALL_VALUE: Self = crate::ConnectorErrorKind::General;
}

View file

@ -1,6 +1,9 @@
#[macro_use]
extern crate tracing;
#[macro_use]
mod macros;
pub mod legacy;
mod channel_connection;
@ -109,10 +112,10 @@ pub enum Written {
impl Written {
#[inline]
pub fn from_size(value: usize) -> Result<Self> {
pub fn from_size(value: usize) -> ConnectorResult<Self> {
core::num::NonZeroUsize::new(value)
.map(Self::Size)
.ok_or(Error::new("invalid written length (cant be zero)"))
.ok_or(ConnectorError::general("invalid written length (cant be zero)"))
}
#[inline]
@ -135,150 +138,94 @@ pub trait Sequence: Send + Sync {
fn state(&self) -> &dyn State;
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<Written>;
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> ConnectorResult<Written>;
fn step_no_input(&mut self, output: &mut Vec<u8>) -> Result<Written> {
fn step_no_input(&mut self, output: &mut Vec<u8>) -> ConnectorResult<Written> {
self.step(&[], output)
}
}
ironrdp_pdu::assert_obj_safe!(Sequence);
pub type Result<T> = std::result::Result<T, Error>;
pub type ConnectorResult<T> = std::result::Result<T, ConnectorError>;
#[non_exhaustive]
#[derive(Debug)]
pub enum ErrorKind {
Pdu(ironrdp_pdu::Error),
pub enum ConnectorErrorKind {
Pdu(ironrdp_pdu::PduError),
Credssp(sspi::Error),
Reason(String),
AccessDenied,
Custom(Box<dyn std::error::Error + Sync + Send + 'static>),
General,
Custom,
}
#[derive(Debug)]
pub struct Error {
pub context: &'static str,
pub kind: ErrorKind,
pub reason: Option<String>,
impl fmt::Display for ConnectorErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self {
ConnectorErrorKind::Pdu(_) => write!(f, "PDU error"),
ConnectorErrorKind::Credssp(_) => write!(f, "CredSSP"),
ConnectorErrorKind::Reason(description) => write!(f, "reason: {description}"),
ConnectorErrorKind::AccessDenied => write!(f, "access denied"),
ConnectorErrorKind::General => write!(f, "general"),
ConnectorErrorKind::Custom => write!(f, "custom"),
}
}
}
impl Error {
pub fn new(context: &'static str) -> Self {
Self {
context,
kind: ErrorKind::General,
reason: None,
impl std::error::Error for ConnectorErrorKind {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self {
ConnectorErrorKind::Pdu(e) => Some(e),
ConnectorErrorKind::Credssp(e) => Some(e),
ConnectorErrorKind::Reason(_) => None,
ConnectorErrorKind::AccessDenied => None,
ConnectorErrorKind::Custom => None,
ConnectorErrorKind::General => None,
}
}
}
pub type ConnectorError = ironrdp_error::Error<ConnectorErrorKind>;
pub trait ConnectorErrorExt {
fn pdu(error: ironrdp_pdu::PduError) -> Self;
fn general(context: &'static str) -> Self;
fn reason(context: &'static str, reason: impl Into<String>) -> Self;
fn custom<E>(context: &'static str, e: E) -> Self
where
E: std::error::Error + Sync + Send + 'static;
}
impl ConnectorErrorExt for ConnectorError {
fn pdu(error: ironrdp_pdu::PduError) -> Self {
Self::new("invalid payload", ConnectorErrorKind::Pdu(error))
}
pub fn with_kind(mut self, kind: ErrorKind) -> Self {
self.kind = kind;
self
fn general(context: &'static str) -> Self {
Self::new(context, ConnectorErrorKind::General)
}
pub fn with_custom<E>(mut self, custom_error: E) -> Self
fn reason(context: &'static str, reason: impl Into<String>) -> Self {
Self::new(context, ConnectorErrorKind::Reason(reason.into()))
}
fn custom<E>(context: &'static str, e: E) -> Self
where
E: std::error::Error + Sync + Send + 'static,
{
self.kind = ErrorKind::Custom(Box::new(custom_error));
self
}
pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
self.reason = Some(reason.into());
self
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
ErrorKind::Pdu(e) => Some(e),
ErrorKind::Credssp(e) => Some(e),
ErrorKind::AccessDenied => None,
ErrorKind::Custom(e) => Some(e.as_ref()),
ErrorKind::General => None,
}
}
}
impl From<Error> for std::io::Error {
fn from(error: Error) -> Self {
std::io::Error::new(std::io::ErrorKind::Other, error)
}
}
impl From<ironrdp_pdu::Error> for Error {
fn from(value: ironrdp_pdu::Error) -> Self {
Self {
context: "invalid payload",
kind: ErrorKind::Pdu(value),
reason: None,
}
}
}
impl From<sspi::Error> for Error {
fn from(value: sspi::Error) -> Self {
Self {
context: "CredSSP",
kind: ErrorKind::Credssp(value),
reason: None,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.context)?;
match &self.kind {
ErrorKind::Pdu(e) => {
if f.alternate() {
write!(f, ": {e}")?;
}
}
ErrorKind::Credssp(e) => {
if f.alternate() {
write!(f, ": {e}")?;
}
}
ErrorKind::AccessDenied => {
write!(f, ": access denied")?;
}
ErrorKind::Custom(e) => {
if f.alternate() {
write!(f, ": {e}")?;
let mut next_source = e.source();
while let Some(e) = next_source {
write!(f, ", caused by: {e}")?;
next_source = e.source();
}
}
}
ErrorKind::General => {}
}
if let Some(reason) = &self.reason {
write!(f, " ({reason})")?;
}
Ok(())
Self::new(context, ConnectorErrorKind::Custom).with_source(e)
}
}
pub trait ConnectorResultExt {
fn with_context(self, context: &'static str) -> Self;
fn with_kind(self, kind: ErrorKind) -> Self;
fn with_custom<E>(self, custom_error: E) -> Self
fn with_source<E>(self, source: E) -> Self
where
E: std::error::Error + Sync + Send + 'static;
fn with_reason(self, reason: impl Into<String>) -> Self;
}
impl<T> ConnectorResultExt for Result<T> {
impl<T> ConnectorResultExt for ConnectorResult<T> {
fn with_context(self, context: &'static str) -> Self {
self.map_err(|mut e| {
e.context = context;
@ -286,27 +233,10 @@ impl<T> ConnectorResultExt for Result<T> {
})
}
fn with_kind(self, kind: ErrorKind) -> Self {
self.map_err(|mut e| {
e.kind = kind;
e
})
}
fn with_custom<E>(self, custom_error: E) -> Self
fn with_source<E>(self, source: E) -> Self
where
E: std::error::Error + Sync + Send + 'static,
{
self.map_err(|mut e| {
e.kind = ErrorKind::Custom(Box::new(custom_error));
e
})
}
fn with_reason(self, reason: impl Into<String>) -> Self {
self.map_err(|mut e| {
e.reason = Some(reason.into());
e
})
self.map_err(|e| e.with_source(source))
}
}

View file

@ -5,7 +5,7 @@ use ironrdp_pdu::PduHint;
use rand_core::{OsRng, RngCore as _};
use super::legacy;
use crate::{Error, Result, Sequence, State, Written};
use crate::{ConnectorResult, Sequence, State, Written};
#[derive(Default, Debug)]
#[non_exhaustive]
@ -79,10 +79,10 @@ impl Sequence for LicenseExchangeSequence {
&self.state
}
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> Result<Written> {
fn step(&mut self, input: &[u8], output: &mut Vec<u8>) -> ConnectorResult<Written> {
let (written, next_state) = match mem::take(&mut self.state) {
LicenseExchangeState::Consumed => {
return Err(Error::new(
return Err(general_err!(
"license exchange sequence state is consumed (this is a bug)",
))
}
@ -110,9 +110,7 @@ impl Sequence for LicenseExchangeSequence {
&self.username,
self.domain.as_deref().unwrap_or(""),
)
.map_err(|e| {
Error::new("unable to generate Client New License Request").with_reason(e.to_string())
})?;
.map_err(|e| custom_err!("ClientNewLicenseRequest", e))?;
trace!(?encryption_data, "Successfully generated Client New License Request");
info!(message = ?new_license_request, "Send");
@ -150,9 +148,7 @@ impl Sequence for LicenseExchangeSequence {
self.domain.as_deref().unwrap_or(""),
&encryption_data,
)
.map_err(|e| {
Error::new("unable to generate Client Platform Challenge Response").with_reason(e.to_string())
})?;
.map_err(|e| custom_err!("ClientPlatformChallengeResponse", e))?;
debug!(message = ?challenge_response, "Send");
@ -178,14 +174,14 @@ impl Sequence for LicenseExchangeSequence {
upgrade_license
.verify_server_license(&encryption_data)
.map_err(|e| Error::new("license verification failed").with_reason(e.to_string()))?;
.map_err(|e| custom_err!("license verification", e))?;
debug!("License verified with success");
(Written::Nothing, LicenseExchangeState::LicenseExchanged)
}
LicenseExchangeState::LicenseExchanged => return Err(Error::new("license already exchanged")),
LicenseExchangeState::LicenseExchanged => return Err(general_err!("license already exchanged")),
};
self.state = next_state;

View file

@ -0,0 +1,38 @@
/// Creates a `ConnectorError` with `General` kind
///
/// Shorthand for
/// ```rust
/// <crate::ConnectorError as crate::ConnectorErrorExt>::general(context)
/// ```
#[macro_export]
macro_rules! general_err {
( $context:expr $(,)? ) => {{
<$crate::ConnectorError as $crate::ConnectorErrorExt>::general($context)
}};
}
/// Creates a `ConnectorError` with `Reason` kind
///
/// Shorthand for
/// ```rust
/// <crate::ConnectorError as crate::ConnectorErrorExt>::reason(context, reason)
/// ```
#[macro_export]
macro_rules! reason_err {
( $context:expr, $($arg:tt)* ) => {{
<$crate::ConnectorError as $crate::ConnectorErrorExt>::reason($context, format!($($arg)*))
}};
}
/// Creates a `ConnectorError` with `Custom` kind and a source error attached to it
///
/// Shorthand for
/// ```rust
/// <crate::ConnectorError as crate::ConnectorErrorExt>::custom(context, source)
/// ```
#[macro_export]
macro_rules! custom_err {
( $context:expr, $source:expr $(,)? ) => {{
<$crate::ConnectorError as $crate::ConnectorErrorExt>::custom($context, $source)
}};
}

View file

@ -0,0 +1,17 @@
[package]
name = "ironrdp-error"
version = "0.1.0"
readme = "README.md"
description = "IronPDU generic error definition"
edition.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
authors.workspace = true
keywords.workspace = true
categories.workspace = true
[features]
default = []
std = ["alloc"]
alloc = []

View file

@ -0,0 +1,189 @@
#![cfg_attr(not(feature = "std"), no_std)]
#[cfg(feature = "alloc")]
extern crate alloc;
use core::fmt;
#[cfg(all(not(feature = "std"), feature = "alloc"))]
trait NoAllocSource: fmt::Display + fmt::Debug {}
#[cfg(all(not(feature = "std"), feature = "alloc"))]
impl<T> NoAllocSource for T where T: fmt::Display + fmt::Debug {}
#[derive(Debug)]
pub struct Error<Kind> {
pub context: &'static str,
pub kind: Kind,
#[cfg(feature = "std")]
source: Option<alloc::boxed::Box<dyn std::error::Error + Sync + Send + 'static>>,
#[cfg(all(not(feature = "std"), feature = "alloc"))]
source: Option<alloc::boxed::Box<dyn NoAllocSource + Sync + Send + 'static>>,
}
impl<Kind> Error<Kind> {
#[cold]
pub fn new(context: &'static str, kind: Kind) -> Self {
Self {
context,
kind,
#[cfg(feature = "alloc")]
source: None,
}
}
#[cfg(feature = "std")]
#[cold]
pub fn with_source<E>(mut self, source: E) -> Self
where
E: std::error::Error + Sync + Send + 'static,
{
self.source = Some(Box::new(source));
self
}
#[cfg(all(not(feature = "std"), feature = "alloc"))]
#[cold]
pub fn with_source<E>(mut self, source: E) -> Self
where
E: fmt::Display + fmt::Debug + Sync + Send + 'static,
{
#[cfg(feature = "alloc")]
{
self.source = Some(alloc::boxed::Box::new(source));
}
// No source when no std and no alloc crates
#[cfg(not(feature = "alloc"))]
{
let _ = source;
}
self
}
pub fn into_other_kind<OtherKind>(self) -> Error<OtherKind>
where
Kind: Into<OtherKind>,
{
Error {
context: self.context,
kind: self.kind.into(),
#[cfg(any(feature = "std", feature = "alloc"))]
source: self.source,
}
}
pub fn kind(&self) -> &Kind {
&self.kind
}
pub fn report(&self) -> ErrorReport<'_, Kind> {
ErrorReport(self)
}
}
impl<Kind> fmt::Display for Error<Kind>
where
Kind: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{}] {}", self.context, self.kind)
}
}
#[cfg(feature = "std")]
impl<Kind> std::error::Error for Error<Kind>
where
Kind: std::error::Error,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
if let Some(source) = self.kind.source() {
Some(source)
} else {
// NOTE: we cant use Option::as_ref here because of type inference
if let Some(e) = &self.source {
Some(e.as_ref())
} else {
None
}
}
}
}
#[cfg(feature = "std")]
impl<Kind> From<Error<Kind>> for std::io::Error
where
Kind: std::error::Error + Send + Sync + 'static,
{
fn from(error: Error<Kind>) -> Self {
std::io::Error::new(std::io::ErrorKind::Other, error)
}
}
pub struct ErrorReport<'a, Kind>(&'a Error<Kind>);
#[cfg(feature = "std")]
impl<Kind> fmt::Display for ErrorReport<'_, Kind>
where
Kind: std::error::Error,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::error::Error;
write!(f, "{}", self.0)?;
let mut next_source = self.0.source();
while let Some(e) = next_source {
write!(f, ", caused by: {e}")?;
next_source = e.source();
}
Ok(())
}
}
#[cfg(not(feature = "std"))]
impl<E> fmt::Display for ErrorReport<'_, E>
where
E: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)?;
#[cfg(feature = "alloc")]
if let Some(source) = &self.0.source {
write!(f, ", caused by: {source}")?;
}
Ok(())
}
}
/// Temporary compability traits to smooth transition from old style
#[cfg(feature = "std")]
#[doc(hidden)]
pub mod legacy {
#[doc(hidden)]
pub trait CatchAllKind {
const CATCH_ALL_VALUE: Self;
}
#[doc(hidden)]
pub trait ErrorContext: std::error::Error {
fn context(&self) -> &'static str;
}
#[doc(hidden)]
impl<E, Kind> From<E> for crate::Error<Kind>
where
E: ErrorContext + Send + Sync + 'static,
Kind: CatchAllKind,
{
#[cold]
fn from(error: E) -> Self {
Self::new(error.context(), Kind::CATCH_ALL_VALUE).with_source(error)
}
}
}

View file

@ -16,15 +16,16 @@ doctest = false
# test = false
[dependencies]
ironrdp-pdu.workspace = true
num-traits = "0.2.15"
num-derive = "0.3.3"
byteorder = "1.4.3"
thiserror = "1.0.40"
bitvec = "1.0.1"
bit_field = "0.10.2"
bitflags = "2"
bitvec = "1.0.1"
byteorder = "1.4.3"
ironrdp-error.workspace = true
ironrdp-pdu.workspace = true
lazy_static = "1.4.0"
num-derive = "0.3.3"
num-traits = "0.2.15"
thiserror = "1.0.40"
[dev-dependencies]
bmp = "0.5"

View file

@ -1,5 +1,5 @@
use ironrdp_pdu::bitmap::rdp6::{BitmapStream as BitmapStreamPdu, ColorPlanes};
use ironrdp_pdu::{decode, Error as PduError};
use ironrdp_pdu::{decode, PduError};
use thiserror::Error;
use crate::color_conversion::{Rgb, YCoCg};

View file

@ -447,6 +447,12 @@ pub enum ZgfxError {
TokenBitsNotFound,
}
impl ironrdp_error::legacy::ErrorContext for ZgfxError {
fn context(&self) -> &'static str {
"zgfx"
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -16,11 +16,13 @@ doctest = false
# test = false
[features]
default = []
std = []
default = ["std"]
std = ["alloc", "ironrdp-error/std"]
alloc = ["ironrdp-error/alloc"]
[dependencies]
bitflags = "2"
ironrdp-error.workspace = true
tap = "1.0.1"
# TODO: get rid of these dependencies (related code should probably go into another crate)

View file

@ -1,4 +1,4 @@
use crate::{Error as PduError, PduDecode, PduEncode, ReadCursor, Result as PduResult, WriteCursor};
use crate::{PduDecode, PduEncode, PduResult, ReadCursor, WriteCursor};
const NON_RLE_PADDING_SIZE: usize = 1;
@ -56,10 +56,10 @@ impl<'a> PduDecode<'a> for BitmapStream<'a> {
let color_planes_size = if !enable_rle_compression {
// Cut padding field if RLE flags is set to 0
if src.is_empty() {
return Err(PduError::Other {
context: Self::NAME,
reason: "Missing padding byte from zero-size Non-RLE bitmap data",
});
return Err(invalid_message_err!(
"padding",
"missing padding byte from zero-sized non-RLE bitmap data",
));
}
src.len() - NON_RLE_PADDING_SIZE
} else {
@ -250,10 +250,13 @@ mod tests {
assert_parsing_failure(
&[],
expect![[r#"
NotEnoughBytes {
name: "Rdp6BitmapStream",
Error {
context: "Rdp6BitmapStream",
kind: NotEnoughBytes {
received: 0,
expected: 1,
},
source: None,
}
"#]],
);
@ -262,9 +265,13 @@ mod tests {
assert_parsing_failure(
&[0x20],
expect![[r#"
Other {
Error {
context: "Rdp6BitmapStream",
reason: "Missing padding byte from zero-size Non-RLE bitmap data",
kind: InvalidMessage {
field: "padding",
reason: "missing padding byte from zero-size Non-RLE bitmap data",
},
source: None,
}
"#]],
);

View file

@ -290,3 +290,9 @@ pub enum FastPathError {
#[error("Invalid RDP Share Data Header: {0}")]
InvalidShareDataHeader(String),
}
impl ironrdp_error::legacy::ErrorContext for FastPathError {
fn context(&self) -> &'static str {
"Fast-Path"
}
}

View file

@ -250,3 +250,9 @@ pub enum RfxError {
#[error("Got invalid IT flag of TileSet: {0}")]
InvalidItFlag(bool),
}
impl ironrdp_error::legacy::ErrorContext for RfxError {
fn context(&self) -> &'static str {
"RFX"
}
}

View file

@ -19,6 +19,7 @@ pub mod pcb;
pub mod rdp;
pub mod tpdu;
pub mod tpkt;
pub mod utf16;
pub mod utils;
pub mod x224;
@ -30,100 +31,86 @@ pub(crate) mod per;
pub use crate::basic_output::{bitmap, fast_path, surface_commands};
pub use crate::rdp::vc::dvc;
pub type Result<T> = core::result::Result<T, Error>;
pub type PduResult<T> = core::result::Result<T, PduError>;
pub type PduError = ironrdp_error::Error<PduErrorKind>;
#[non_exhaustive]
#[derive(Debug)]
pub enum Error {
NotEnoughBytes {
name: &'static str,
received: usize,
expected: usize,
},
InvalidMessage {
name: &'static str,
field: &'static str,
reason: &'static str,
},
UnexpectedMessageType {
name: &'static str,
got: u8,
},
UnsupportedVersion {
name: &'static str,
got: u8,
},
Other {
context: &'static str,
reason: &'static str,
},
Custom(Box<dyn std::error::Error + Sync + Send + 'static>),
#[derive(Clone, Debug)]
pub enum PduErrorKind {
NotEnoughBytes { received: usize, expected: usize },
InvalidMessage { field: &'static str, reason: &'static str },
UnexpectedMessageType { got: u8 },
UnsupportedVersion { got: u8 },
Other { description: &'static str },
Custom,
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
if let Error::Custom(e) = &self {
Some(e.as_ref())
} else {
None
}
}
}
impl std::error::Error for PduErrorKind {}
impl From<Error> for std::io::Error {
fn from(error: Error) -> Self {
std::io::Error::new(std::io::ErrorKind::Other, error)
}
}
impl fmt::Display for Error {
impl fmt::Display for PduErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::NotEnoughBytes {
name,
received,
expected,
} => write!(
Self::NotEnoughBytes { received, expected } => write!(
f,
"not enough bytes provided to decode {name}: received {received} bytes, expected {expected} bytes"
"not enough bytes provided to decode: received {received} bytes, expected {expected} bytes"
),
Error::InvalidMessage { name, field, reason } => {
write!(f, "invalid `{field}` in {name}: {reason}")
Self::InvalidMessage { field, reason } => {
write!(f, "invalid `{field}`: {reason}")
}
Error::UnexpectedMessageType { name, got } => {
write!(f, "invalid message type ({got}) for {name}")
Self::UnexpectedMessageType { got } => {
write!(f, "invalid message type ({got})")
}
Error::UnsupportedVersion { name, got } => {
write!(f, "unsupported version ({got}) for {name}")
Self::UnsupportedVersion { got } => {
write!(f, "unsupported version ({got})")
}
Error::Other { context, reason } => {
write!(f, "{reason} ({context})")
Self::Other { description } => {
write!(f, "{description}")
}
Error::Custom(e) => {
if f.alternate() {
write!(f, "{e}")?;
let mut next_source = e.source();
while let Some(e) = next_source {
write!(f, ", caused by: {e}")?;
next_source = e.source();
}
} else {
write!(f, "custom")?;
}
Ok(())
Self::Custom => {
write!(f, "custom")
}
}
}
}
impl Error {
pub fn custom<E>(e: E) -> Self
pub trait PduErrorExt {
fn not_enough_bytes(context: &'static str, received: usize, expected: usize) -> Self;
fn invalid_message(context: &'static str, field: &'static str, reason: &'static str) -> Self;
fn unexpected_message_type(context: &'static str, got: u8) -> Self;
fn unsupported_version(context: &'static str, got: u8) -> Self;
fn other(context: &'static str, description: &'static str) -> Self;
fn custom<E>(context: &'static str, e: E) -> Self
where
E: std::error::Error + Sync + Send + 'static;
}
impl PduErrorExt for PduError {
fn not_enough_bytes(context: &'static str, received: usize, expected: usize) -> Self {
Self::new(context, PduErrorKind::NotEnoughBytes { received, expected })
}
fn invalid_message(context: &'static str, field: &'static str, reason: &'static str) -> Self {
Self::new(context, PduErrorKind::InvalidMessage { field, reason })
}
fn unexpected_message_type(context: &'static str, got: u8) -> Self {
Self::new(context, PduErrorKind::UnexpectedMessageType { got })
}
fn unsupported_version(context: &'static str, got: u8) -> Self {
Self::new(context, PduErrorKind::UnsupportedVersion { got })
}
fn other(context: &'static str, description: &'static str) -> Self {
Self::new(context, PduErrorKind::Other { description })
}
fn custom<E>(context: &'static str, e: E) -> Self
where
E: std::error::Error + Sync + Send + 'static,
{
Self::Custom(Box::new(e))
Self::new(context, PduErrorKind::Custom).with_source(e)
}
}
@ -140,7 +127,7 @@ pub trait Pdu {
/// This trait is object-safe and may be used in a dynamic context.
pub trait PduEncode {
/// Encodes this PDU in-place using the provided `WriteCursor`.
fn encode(&self, dst: &mut WriteCursor<'_>) -> Result<()>;
fn encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()>;
/// Returns the associated PDU name associated.
fn name(&self) -> &'static str;
@ -152,14 +139,14 @@ pub trait PduEncode {
assert_obj_safe!(PduEncode);
/// Encodes the given PDU in-place into the provided buffer and returns the number of bytes written.
pub fn encode<T: PduEncode>(pdu: &T, dst: &mut [u8]) -> Result<usize> {
pub fn encode<T: PduEncode>(pdu: &T, dst: &mut [u8]) -> PduResult<usize> {
let mut cursor = WriteCursor::new(dst);
encode_cursor(pdu, &mut cursor)?;
Ok(cursor.pos())
}
/// Same as `encode_pdu` but resizes the buffer when it is too small to fit the PDU.
pub fn encode_buf<T: PduEncode>(pdu: &T, buf: &mut Vec<u8>) -> Result<usize> {
pub fn encode_buf<T: PduEncode>(pdu: &T, buf: &mut Vec<u8>) -> PduResult<usize> {
let pdu_size = pdu.size();
if buf.len() < pdu_size {
@ -170,7 +157,7 @@ pub fn encode_buf<T: PduEncode>(pdu: &T, buf: &mut Vec<u8>) -> Result<usize> {
}
/// Encodes the given PDU in-place using the provided `WriteCursor`.
pub fn encode_cursor<T: PduEncode>(pdu: &T, dst: &mut WriteCursor<'_>) -> Result<()> {
pub fn encode_cursor<T: PduEncode>(pdu: &T, dst: &mut WriteCursor<'_>) -> PduResult<()> {
pdu.encode(dst)
}
@ -188,29 +175,29 @@ pub fn size<T: PduEncode>(pdu: &T) -> usize {
///
/// The binary payload must be a full PDU, not some subset of it.
pub trait PduDecode<'de>: Sized {
fn decode(src: &mut ReadCursor<'de>) -> Result<Self>;
fn decode(src: &mut ReadCursor<'de>) -> PduResult<Self>;
}
pub fn decode<'de, T: PduDecode<'de>>(src: &'de [u8]) -> Result<T> {
pub fn decode<'de, T: PduDecode<'de>>(src: &'de [u8]) -> PduResult<T> {
let mut cursor = ReadCursor::new(src);
T::decode(&mut cursor)
}
pub fn decode_cursor<'de, T: PduDecode<'de>>(src: &mut ReadCursor<'de>) -> Result<T> {
pub fn decode_cursor<'de, T: PduDecode<'de>>(src: &mut ReadCursor<'de>) -> PduResult<T> {
T::decode(src)
}
/// Similar to `PduDecode` but unconditionally returns an owned type.
pub trait PduDecodeOwned: Sized {
fn decode_owned(src: &mut ReadCursor<'_>) -> Result<Self>;
fn decode_owned(src: &mut ReadCursor<'_>) -> PduResult<Self>;
}
pub fn decode_owned<T: PduDecodeOwned>(src: &[u8]) -> Result<T> {
pub fn decode_owned<T: PduDecodeOwned>(src: &[u8]) -> PduResult<T> {
let mut cursor = ReadCursor::new(src);
T::decode_owned(&mut cursor)
}
pub fn decode_owned_cursor<T: PduDecodeOwned>(src: &mut ReadCursor<'_>) -> Result<T> {
pub fn decode_owned_cursor<T: PduDecodeOwned>(src: &mut ReadCursor<'_>) -> PduResult<T> {
T::decode_owned(src)
}
@ -249,7 +236,7 @@ pub struct PduInfo {
}
/// Finds next RDP PDU size by reading the next few bytes.
pub fn find_size(bytes: &[u8]) -> Result<Option<PduInfo>> {
pub fn find_size(bytes: &[u8]) -> PduResult<Option<PduInfo>> {
macro_rules! ensure_enough {
($bytes:expr, $len:expr) => {
if $bytes.len() < $len {
@ -261,11 +248,8 @@ pub fn find_size(bytes: &[u8]) -> Result<Option<PduInfo>> {
ensure_enough!(bytes, 1);
let fp_output_header = bytes[0];
let action =
Action::from_fp_output_header(fp_output_header).map_err(|unknown_action| Error::UnexpectedMessageType {
name: "fpOutputHeader",
got: unknown_action,
})?;
let action = Action::from_fp_output_header(fp_output_header)
.map_err(|unknown_action| PduError::unexpected_message_type("fpOutputHeader", unknown_action))?;
match action {
Action::X224 => {
@ -299,7 +283,7 @@ pub fn find_size(bytes: &[u8]) -> Result<Option<PduInfo>> {
pub trait PduHint: core::fmt::Debug {
/// Finds next PDU size by reading the next few bytes.
fn find_size(&self, bytes: &[u8]) -> Result<Option<usize>>;
fn find_size(&self, bytes: &[u8]) -> PduResult<Option<usize>>;
}
#[derive(Clone, Copy, Debug)]
@ -308,7 +292,7 @@ pub struct X224Hint;
pub const X224_HINT: X224Hint = X224Hint;
impl PduHint for X224Hint {
fn find_size(&self, bytes: &[u8]) -> Result<Option<usize>> {
fn find_size(&self, bytes: &[u8]) -> PduResult<Option<usize>> {
match find_size(bytes)? {
Some(pdu_info) => {
debug_assert_eq!(pdu_info.action, Action::X224);
@ -325,7 +309,7 @@ pub struct FastPathHint;
pub const FAST_PATH_HINT: FastPathHint = FastPathHint;
impl PduHint for FastPathHint {
fn find_size(&self, bytes: &[u8]) -> Result<Option<usize>> {
fn find_size(&self, bytes: &[u8]) -> PduResult<Option<usize>> {
match find_size(bytes)? {
Some(pdu_info) => {
debug_assert_eq!(pdu_info.action, Action::FastPath);

View file

@ -2,42 +2,156 @@
//!
//! Some are exported and available to external crates
/// Creates a `PduError` with `NotEnoughBytes` kind
///
/// Shorthand for
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::not_enough_bytes(context, received, expected)
/// ```
/// and
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::not_enough_bytes(Self::NAME, received, expected)
/// ```
#[macro_export]
macro_rules! not_enough_bytes_err {
( $context:expr, $received:expr , $expected:expr $(,)? ) => {{
<$crate::PduError as $crate::PduErrorExt>::not_enough_bytes($context, $received, $expected)
}};
( $received:expr , $expected:expr $(,)? ) => {{
not_enough_bytes_err!(Self::NAME, $received, $expected)
}};
}
/// Creates a `PduError` with `InvalidMessage` kind
///
/// Shorthand for
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::invalid_message(context, field, reason)
/// ```
/// and
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::invalid_message(Self::NAME, field, reason)
/// ```
#[macro_export]
macro_rules! invalid_message_err {
( $context:expr, $field:expr , $reason:expr $(,)? ) => {{
<$crate::PduError as $crate::PduErrorExt>::invalid_message($context, $field, $reason)
}};
( $field:expr , $reason:expr $(,)? ) => {{
invalid_message_err!(Self::NAME, $field, $reason)
}};
}
/// Creates a `PduError` with `UnexpectedMessageType` kind
///
/// Shorthand for
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::unexpected_message_type(context, got)
/// ```
/// and
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::unexpected_message_type(Self::NAME, got)
/// ```
#[macro_export]
macro_rules! unexpected_message_type_err {
( $context:expr, $got:expr $(,)? ) => {{
<$crate::PduError as $crate::PduErrorExt>::unexpected_message_type($context, $got)
}};
( $got:expr $(,)? ) => {{
unexpected_message_type_err!(Self::NAME, $got)
}};
}
/// Creates a `PduError` with `UnsupportedVersion` kind
///
/// Shorthand for
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::unsupported_version(context, got)
/// ```
/// and
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::unsupported_version(Self::NAME, got)
/// ```
#[macro_export]
macro_rules! unsupported_version_err {
( $context:expr, $got:expr $(,)? ) => {{
<$crate::PduError as $crate::PduErrorExt>::unsupported_version($context, $got)
}};
( $got:expr $(,)? ) => {{
unsupported_version_err!(Self::NAME, $got)
}};
}
/// Creates a `PduError` with `Other` kind
///
/// Shorthand for
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::other(context, description)
/// ```
/// and
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::other(Self::NAME, description)
/// ```
#[macro_export]
macro_rules! other_err {
( $context:expr, $description:expr $(,)? ) => {{
<$crate::PduError as $crate::PduErrorExt>::other($context, $description)
}};
( $description:expr $(,)? ) => {{
other_err!(Self::NAME, $description)
}};
}
/// Creates a `PduError` with `Custom` kind and a source error attached to it
///
/// Shorthand for
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::custom(context, source)
/// ```
/// and
/// ```rust
/// <crate::PduError as crate::PduErrorExt>::custom(Self::NAME, source)
/// ```
#[macro_export]
macro_rules! custom_err {
( $context:expr, $source:expr $(,)? ) => {{
<$crate::PduError as $crate::PduErrorExt>::custom($context, $source)
}};
( $source:expr $(,)? ) => {{
custom_err!(Self::NAME, $source)
}};
}
#[macro_export]
macro_rules! ensure_size {
(name: $name:expr, in: $buf:ident, size: $expected:expr) => {{
(ctx: $ctx:expr, in: $buf:ident, size: $expected:expr) => {{
let received = $buf.len();
let expected = $expected;
if !(received >= expected) {
return Err($crate::Error::NotEnoughBytes {
name: $name,
received,
expected,
});
return Err(<$crate::PduError as $crate::PduErrorExt>::not_enough_bytes($ctx, received, expected));
}
}};
(in: $buf:ident, size: $expected:expr) => {{
$crate::ensure_size!(name: Self::NAME, in: $buf, size: $expected)
$crate::ensure_size!(ctx: Self::NAME, in: $buf, size: $expected)
}};
}
#[macro_export]
macro_rules! ensure_fixed_part_size {
(in: $buf:ident) => {{
$crate::ensure_size!(name: Self::NAME, in: $buf, size: Self::FIXED_PART_SIZE)
$crate::ensure_size!(ctx: Self::NAME, in: $buf, size: Self::FIXED_PART_SIZE)
}};
}
#[macro_export]
macro_rules! cast_length {
($len:expr, $name:expr, $field:expr) => {{
$len.try_into().map_err(|_| $crate::Error::InvalidMessage {
name: $name,
field: $field,
reason: "too many elements",
($ctx:expr, $field:expr, $len:expr) => {{
$len.try_into().map_err(|e| {
<$crate::PduError as $crate::PduErrorExt>::invalid_message($ctx, $field, "too many elements").with_source(e)
})
}};
($len:expr, $field:expr) => {{
$crate::cast_length!($len, <Self as $crate::Pdu>::NAME, $field)
($field:expr, $len:expr) => {{
$crate::cast_length!(<Self as $crate::Pdu>::NAME, $field, $len)
}};
}
@ -64,7 +178,7 @@ macro_rules! impl_pdu_pod {
}
impl $crate::PduDecodeOwned for $pdu_ty {
fn decode_owned(src: &mut $crate::cursor::ReadCursor<'_>) -> $crate::Result<Self> {
fn decode_owned(src: &mut $crate::cursor::ReadCursor<'_>) -> $crate::PduResult<Self> {
<Self as $crate::PduDecode>::decode(src)
}
}
@ -78,7 +192,7 @@ macro_rules! impl_pdu_borrowing {
pub type $owned_ty = $pdu_ty<'static>;
impl $crate::PduDecodeOwned for $owned_ty {
fn decode_owned(src: &mut $crate::cursor::ReadCursor<'_>) -> $crate::Result<Self> {
fn decode_owned(src: &mut $crate::cursor::ReadCursor<'_>) -> $crate::PduResult<Self> {
let pdu = <$pdu_ty as $crate::PduDecode>::decode(src)?;
Ok($crate::IntoOwnedPdu::into_owned_pdu(pdu))
}

View file

@ -5,7 +5,7 @@ use crate::gcc::{Channel, ClientGccBlocks, ConferenceCreateRequest, ConferenceCr
use crate::tpdu::{TpduCode, TpduHeader};
use crate::tpkt::TpktHeader;
use crate::x224::{user_data_size, X224Pdu};
use crate::{per, IntoOwnedPdu, Result};
use crate::{per, IntoOwnedPdu, PduError, PduErrorExt as _, PduResult};
// T.125 MCS is defined in:
//
@ -131,13 +131,28 @@ pub const RESULT_ENUM_LENGTH: u8 = 16;
const BASE_CHANNEL_ID: u16 = 1001;
const SEND_DATA_PDU_DATA_PRIORITY_AND_SEGMENTATION: u8 = 0x70;
/// Creates a closure mapping a `PerError` to a `PduError` with field-level context.
///
/// Shorthand for
/// ```rust
/// |e| <crate::PduError as crate::PduErrorExt>::invalid_message(Self::MCS_NAME, field_name, "PER").with_source(e)
/// ```
macro_rules! per_field_err {
($field_name:expr) => {{
|error| {
<$crate::PduError as $crate::PduErrorExt>::invalid_message(Self::MCS_NAME, $field_name, "PER")
.with_source(error)
}
}};
}
#[doc(hidden)]
pub trait McsPdu<'de>: Sized {
const MCS_NAME: &'static str;
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()>;
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()>;
fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> Result<Self>;
fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> PduResult<Self>;
fn mcs_size(&self) -> usize;
@ -154,11 +169,11 @@ where
const TPDU_CODE: TpduCode = TpduCode::DATA;
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
self.mcs_body_encode(dst)
}
fn x224_body_decode(src: &mut ReadCursor<'de>, tpkt: &TpktHeader, tpdu: &TpduHeader) -> Result<Self> {
fn x224_body_decode(src: &mut ReadCursor<'de>, tpkt: &TpktHeader, tpdu: &TpduHeader) -> PduResult<Self> {
let tpdu_user_data_size = user_data_size(tpkt, tpdu);
T::mcs_body_decode(src, tpdu_user_data_size)
}
@ -186,12 +201,9 @@ enum DomainMcsPdu {
}
impl DomainMcsPdu {
fn check_expected(self, name: &'static str, expected: DomainMcsPdu) -> crate::Result<()> {
fn check_expected(self, name: &'static str, expected: DomainMcsPdu) -> PduResult<()> {
if self != expected {
Err(crate::Error::UnexpectedMessageType {
name,
got: expected.as_u8(),
})
Err(PduError::unexpected_message_type(name, expected.as_u8()))
} else {
Ok(())
}
@ -224,26 +236,26 @@ impl DomainMcsPdu {
}
}
fn read_mcspdu_header(src: &mut ReadCursor<'_>, name: &'static str) -> crate::Result<DomainMcsPdu> {
ensure_size!(name: name, in: src, size: 1);
fn read_mcspdu_header(src: &mut ReadCursor<'_>, name: &'static str) -> PduResult<DomainMcsPdu> {
ensure_size!(ctx: name, in: src, size: 1);
let choice = src.read_u8();
DomainMcsPdu::from_choice(choice).ok_or(crate::Error::InvalidMessage {
DomainMcsPdu::from_choice(choice).ok_or(PduError::invalid_message(
name,
field: "domain-mcspdu",
reason: "unexpected application tag for CHOICE",
})
"domain-mcspdu",
"unexpected application tag for CHOICE",
))
}
fn peek_mcspdu_header(src: &mut ReadCursor<'_>, name: &'static str) -> crate::Result<DomainMcsPdu> {
ensure_size!(name: name, in: src, size: 1);
fn peek_mcspdu_header(src: &mut ReadCursor<'_>, name: &'static str) -> PduResult<DomainMcsPdu> {
ensure_size!(ctx: name, in: src, size: 1);
let choice = src.peek_u8();
DomainMcsPdu::from_choice(choice).ok_or(crate::Error::InvalidMessage {
DomainMcsPdu::from_choice(choice).ok_or(PduError::invalid_message(
name,
field: "domain-mcspdu",
reason: "unexpected application tag for CHOICE",
})
"domain-mcspdu",
"unexpected application tag for CHOICE",
))
}
fn write_mcspdu_header(dst: &mut WriteCursor<'_>, domain_mcspdu: DomainMcsPdu, options: u8) {
@ -290,7 +302,7 @@ impl IntoOwnedPdu for McsMessage<'_> {
impl<'de> McsPdu<'de> for McsMessage<'de> {
const MCS_NAME: &'static str = "McsMessage";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
match self {
Self::ErectDomainRequest(msg) => msg.mcs_body_encode(dst),
Self::AttachUserRequest(msg) => msg.mcs_body_encode(dst),
@ -303,7 +315,7 @@ impl<'de> McsPdu<'de> for McsMessage<'de> {
}
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> PduResult<Self> {
match peek_mcspdu_header(src, Self::MCS_NAME)? {
DomainMcsPdu::ErectDomainRequest => Ok(McsMessage::ErectDomainRequest(ErectDomainPdu::mcs_body_decode(
src,
@ -374,7 +386,7 @@ impl_pdu_pod!(ErectDomainPdu);
impl<'de> McsPdu<'de> for ErectDomainPdu {
const MCS_NAME: &'static str = "ErectDomainPdu";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_mcspdu_header(dst, DomainMcsPdu::ErectDomainRequest, 0);
per::write_u32(dst, self.sub_height);
@ -383,11 +395,11 @@ impl<'de> McsPdu<'de> for ErectDomainPdu {
Ok(())
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> PduResult<Self> {
read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::ErectDomainRequest)?;
let sub_height = per::read_u32(src)?;
let sub_interval = per::read_u32(src)?;
let sub_height = per::read_u32(src).map_err(per_field_err!("subHeight"))?;
let sub_interval = per::read_u32(src).map_err(per_field_err!("subInterval"))?;
Ok(Self {
sub_height,
@ -408,13 +420,13 @@ impl_pdu_pod!(AttachUserRequest);
impl<'de> McsPdu<'de> for AttachUserRequest {
const MCS_NAME: &'static str = "AttachUserRequest";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_mcspdu_header(dst, DomainMcsPdu::AttachUserRequest, 0);
Ok(())
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> PduResult<Self> {
read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::AttachUserRequest)?;
Ok(Self)
@ -436,20 +448,20 @@ impl_pdu_pod!(AttachUserConfirm);
impl<'de> McsPdu<'de> for AttachUserConfirm {
const MCS_NAME: &'static str = "AttachUserConfirm";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_mcspdu_header(dst, DomainMcsPdu::AttachUserConfirm, 2);
per::write_enum(dst, self.result);
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID)?;
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
Ok(())
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> PduResult<Self> {
read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::AttachUserConfirm)?;
let result = per::read_enum(src, RESULT_ENUM_LENGTH)?;
let user_id = per::read_u16(src, BASE_CHANNEL_ID)?;
let result = per::read_enum(src, RESULT_ENUM_LENGTH).map_err(per_field_err!("result"))?;
let user_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("userId"))?;
Ok(Self {
result,
@ -473,20 +485,20 @@ impl_pdu_pod!(ChannelJoinRequest);
impl<'de> McsPdu<'de> for ChannelJoinRequest {
const MCS_NAME: &'static str = "ChannelJoinRequest";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_mcspdu_header(dst, DomainMcsPdu::ChannelJoinRequest, 0);
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID)?;
per::write_u16(dst, self.channel_id, 0)?;
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
per::write_u16(dst, self.channel_id, 0).map_err(per_field_err!("channelId"))?;
Ok(())
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> PduResult<Self> {
read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::ChannelJoinRequest)?;
let initiator_id = per::read_u16(src, BASE_CHANNEL_ID)?;
let channel_id = per::read_u16(src, 0)?;
let initiator_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
let channel_id = per::read_u16(src, 0).map_err(per_field_err!("channelID"))?;
Ok(Self {
initiator_id,
@ -512,24 +524,24 @@ impl_pdu_pod!(ChannelJoinConfirm);
impl<'de> McsPdu<'de> for ChannelJoinConfirm {
const MCS_NAME: &'static str = "ChannelJoinConfirm";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_mcspdu_header(dst, DomainMcsPdu::ChannelJoinConfirm, 2);
per::write_enum(dst, self.result);
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID)?;
per::write_u16(dst, self.requested_channel_id, 0)?;
per::write_u16(dst, self.channel_id, 0)?;
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
per::write_u16(dst, self.requested_channel_id, 0).map_err(per_field_err!("requested"))?;
per::write_u16(dst, self.channel_id, 0).map_err(per_field_err!("channelId"))?;
Ok(())
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> PduResult<Self> {
read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::ChannelJoinConfirm)?;
let result = per::read_enum(src, RESULT_ENUM_LENGTH)?;
let initiator_id = per::read_u16(src, BASE_CHANNEL_ID)?;
let requested_channel_id = per::read_u16(src, 0)?;
let channel_id = per::read_u16(src, 0)?;
let result = per::read_enum(src, RESULT_ENUM_LENGTH).map_err(per_field_err!("result"))?;
let initiator_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
let requested_channel_id = per::read_u16(src, 0).map_err(per_field_err!("requested"))?;
let channel_id = per::read_u16(src, 0).map_err(per_field_err!("channelId"))?;
Ok(Self {
result,
@ -567,44 +579,44 @@ impl IntoOwnedPdu for SendDataRequest<'_> {
impl<'de> McsPdu<'de> for SendDataRequest<'de> {
const MCS_NAME: &'static str = "SendDataRequest";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_mcspdu_header(dst, DomainMcsPdu::SendDataRequest, 0);
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID)?;
per::write_u16(dst, self.channel_id, 0)?;
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
per::write_u16(dst, self.channel_id, 0).map_err(per_field_err!("channelID"))?;
dst.write_u8(SEND_DATA_PDU_DATA_PRIORITY_AND_SEGMENTATION);
per::write_length(dst, cast_length!(self.user_data.len(), "user-data-length")?);
per::write_length(dst, cast_length!("user-data-length", self.user_data.len())?);
dst.write_slice(&self.user_data);
Ok(())
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> PduResult<Self> {
let src_len_before = src.len();
read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::SendDataRequest)?;
let initiator_id = per::read_u16(src, BASE_CHANNEL_ID)?;
let channel_id = per::read_u16(src, 0)?;
let initiator_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
let channel_id = per::read_u16(src, 0).map_err(per_field_err!("channelId"))?;
let _data_priority_and_segmentation = src.read_u8();
let (length, _) = per::read_length(src)?;
let (length, _) = per::read_length(src).map_err(per_field_err!("userDataLength"))?;
let length = usize::from(length);
let src_len_after = src.len();
if length > tpdu_user_data_size.saturating_sub(src_len_before - src_len_after) {
return Err(crate::Error::InvalidMessage {
name: Self::MCS_NAME,
field: "user-data-length",
reason: "inconsistent with user data size advertised in TPDU",
});
return Err(PduError::invalid_message(
Self::MCS_NAME,
"userDataLength",
"inconsistent with user data size advertised in TPDU",
));
}
ensure_size!(name: Self::MCS_NAME, in: src, size: length);
ensure_size!(ctx: Self::MCS_NAME, in: src, size: length);
let user_data = Cow::Borrowed(src.read_slice(length));
Ok(Self {
@ -646,44 +658,44 @@ impl IntoOwnedPdu for SendDataIndication<'_> {
impl<'de> McsPdu<'de> for SendDataIndication<'de> {
const MCS_NAME: &'static str = "SendDataIndication";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_mcspdu_header(dst, DomainMcsPdu::SendDataIndication, 0);
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID)?;
per::write_u16(dst, self.channel_id, 0)?;
per::write_u16(dst, self.initiator_id, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
per::write_u16(dst, self.channel_id, 0).map_err(per_field_err!("channelId"))?;
dst.write_u8(SEND_DATA_PDU_DATA_PRIORITY_AND_SEGMENTATION);
per::write_length(dst, cast_length!(self.user_data.len(), "user-data-length")?);
per::write_length(dst, cast_length!("userDataLength", self.user_data.len())?);
dst.write_slice(&self.user_data);
Ok(())
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> PduResult<Self> {
let src_len_before = src.len();
read_mcspdu_header(src, Self::MCS_NAME)?.check_expected(Self::MCS_NAME, DomainMcsPdu::SendDataIndication)?;
let initiator_id = per::read_u16(src, BASE_CHANNEL_ID)?;
let channel_id = per::read_u16(src, 0)?;
let initiator_id = per::read_u16(src, BASE_CHANNEL_ID).map_err(per_field_err!("initiator"))?;
let channel_id = per::read_u16(src, 0).map_err(per_field_err!("channelId"))?;
let _data_priority_and_segmentation = src.read_u8();
let (length, _) = per::read_length(src)?;
let (length, _) = per::read_length(src).map_err(per_field_err!("userDataLength"))?;
let length = usize::from(length);
let src_len_after = src.len();
if length > tpdu_user_data_size.saturating_sub(src_len_before - src_len_after) {
return Err(crate::Error::InvalidMessage {
name: Self::MCS_NAME,
field: "user-data-length",
reason: "inconsistent with user data size advertised in TPDU",
});
return Err(PduError::invalid_message(
Self::MCS_NAME,
"userDataLength",
"inconsistent with user data size advertised in TPDU",
));
}
ensure_size!(name: Self::MCS_NAME, in: src, size: length);
ensure_size!(ctx: Self::MCS_NAME, in: src, size: length);
let user_data = Cow::Borrowed(src.read_slice(length));
Ok(Self {
@ -740,7 +752,7 @@ impl_pdu_pod!(DisconnectProviderUltimatum);
impl<'de> McsPdu<'de> for DisconnectProviderUltimatum {
const MCS_NAME: &'static str = "DisconnectProviderUltimatum";
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
let domain_mcspdu = DomainMcsPdu::DisconnectProviderUltimatum.as_u8();
let reason = self.reason.as_u8();
@ -752,7 +764,7 @@ impl<'de> McsPdu<'de> for DisconnectProviderUltimatum {
Ok(())
}
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> Result<Self> {
fn mcs_body_decode(src: &mut ReadCursor<'de>, _: usize) -> PduResult<Self> {
// http://msdn.microsoft.com/en-us/library/cc240872.aspx:
//
// PER encoded (ALIGNED variant of BASIC-PER) PDU contents:
@ -784,19 +796,19 @@ impl<'de> McsPdu<'de> for DisconnectProviderUltimatum {
let reason = (b1 & 0x03) << 1 | (b2 >> 7);
DomainMcsPdu::from_u8(domain_mcspdu_choice)
.ok_or(crate::Error::InvalidMessage {
name: Self::MCS_NAME,
field: "domain-mcspdu",
reason: "unexpected application tag for CHOICE",
})?
.ok_or(PduError::invalid_message(
Self::MCS_NAME,
"domain-mcspdu",
"unexpected application tag for CHOICE",
))?
.check_expected(Self::MCS_NAME, DomainMcsPdu::DisconnectProviderUltimatum)?;
Ok(Self {
reason: DisconnectReason::from_u8(reason).ok_or(crate::Error::InvalidMessage {
name: Self::MCS_NAME,
field: "reason",
reason: "unknown variant",
})?,
reason: DisconnectReason::from_u8(reason).ok_or(PduError::invalid_message(
Self::MCS_NAME,
"reason",
"unknown variant",
))?,
})
}
@ -916,6 +928,22 @@ mod legacy {
use crate::gcc::GccError;
use crate::{ber, PduParsing};
// impl<'de> McsPdu<'de> for ConnectInitial {
// const MCS_NAME: &'static str = "DisconnectProviderUltimatum";
// fn mcs_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
// todo!()
// }
// fn mcs_body_decode(src: &mut ReadCursor<'de>, tpdu_user_data_size: usize) -> Result<Self> {
// todo!()
// }
// fn mcs_size(&self) -> usize {
// todo!()
// }
// }
const MCS_TYPE_CONNECT_INITIAL: u8 = 0x65;
const MCS_TYPE_CONNECT_RESPONSE: u8 = 0x66;
@ -1107,4 +1135,10 @@ mod legacy {
io::Error::new(io::ErrorKind::Other, format!("MCS Connection Sequence error: {e}"))
}
}
impl ironrdp_error::legacy::ErrorContext for McsError {
fn context(&self) -> &'static str {
"mcs"
}
}
}

View file

@ -7,7 +7,7 @@ use crate::cursor::{ReadCursor, WriteCursor};
use crate::tpdu::{TpduCode, TpduHeader};
use crate::tpkt::TpktHeader;
use crate::x224::X224Pdu;
use crate::{Error, Pdu as _, Result};
use crate::{Pdu as _, PduError, PduErrorExt as _, PduResult};
bitflags! {
/// A 32-bit, unsigned integer that contains flags indicating the supported
@ -104,14 +104,14 @@ impl NegoRequestData {
Self::Cookie(Cookie(value))
}
pub fn read(src: &mut ReadCursor<'_>) -> Result<Option<Self>> {
pub fn read(src: &mut ReadCursor<'_>) -> PduResult<Option<Self>> {
match RoutingToken::read(src)? {
Some(token) => Ok(Some(Self::RoutingToken(token))),
None => Cookie::read(src)?.map(Self::Cookie).pipe(Ok),
}
}
pub fn write(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
pub fn write(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
match self {
NegoRequestData::RoutingToken(token) => token.write(dst),
NegoRequestData::Cookie(cookie) => cookie.write(dst),
@ -132,11 +132,11 @@ pub struct Cookie(pub String);
impl Cookie {
const PREFIX: &str = "Cookie: mstshash=";
pub fn read(src: &mut ReadCursor<'_>) -> Result<Option<Self>> {
pub fn read(src: &mut ReadCursor<'_>) -> PduResult<Option<Self>> {
read_nego_data(src, "Cookie", Self::PREFIX)?.map(Self).pipe(Ok)
}
pub fn write(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
pub fn write(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_nego_data(dst, "Cookie", Self::PREFIX, &self.0)
}
@ -151,11 +151,11 @@ pub struct RoutingToken(pub String);
impl RoutingToken {
const PREFIX: &str = "Cookie: msts=";
pub fn read(src: &mut ReadCursor<'_>) -> Result<Option<Self>> {
pub fn read(src: &mut ReadCursor<'_>) -> PduResult<Option<Self>> {
read_nego_data(src, "RoutingToken", Self::PREFIX)?.map(Self).pipe(Ok)
}
pub fn write(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
pub fn write(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
write_nego_data(dst, "RoutingToken", Self::PREFIX, &self.0)
}
@ -203,7 +203,7 @@ impl<'de> X224Pdu<'de> for ConnectionRequest {
const TPDU_CODE: TpduCode = TpduCode::CONNECTION_REQUEST;
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
if let Some(nego_data) = &self.nego_data {
nego_data.write(dst)?;
}
@ -218,7 +218,7 @@ impl<'de> X224Pdu<'de> for ConnectionRequest {
Ok(())
}
fn x224_body_decode(src: &mut ReadCursor<'de>, _: &TpktHeader, tpdu: &TpduHeader) -> Result<Self> {
fn x224_body_decode(src: &mut ReadCursor<'de>, _: &TpktHeader, tpdu: &TpduHeader) -> PduResult<Self> {
let variable_part_size = tpdu.variable_part_size();
ensure_size!(in: src, size: variable_part_size);
@ -226,28 +226,25 @@ impl<'de> X224Pdu<'de> for ConnectionRequest {
let nego_data = NegoRequestData::read(src)?;
let Some(variable_part_rest_size) = variable_part_size.checked_sub(nego_data.as_ref().map(|data| data.size()).unwrap_or(0)) else {
return Err(Error::InvalidMessage { name: Self::NAME, field: "TPDU header variable part", reason: "advertised size too small" })
return Err(PduError::invalid_message(Self::NAME, "TPDU header variable part", "advertised size too small"));
};
if variable_part_rest_size >= usize::from(Self::RDP_NEG_REQ_SIZE) {
let msg_type = NegoMsgType::from(src.read_u8());
if msg_type != NegoMsgType::REQUEST {
return Err(Error::UnexpectedMessageType {
name: Self::NAME,
got: u8::from(msg_type),
});
return Err(PduError::unexpected_message_type(Self::NAME, u8::from(msg_type)));
}
let flags = RequestFlags::from_bits_truncate(src.read_u8());
if flags.contains(RequestFlags::CORRELATION_INFO_PRESENT) {
// TODO: support for RDP_NEG_CORRELATION_INFO
return Err(Error::InvalidMessage {
name: Self::NAME,
field: "flags",
reason: "CORRECTION_INFO_PRESENT flag is set, but not supported by IronRDP",
});
return Err(PduError::invalid_message(
Self::NAME,
"flags",
"CORRECTION_INFO_PRESENT flag is set, but not supported by IronRDP",
));
}
let _length = src.read_u16();
@ -309,7 +306,7 @@ impl<'de> X224Pdu<'de> for ConnectionConfirm {
const TPDU_CODE: TpduCode = TpduCode::CONNECTION_CONFIRM;
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
match self {
ConnectionConfirm::Response { flags, protocol } => {
dst.write_u8(u8::from(NegoMsgType::RESPONSE));
@ -328,7 +325,7 @@ impl<'de> X224Pdu<'de> for ConnectionConfirm {
Ok(())
}
fn x224_body_decode(src: &mut ReadCursor<'de>, _: &TpktHeader, tpdu: &TpduHeader) -> Result<Self> {
fn x224_body_decode(src: &mut ReadCursor<'de>, _: &TpktHeader, tpdu: &TpduHeader) -> PduResult<Self> {
let variable_part_size = tpdu.variable_part_size();
ensure_size!(in: src, size: variable_part_size);
@ -349,10 +346,7 @@ impl<'de> X224Pdu<'de> for ConnectionConfirm {
Ok(Self::Failure { code })
}
unexpected => Err(Error::UnexpectedMessageType {
name: Self::X224_NAME,
got: u8::from(unexpected),
}),
unexpected => Err(PduError::unexpected_message_type(Self::X224_NAME, u8::from(unexpected))),
}
} else {
Ok(Self::Response {
@ -374,8 +368,8 @@ impl<'de> X224Pdu<'de> for ConnectionConfirm {
}
}
fn read_nego_data(src: &mut ReadCursor<'_>, name: &'static str, prefix: &str) -> Result<Option<String>> {
ensure_size!(name: name, in: src, size: prefix.len() + 2);
fn read_nego_data(src: &mut ReadCursor<'_>, name: &'static str, prefix: &str) -> PduResult<Option<String>> {
ensure_size!(ctx: name, in: src, size: prefix.len() + 2);
if src.peek_slice(prefix.len()) != prefix.as_bytes() {
return Ok(None);
@ -387,7 +381,7 @@ fn read_nego_data(src: &mut ReadCursor<'_>, name: &'static str, prefix: &str) ->
while src.peek_u16() != 0x0A0D {
src.advance(1);
ensure_size!(name: name, in: src, size: 2);
ensure_size!(ctx: name, in: src, size: 2);
}
let identifier_end = src.pos();
@ -395,18 +389,14 @@ fn read_nego_data(src: &mut ReadCursor<'_>, name: &'static str, prefix: &str) ->
src.advance(2);
let data = core::str::from_utf8(&src.inner()[identifier_start..identifier_end])
.map_err(|_| Error::InvalidMessage {
name,
field: "identifier",
reason: "not valid UTF-8",
})?
.map_err(|_| PduError::invalid_message(name, "identifier", "not valid UTF-8"))?
.to_owned();
Ok(Some(data))
}
fn write_nego_data(dst: &mut WriteCursor<'_>, name: &'static str, prefix: &str, value: &str) -> Result<()> {
ensure_size!(name: name, in: dst, size: prefix.len() + value.len() + 2);
fn write_nego_data(dst: &mut WriteCursor<'_>, name: &'static str, prefix: &str, value: &str) -> PduResult<()> {
ensure_size!(ctx: name, in: dst, size: prefix.len() + value.len() + 2);
dst.write_slice(prefix.as_bytes());
dst.write_slice(value.as_bytes());

View file

@ -2,7 +2,7 @@
use crate::cursor::ReadCursor;
use crate::padding::Padding;
use crate::{Error, Pdu, PduDecode, PduEncode, Result};
use crate::{Pdu, PduDecode, PduEncode, PduError, PduErrorExt as _, PduResult};
/// Preconnection PDU version
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -43,17 +43,17 @@ impl Pdu for PreconnectionBlob {
}
impl<'de> PduDecode<'de> for PreconnectionBlob {
fn decode(src: &mut ReadCursor<'de>) -> Result<Self> {
fn decode(src: &mut ReadCursor<'de>) -> PduResult<Self> {
ensure_fixed_part_size!(in: src);
let pcb_size: usize = cast_length!(src.read_u32(), "cbSize")?;
let pcb_size: usize = cast_length!("cbSize", src.read_u32())?;
if pcb_size < Self::FIXED_PART_SIZE {
return Err(Error::InvalidMessage {
name: Self::NAME,
field: "cbSize",
reason: "advertised size too small for Preconnection PDU V1",
});
return Err(PduError::invalid_message(
Self::NAME,
"cbSize",
"advertised size too small for Preconnection PDU V1",
));
}
Padding::<4>::read(src); // flags
@ -74,33 +74,17 @@ impl<'de> PduDecode<'de> for PreconnectionBlob {
let cb_pcb = cch_pcb * 2;
if remaining_size - 2 < cb_pcb {
return Err(Error::InvalidMessage {
name: Self::NAME,
field: "cchPCB",
reason: "PCB string bigger than advertised size",
});
return Err(PduError::invalid_message(
Self::NAME,
"cchPCB",
"PCB string bigger than advertised size",
));
}
let wsz_pcb_utf16 = src.read_slice(cb_pcb);
let mut trimmed_pcb_utf16: Vec<u16> = Vec::with_capacity(cch_pcb);
for chunk in wsz_pcb_utf16.chunks_exact(2) {
let code_unit = u16::from_le_bytes([chunk[0], chunk[1]]);
// Stop reading at the null terminator
if code_unit == 0 {
break;
}
trimmed_pcb_utf16.push(code_unit);
}
let payload = String::from_utf16(&trimmed_pcb_utf16).map_err(|_| Error::InvalidMessage {
name: Self::NAME,
field: "wszPCB",
reason: "invalid UTF-16",
})?;
let payload = crate::utf16::read_utf16_string(wsz_pcb_utf16, Some(cch_pcb))
.map_err(|e| PduError::invalid_message(Self::NAME, "wszPCB", "bad UTF-16 string").with_source(e))?;
let leftover_size = remaining_size - 2 - cb_pcb;
src.advance(leftover_size); // Consume (unused) leftover data
@ -121,28 +105,28 @@ impl<'de> PduDecode<'de> for PreconnectionBlob {
}
impl PduEncode for PreconnectionBlob {
fn encode(&self, dst: &mut crate::cursor::WriteCursor<'_>) -> Result<()> {
fn encode(&self, dst: &mut crate::cursor::WriteCursor<'_>) -> PduResult<()> {
if self.v2_payload.is_some() && self.version == PcbVersion::V1 {
return Err(Error::InvalidMessage {
name: Self::NAME,
field: "version",
reason: "there is no string payload in Preconnection PDU V1",
});
return Err(PduError::invalid_message(
Self::NAME,
"version",
"there is no string payload in Preconnection PDU V1",
));
}
let pcb_size = self.size();
ensure_size!(in: dst, size: pcb_size);
dst.write_u32(cast_length!(pcb_size, "cbSize")?); // cbSize
dst.write_u32(cast_length!("cbSize", pcb_size)?); // cbSize
Padding::<4>::write(dst); // flags
dst.write_u32(self.version.0); // version
dst.write_u32(self.id); // id
if let Some(v2_payload) = &self.v2_payload {
// cchPCB
let utf16_character_count = v2_payload.encode_utf16().count() + 1; // +1 for null terminator
dst.write_u16(cast_length!(utf16_character_count, "cchPCB")?);
let utf16_character_count = v2_payload.chars().count() + 1; // +1 for null terminator
dst.write_u16(cast_length!("cchPCB", utf16_character_count)?);
// wszPCB
v2_payload.encode_utf16().for_each(|c| dst.write_u16(c));
@ -160,8 +144,8 @@ impl PduEncode for PreconnectionBlob {
let fixed_part_size = Self::FIXED_PART_SIZE;
let variable_part = if let Some(v2_payload) = &self.v2_payload {
let utf16_character_count = v2_payload.encode_utf16().count() + 1; // +1 for null terminator
2 + utf16_character_count * 2
let utf16_encoded_len = crate::utf16::null_terminated_utf16_encoded_len(v2_payload);
2 + utf16_encoded_len
} else {
0
};
@ -174,6 +158,8 @@ impl PduEncode for PreconnectionBlob {
mod tests {
use super::*;
use expect_test::expect;
const PRECONNECTION_PDU_V1_NULL_SIZE_BUF: [u8; 16] = [
0x00, 0x00, 0x00, 0x00, // -> RDP_PRECONNECTION_PDU_V1::cbSize = 0x00 = 0 bytes
0x00, 0x00, 0x00, 0x00, // -> RDP_PRECONNECTION_PDU_V1::Flags = 0
@ -231,12 +217,17 @@ mod tests {
.err()
.unwrap();
if let Error::InvalidMessage { field, reason, .. } = e {
assert_eq!(field, "cbSize");
assert_eq!(reason, "advertised size too small for Preconnection PDU V1");
} else {
panic!("unexpected error: {e}");
expect![[r#"
Error {
context: "PreconnectionBlob",
kind: InvalidMessage {
field: "cbSize",
reason: "advertised size too small for Preconnection PDU V1",
},
source: None,
}
"#]]
.assert_debug_eq(&e);
}
#[test]
@ -245,12 +236,17 @@ mod tests {
.err()
.unwrap();
if let Error::NotEnoughBytes { received, expected, .. } = e {
assert_eq!(received, 0);
assert_eq!(expected, 239);
} else {
panic!("unexpected error: {e}");
expect![[r#"
Error {
context: "PreconnectionBlob",
kind: NotEnoughBytes {
received: 0,
expected: 239,
},
source: None,
}
"#]]
.assert_debug_eq(&e);
}
#[test]
@ -277,12 +273,17 @@ mod tests {
.err()
.unwrap();
if let Error::InvalidMessage { field, reason, .. } = e {
assert_eq!(field, "cchPCB");
assert_eq!(reason, "PCB string bigger than advertised size");
} else {
panic!("unexpected error: {e}");
expect![[r#"
Error {
context: "PreconnectionBlob",
kind: InvalidMessage {
field: "cchPCB",
reason: "PCB string bigger than advertised size",
},
source: None,
}
"#]]
.assert_debug_eq(&e);
}
#[test]

View file

@ -1,5 +1,7 @@
#![allow(dead_code)]
use core::fmt;
use crate::cursor::{ReadCursor, WriteCursor};
pub(crate) const CHOICE_SIZE: usize = 1;
@ -8,13 +10,89 @@ pub(crate) const U16_SIZE: usize = 2;
const OBJECT_ID_SIZE: usize = 6;
pub(crate) fn read_length(src: &mut ReadCursor<'_>) -> crate::Result<(u16, usize)> {
ensure_size!(name: "PER LENGTH", in: src, size: 1);
let a = src.read_u8();
#[derive(Clone, Debug)]
pub(crate) enum PerError {
NotEnoughBytes { available: usize, required: usize },
InvalidLength { reason: &'static str },
Overflow,
Underflow,
UnexpectedEnumVariant,
OctetStringTooSmall,
OctetStringTooBig,
NumericStringTooSmall,
NumericStringTooBig,
}
impl std::error::Error for PerError {}
impl fmt::Display for PerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PerError::NotEnoughBytes { available, required } => write!(
f,
"not enough bytes to read PEM element: {available} bytes availables, required {required} bytes"
),
PerError::InvalidLength { reason } => write!(f, "invalid length: {reason}"),
PerError::Overflow => write!(f, "overflow"),
PerError::Underflow => write!(f, "underflow"),
PerError::UnexpectedEnumVariant => write!(f, "enumerated value does not fall within the expected range"),
PerError::OctetStringTooSmall => write!(f, "octet string too small"),
PerError::OctetStringTooBig => write!(f, "octet string too big"),
PerError::NumericStringTooSmall => write!(f, "numeric string too small"),
PerError::NumericStringTooBig => write!(f, "numeric string too big"),
}
}
}
fn try_read_u8(src: &mut ReadCursor<'_>) -> Result<u8, PerError> {
if src.is_empty() {
Err(PerError::NotEnoughBytes {
available: src.len(),
required: 1,
})
} else {
Ok(src.read_u8())
}
}
fn try_read_u16_be(src: &mut ReadCursor<'_>) -> Result<u16, PerError> {
if src.len() >= 2 {
Ok(src.read_u16_be())
} else {
Err(PerError::NotEnoughBytes {
available: src.len(),
required: 2,
})
}
}
fn try_read_u32_be(src: &mut ReadCursor<'_>) -> Result<u32, PerError> {
if src.len() >= 4 {
Ok(src.read_u32_be())
} else {
Err(PerError::NotEnoughBytes {
available: src.len(),
required: 4,
})
}
}
fn try_read_slice<'a>(src: &mut ReadCursor<'a>, n: usize) -> Result<&'a [u8], PerError> {
if src.len() >= n {
Ok(src.read_slice(n))
} else {
Err(PerError::NotEnoughBytes {
available: src.len(),
required: n,
})
}
}
pub(crate) fn read_length(src: &mut ReadCursor<'_>) -> Result<(u16, usize), PerError> {
let a = try_read_u8(src)?;
if a & 0x80 != 0 {
ensure_size!(name: "PER LENGTH", in: src, size: 1);
let b = src.read_u8();
let b = try_read_u8(src)?;
let length = ((u16::from(a) & !0x80) << 8) + u16::from(b);
Ok((length, 2))
@ -83,26 +161,16 @@ pub(crate) fn write_padding(dst: &mut WriteCursor<'_>, padding_length: usize) {
}
}
pub(crate) fn read_u32(src: &mut ReadCursor<'_>) -> crate::Result<u32> {
pub(crate) fn read_u32(src: &mut ReadCursor<'_>) -> Result<u32, PerError> {
let (length, _) = read_length(src)?;
match length {
0 => Ok(0),
1 => {
ensure_size!(name: "PER U32", in: src, size: 1);
Ok(u32::from(src.read_u8()))
}
2 => {
ensure_size!(name: "PER U32", in: src, size: 2);
Ok(u32::from(src.read_u16_be()))
}
4 => {
ensure_size!(name: "PER U32", in: src, size: 4);
Ok(src.read_u32_be())
}
_ => Err(crate::Error::Other {
context: "PER",
reason: "invalid length for u32",
1 => Ok(u32::from(try_read_u8(src)?)),
2 => Ok(u32::from(try_read_u16_be(src)?)),
4 => Ok(try_read_u32_be(src)?),
_ => Err(PerError::InvalidLength {
reason: "U32 with length greater than 4 bytes",
}),
}
}
@ -120,35 +188,21 @@ pub(crate) fn write_u32(dst: &mut WriteCursor<'_>, value: u32) {
}
}
pub(crate) fn read_u16(src: &mut ReadCursor<'_>, min: u16) -> crate::Result<u16> {
ensure_size!(name: "PER U16", in: src, size: 2);
min.checked_add(src.read_u16_be()).ok_or(crate::Error::Other {
context: "PER",
reason: "invalid u16",
})
pub(crate) fn read_u16(src: &mut ReadCursor<'_>, min: u16) -> Result<u16, PerError> {
let value = try_read_u16_be(src)?;
min.checked_add(value).ok_or(PerError::Overflow)
}
pub(crate) fn write_u16(dst: &mut WriteCursor<'_>, value: u16, min: u16) -> crate::Result<()> {
if value < min {
Err(crate::Error::Other {
context: "PER",
reason: "u16 value greater than specified minimum",
})
} else {
dst.write_u16_be(value - min);
pub(crate) fn write_u16(dst: &mut WriteCursor<'_>, value: u16, min: u16) -> Result<(), PerError> {
dst.write_u16_be(value.checked_sub(min).ok_or(PerError::Underflow)?);
Ok(())
}
}
pub(crate) fn read_enum(src: &mut ReadCursor<'_>, count: u8) -> crate::Result<u8> {
ensure_size!(name: "PER ENUM", in: src, size: 1);
let enumerated = src.read_u8();
pub(crate) fn read_enum(src: &mut ReadCursor<'_>, count: u8) -> Result<u8, PerError> {
let enumerated = try_read_u8(src)?;
if u16::from(enumerated) + 1 > u16::from(count) {
Err(crate::Error::Other {
context: "PER",
reason: "enumerated value does not fall within expected range",
})
if enumerated >= count {
Err(PerError::UnexpectedEnumVariant)
} else {
Ok(enumerated)
}
@ -158,25 +212,22 @@ pub(crate) fn write_enum(dst: &mut WriteCursor<'_>, enumerated: u8) {
dst.write_u8(enumerated);
}
pub(crate) fn read_object_id(src: &mut ReadCursor<'_>) -> crate::Result<[u8; OBJECT_ID_SIZE]> {
pub(crate) fn read_object_id(src: &mut ReadCursor<'_>) -> Result<[u8; OBJECT_ID_SIZE], PerError> {
let (length, _) = read_length(src)?;
if length != 5 {
return Err(crate::Error::Other {
context: "PER",
reason: "invalid object id length",
return Err(PerError::InvalidLength {
reason: "invalid OID length advertised",
});
}
ensure_size!(name: "PER OID", in: src, size: 5);
let first_two_tuples = src.read_u8();
let first_two_tuples = try_read_u8(src)?;
let mut read_object_ids = [0u8; OBJECT_ID_SIZE];
read_object_ids[0] = first_two_tuples / 40;
read_object_ids[1] = first_two_tuples % 40;
for read_object_id in read_object_ids.iter_mut().skip(2) {
*read_object_id = src.read_u8();
*read_object_id = try_read_u8(src)?;
}
Ok(read_object_ids)
@ -193,36 +244,51 @@ pub(crate) fn write_object_id(dst: &mut WriteCursor<'_>, object_ids: [u8; OBJECT
}
}
pub(crate) fn read_octet_string(src: &mut ReadCursor<'_>, min: usize) -> crate::Result<Vec<u8>> {
pub(crate) fn read_octet_string<'a>(src: &mut ReadCursor<'a>, min: usize) -> Result<&'a [u8], PerError> {
let (length, _) = read_length(src)?;
let read_len = min + usize::from(length);
ensure_size!(name: "PER OCTET_STRING", in: src, size: read_len);
Ok(src.read_slice(read_len).to_owned())
let octet_string = try_read_slice(src, read_len)?;
Ok(octet_string)
}
pub(crate) fn write_octet_string(dst: &mut WriteCursor<'_>, octet_string: &[u8], min: usize) {
let length = if octet_string.len() >= min {
octet_string.len() - min
} else {
min
};
pub(crate) fn write_octet_string(dst: &mut WriteCursor<'_>, octet_string: &[u8], min: usize) -> Result<(), PerError> {
if octet_string.len() < min {
return Err(PerError::OctetStringTooSmall);
}
let length = octet_string.len() - min;
let length = u16::try_from(length).map_err(|_| PerError::OctetStringTooBig)?;
write_length(dst, length);
write_length(dst, length as u16);
dst.write_slice(octet_string);
}
pub(crate) fn read_numeric_string(src: &mut ReadCursor<'_>, min: u16) -> crate::Result<()> {
let (length, _) = read_length(src)?;
let length = usize::from((length + min + 1) / 2);
ensure_size!(name: "PER NUMERIC_STRING", in: src, size: length);
src.advance(length);
Ok(())
}
pub(crate) fn write_numeric_string(dst: &mut WriteCursor<'_>, num_str: &[u8], min: usize) {
let length = if num_str.len() >= min { num_str.len() - min } else { min };
pub(crate) fn read_numeric_string(src: &mut ReadCursor<'_>, min: u16) -> Result<(), PerError> {
let (length, _) = read_length(src)?;
let length = usize::from((length + min + 1) / 2);
write_length(dst, u16::try_from(length).unwrap());
if src.len() < length {
Err(PerError::NotEnoughBytes {
available: src.len(),
required: length,
})
} else {
src.advance(length);
Ok(())
}
}
pub(crate) fn write_numeric_string(dst: &mut WriteCursor<'_>, num_str: &[u8], min: usize) -> Result<(), PerError> {
if num_str.len() < min {
return Err(PerError::NumericStringTooSmall);
}
let length = num_str.len() - min;
let length = u16::try_from(length).map_err(|_| PerError::NumericStringTooBig)?;
write_length(dst, length);
let magic_transform = |elem| (elem - 0x30) % 10;
@ -234,6 +300,8 @@ pub(crate) fn write_numeric_string(dst: &mut WriteCursor<'_>, num_str: &[u8], mi
dst.write_u8(num);
}
Ok(())
}
pub(crate) mod legacy {
@ -481,6 +549,8 @@ pub(crate) mod legacy {
mod tests {
use super::*;
use expect_test::expect;
#[test]
fn read_length_is_correct_length() {
let mut src = ReadCursor::new(&[0x05]);
@ -632,12 +702,11 @@ mod tests {
let buf = [0xff, 0xff];
let mut src = ReadCursor::new(&buf);
match read_u16(&mut src, 1) {
Err(crate::Error::Other { reason, .. }) => {
assert_eq!(reason, "invalid u16");
}
_ => panic!("Unexpected result"),
};
let e = read_u16(&mut src, 1).err().unwrap();
expect![[r#"
Overflow
"#]].assert_debug_eq(&e)
}
#[test]
@ -659,12 +728,9 @@ mod tests {
let e = write_u16(&mut dst, 1000, 1001).err().unwrap();
if let crate::Error::Other { context, reason } = e {
assert_eq!(context, "PER");
assert_eq!(reason, "u16 value greater than specified minimum");
} else {
panic!("unexpected error: {e}");
}
expect![[r#"
Underflow
"#]].assert_debug_eq(&e);
}
#[test]
@ -691,12 +757,11 @@ mod tests {
let buf = [0x05];
let mut src = ReadCursor::new(&buf);
match read_enum(&mut src, 1) {
Err(crate::Error::Other { reason, .. }) => {
assert_eq!(reason, "enumerated value does not fall within expected range");
}
_ => panic!("Unexpected result"),
};
let e = read_enum(&mut src, 1).err().unwrap();
expect![[r#"
UnexpectedEnumVariant
"#]].assert_debug_eq(&e);
}
#[test]
@ -712,12 +777,11 @@ mod tests {
let buf = [0xff];
let mut src = ReadCursor::new(&buf);
match read_enum(&mut src, 0xff) {
Err(crate::Error::Other { reason, .. }) => {
assert_eq!(reason, "enumerated value does not fall within expected range");
}
_ => panic!("Unexpected result"),
};
let e = read_enum(&mut src, 0xff).err().unwrap();
expect![[r#"
UnexpectedEnumVariant
"#]].assert_debug_eq(&e);
}
#[test]
@ -736,7 +800,7 @@ mod tests {
let mut buf = [0; 2];
let mut dst = WriteCursor::new(&mut buf);
write_numeric_string(&mut dst, octet_string, 1);
write_numeric_string(&mut dst, octet_string, 1).unwrap();
assert_eq!(dst.len(), 0);
assert_eq!(buf, expected_buf);
@ -747,7 +811,7 @@ mod tests {
let buf = [0x00, 0x44, 0x75, 0x63, 0x61];
let mut src = ReadCursor::new(&buf);
assert_eq!(b"Duca", read_octet_string(&mut src, 4).unwrap().as_slice());
assert_eq!(b"Duca", read_octet_string(&mut src, 4).unwrap());
}
#[test]
@ -758,7 +822,7 @@ mod tests {
let mut buf = [0; 5];
let mut dst = WriteCursor::new(&mut buf);
write_octet_string(&mut dst, octet_string, 4);
write_octet_string(&mut dst, octet_string, 4).unwrap();
assert_eq!(dst.len(), 0);
assert_eq!(buf, expected_buf);

View file

@ -96,3 +96,9 @@ impl From<RdpError> for io::Error {
io::Error::new(io::ErrorKind::Other, format!("RDP Connection Sequence error: {e}"))
}
}
impl ironrdp_error::legacy::ErrorContext for RdpError {
fn context(&self) -> &'static str {
"RDP"
}
}

View file

@ -570,12 +570,8 @@ pub enum ClientInfoError {
}
fn string_len(value: &str, character_set: CharacterSet) -> u16 {
// FIXME: this is not the right way to compute the number of bytes for unicode strings.
// This is a time bomb: both UTF-8 and UTF-16 are using a variable-length encoding and code points may be encoded
// using multiple code units. The thing is, UTF-16 uses one or two 16-bit code units and
// UTF-8 uses between one and four 8-bit code units. Its really not always the case that a code point
// in UTF-16 is twice as big as the same code point in UTF-8.
// Refer to `ironrdp_pdu::pcb` module for a correct implementation.
// Something like that: value.encode_utf16().count() * 2 (+2 if we need to insert a null terminator: 0x0000)
value.len() as u16 * character_set.to_u16().unwrap()
match character_set {
CharacterSet::Ansi => u16::try_from(value.len()).unwrap(),
CharacterSet::Unicode => u16::try_from(value.encode_utf16().count() * 2).unwrap(),
}
}

View file

@ -233,6 +233,12 @@ pub enum ServerLicenseError {
BlobTooSmall,
}
impl ironrdp_error::legacy::ErrorContext for ServerLicenseError {
fn context(&self) -> &'static str {
"server license"
}
}
pub struct BlobHeader {
pub blob_type: BlobType,
pub length: usize,

View file

@ -95,3 +95,9 @@ impl From<ChannelError> for io::Error {
io::Error::new(io::ErrorKind::Other, format!("Virtual channel error: {e}"))
}
}
impl ironrdp_error::legacy::ErrorContext for ChannelError {
fn context(&self) -> &'static str {
"virtual channel error"
}
}

View file

@ -297,3 +297,9 @@ pub enum DisplayPipelineError {
#[error("Invalid PDU length: expected ({expected}) != actual ({actual})")]
InvalidPduLength { expected: usize, actual: usize },
}
impl ironrdp_error::legacy::ErrorContext for DisplayPipelineError {
fn context(&self) -> &'static str {
"display pipeline"
}
}

View file

@ -313,3 +313,9 @@ pub enum GraphicsPipelineError {
#[error("Invalid PDU length: expected ({expected}) != actual ({actual})")]
InvalidPduLength { expected: usize, actual: usize },
}
impl ironrdp_error::legacy::ErrorContext for GraphicsPipelineError {
fn context(&self) -> &'static str {
"graphics pipeline"
}
}

View file

@ -1,6 +1,7 @@
use crate::cursor::{ReadCursor, WriteCursor};
use crate::padding::Padding;
use crate::tpkt::TpktHeader;
use crate::{PduError, PduErrorExt as _, PduResult};
/// TPDU type used during X.224 messages exchange
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
@ -23,14 +24,11 @@ impl TpduCode {
}
}
pub fn check_expected(self, expected: TpduCode) -> crate::Result<()> {
pub fn check_expected(self, expected: TpduCode) -> PduResult<()> {
if self == expected {
Ok(())
} else {
Err(crate::Error::UnexpectedMessageType {
name: TpduHeader::NAME,
got: self.0,
})
Err(PduError::unexpected_message_type(TpduHeader::NAME, self.0))
}
}
}
@ -117,27 +115,27 @@ impl TpduHeader {
const FIXED_PART_SIZE: usize = Self::DATA_FIXED_PART_SIZE;
pub fn read(src: &mut ReadCursor<'_>, tpkt: &TpktHeader) -> crate::Result<Self> {
pub fn read(src: &mut ReadCursor<'_>, tpkt: &TpktHeader) -> PduResult<Self> {
ensure_fixed_part_size!(in: src);
let li = src.read_u8(); // LI
let code = TpduCode::from(src.read_u8()); // Code
if usize::from(li) + 1 + TpktHeader::SIZE > usize::from(tpkt.packet_length) {
return Err(crate::Error::InvalidMessage {
name: Self::NAME,
field: "li",
reason: "tpdu length greater than tpkt length",
});
return Err(PduError::invalid_message(
Self::NAME,
"li",
"tpdu length greater than tpkt length",
));
}
// The value 255 (1111 1111) is reserved for possible extensions.
if li == 0b1111_1111 {
return Err(crate::Error::InvalidMessage {
name: Self::NAME,
field: "li",
reason: "unsupported X.224 extension (suggested by LI field set to 255)",
});
return Err(PduError::invalid_message(
Self::NAME,
"li",
"unsupported X.224 extension (suggested by LI field set to 255)",
));
}
if code == TpduCode::DATA {
@ -150,7 +148,7 @@ impl TpduHeader {
Ok(Self { li, code })
}
pub fn write(&self, dst: &mut WriteCursor<'_>) -> crate::Result<()> {
pub fn write(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
const EOT_BYTE: u8 = 0x80;
ensure_fixed_part_size!(in: dst);

View file

@ -1,5 +1,6 @@
use crate::cursor::{ReadCursor, WriteCursor};
use crate::padding::Padding;
use crate::{PduError, PduErrorExt as _, PduResult};
/// TPKT header
///
@ -47,16 +48,13 @@ impl TpktHeader {
const FIXED_PART_SIZE: usize = Self::SIZE;
pub fn read(src: &mut ReadCursor<'_>) -> crate::Result<Self> {
pub fn read(src: &mut ReadCursor<'_>) -> PduResult<Self> {
ensure_fixed_part_size!(in: src);
let version = src.read_u8();
if version != Self::VERSION {
return Err(crate::Error::UnsupportedVersion {
name: "TPKT version",
got: version,
});
return Err(PduError::unsupported_version("TPKT version", version));
}
Padding::<1>::read(src);
@ -66,7 +64,7 @@ impl TpktHeader {
Ok(Self { packet_length })
}
pub fn write(&self, dst: &mut WriteCursor<'_>) -> crate::Result<()> {
pub fn write(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
ensure_fixed_part_size!(in: dst);
dst.write_u8(Self::VERSION);

View file

@ -0,0 +1,26 @@
use std::string::FromUtf16Error;
pub fn read_utf16_string(utf16_payload: &[u8], utf16_size_hint: Option<usize>) -> Result<String, FromUtf16Error> {
let mut trimmed_utf16: Vec<u16> = if let Some(size_hint) = utf16_size_hint {
Vec::with_capacity(size_hint)
} else {
Vec::with_capacity(utf16_payload.len() / 2)
};
for chunk in utf16_payload.chunks_exact(2) {
let code_unit = u16::from_le_bytes([chunk[0], chunk[1]]);
// Stop reading at the null terminator
if code_unit == 0 {
break;
}
trimmed_utf16.push(code_unit);
}
String::from_utf16(&trimmed_utf16)
}
pub fn null_terminated_utf16_encoded_len(utf8: &str) -> usize {
utf8.encode_utf16().count() * 2 + 2
}

View file

@ -3,16 +3,16 @@ use std::borrow::Cow;
use crate::cursor::{ReadCursor, WriteCursor};
use crate::tpdu::{TpduCode, TpduHeader};
use crate::tpkt::TpktHeader;
use crate::{IntoOwnedPdu, Pdu, PduDecode, PduEncode, Result};
use crate::{IntoOwnedPdu, Pdu, PduDecode, PduEncode, PduError, PduErrorExt as _, PduResult};
pub trait X224Pdu<'de>: Sized {
const X224_NAME: &'static str;
const TPDU_CODE: TpduCode;
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()>;
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()>;
fn x224_body_decode(src: &mut ReadCursor<'de>, tpkt: &TpktHeader, tpdu: &TpduHeader) -> Result<Self>;
fn x224_body_decode(src: &mut ReadCursor<'de>, tpkt: &TpktHeader, tpdu: &TpduHeader) -> PduResult<Self>;
fn tpdu_header_variable_part_size(&self) -> usize;
@ -30,7 +30,7 @@ impl<'de, T> PduEncode for T
where
T: X224Pdu<'de>,
{
fn encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
let packet_length = self.size();
ensure_size!(in: dst, size: packet_length);
@ -66,7 +66,7 @@ impl<'de, T> PduDecode<'de> for T
where
T: X224Pdu<'de>,
{
fn decode(src: &mut ReadCursor<'de>) -> Result<Self> {
fn decode(src: &mut ReadCursor<'de>) -> PduResult<Self> {
let tpkt = TpktHeader::read(src)?;
ensure_size!(in: src, size: tpkt.packet_length().saturating_sub(TpktHeader::SIZE));
@ -75,11 +75,11 @@ where
tpdu.code.check_expected(T::TPDU_CODE)?;
if tpdu.size() < tpdu.fixed_part_size() {
return Err(crate::Error::InvalidMessage {
name: "TpduHeader",
field: "li",
reason: "fixed part bigger than total header size",
});
return Err(PduError::invalid_message(
"TpduHeader",
"li",
"fixed part bigger than total header size",
));
}
T::x224_body_decode(src, &tpkt, &tpdu)
@ -107,14 +107,14 @@ impl<'de> X224Pdu<'de> for X224Data<'de> {
const TPDU_CODE: TpduCode = TpduCode::DATA;
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> Result<()> {
fn x224_body_encode(&self, dst: &mut WriteCursor<'_>) -> PduResult<()> {
ensure_size!(in: dst, size: self.data.len());
dst.write_slice(&self.data);
Ok(())
}
fn x224_body_decode(src: &mut ReadCursor<'de>, tpkt: &TpktHeader, tpdu: &TpduHeader) -> Result<Self> {
fn x224_body_decode(src: &mut ReadCursor<'de>, tpkt: &TpktHeader, tpdu: &TpduHeader) -> PduResult<Self> {
let user_data_size = user_data_size(tpkt, tpdu);
ensure_size!(in: src, size: user_data_size);

View file

@ -16,9 +16,10 @@ doctest = false
test = false
[dependencies]
ironrdp-pdu.workspace = true
bitflags = "2" # TODO: investigate usage in this crate
ironrdp-connector.workspace = true # TODO: at some point, this dependency could be removed (good for compilation speed)
ironrdp-error.workspace = true
ironrdp-graphics.workspace = true
ironrdp-pdu.workspace = true
sspi.workspace = true
tracing.workspace = true
bitflags = "2" # TODO: investigate usage in this crate

View file

@ -4,7 +4,7 @@ use ironrdp_pdu::Action;
use crate::image::DecodedImage;
use crate::x224::GfxHandler;
use crate::{fast_path, utils, x224, Result};
use crate::{fast_path, utils, x224, SessionResult};
pub struct ActiveStage {
x224_processor: x224::Processor,
@ -38,7 +38,7 @@ impl ActiveStage {
image: &mut DecodedImage,
action: Action,
frame: &[u8],
) -> Result<Vec<ActiveStageOutput>> {
) -> SessionResult<Vec<ActiveStageOutput>> {
let mut graphics_update_region = None;
let output = match action {
@ -64,12 +64,16 @@ impl ActiveStage {
}
/// Sends a PDU on the dynamic channel.
pub fn encode_dynamic(&self, output: &mut Vec<u8>, channel_name: &str, dvc_data: &[u8]) -> Result<usize> {
pub fn encode_dynamic(&self, output: &mut Vec<u8>, channel_name: &str, dvc_data: &[u8]) -> SessionResult<usize> {
self.x224_processor.encode_dynamic(output, channel_name, dvc_data)
}
/// Send a pdu on the static global channel. Typically used to send input events
pub fn encode_static(&self, output: &mut Vec<u8>, pdu: ironrdp_pdu::rdp::headers::ShareDataPdu) -> Result<usize> {
pub fn encode_static(
&self,
output: &mut Vec<u8>,
pdu: ironrdp_pdu::rdp::headers::ShareDataPdu,
) -> SessionResult<usize> {
self.x224_processor.encode_static(output, pdu)
}
}

View file

@ -11,7 +11,7 @@ use ironrdp_pdu::PduBufferParsing;
use crate::image::DecodedImage;
use crate::utils::CodecId;
use crate::{rfx, Error, Result};
use crate::{rfx, SessionResult};
pub struct Processor {
complete_data: CompleteData,
@ -27,7 +27,7 @@ impl Processor {
image: &mut DecodedImage,
mut input: &[u8],
output: &mut Vec<u8>,
) -> Result<Option<Rectangle>> {
) -> SessionResult<Option<Rectangle>> {
use ironrdp_pdu::PduParsing as _;
let header = FastPathHeader::from_buffer(&mut input)?;
@ -150,7 +150,7 @@ impl Processor {
warn!(?error, "Received invalid bitmap");
Ok(None)
}
Err(e) => Err(Error::new("Fast-Path").with_custom(e)),
Err(e) => Err(custom_err!("Fast-Path", e)),
}
}
@ -159,7 +159,7 @@ impl Processor {
image: &mut DecodedImage,
output: &mut Vec<u8>,
surface_commands: Vec<SurfaceCommand<'_>>,
) -> Result<Rectangle> {
) -> SessionResult<Rectangle> {
let mut update_rectangle = Rectangle::empty();
for command in surface_commands {
@ -168,8 +168,11 @@ impl Processor {
trace!("Surface bits");
let codec_id = CodecId::from_u8(bits.extended_bitmap_data.codec_id).ok_or_else(|| {
Error::new("unexpected codec ID")
.with_reason(format!("{:x}", bits.extended_bitmap_data.codec_id))
reason_err!(
"Fast-Path",
"unexpected codec ID: {:x}",
bits.extended_bitmap_data.codec_id
)
})?;
match codec_id {
@ -283,7 +286,7 @@ impl FrameMarkerProcessor {
}
}
fn process(&mut self, marker: &FrameMarkerPdu, output: &mut Vec<u8>) -> Result<()> {
fn process(&mut self, marker: &FrameMarkerPdu, output: &mut Vec<u8>) -> SessionResult<()> {
match marker.frame_action {
FrameAction::Begin => Ok(()),
FrameAction::End => {
@ -295,7 +298,8 @@ impl FrameMarkerProcessor {
frame_id: marker.frame_id.unwrap_or(0),
}),
output,
)?;
)
.map_err(crate::legacy::map_error)?;
output.truncate(written);

View file

@ -2,7 +2,7 @@ use ironrdp_graphics::image_processing::{ImageRegion, ImageRegionMut, PixelForma
use ironrdp_graphics::rectangle_processing::Region;
use ironrdp_pdu::geometry::Rectangle;
use crate::{Error, Result};
use crate::SessionResult;
const TILE_SIZE: u16 = 64;
const SOURCE_PIXEL_FORMAT: PixelFormat = PixelFormat::BgrX32;
@ -49,7 +49,7 @@ impl DecodedImage {
clipping_rectangles: &Region,
update_rectangle: &Rectangle,
width: u16,
) -> Result<()> {
) -> SessionResult<()> {
debug!("Tile: {:?}", update_rectangle);
let update_region = clipping_rectangles.intersect_rectangle(update_rectangle);
@ -80,7 +80,7 @@ impl DecodedImage {
source_image_region
.copy_to(&mut destination_image_region)
.map_err(|e| Error::new("copy_to").with_custom(e))?;
.map_err(|e| custom_err!("copy_to", e))?;
}
Ok(())

View file

@ -2,13 +2,15 @@ use ironrdp_connector::legacy::{encode_send_data_request, SendDataIndicationCtx}
use ironrdp_pdu::rdp::vc;
use ironrdp_pdu::PduParsing as _;
use crate::{SessionError, SessionResult};
pub fn encode_dvc_message(
initiator_id: u16,
drdynvc_id: u16,
dvc_pdu: vc::dvc::ClientPdu,
dvc_data: &[u8],
mut buf: &mut Vec<u8>,
) -> crate::Result<usize> {
) -> SessionResult<usize> {
let dvc_length = dvc_pdu.buffer_length() + dvc_data.len();
let channel_header = vc::ChannelPduHeader {
@ -17,7 +19,7 @@ pub fn encode_dvc_message(
};
// [ TPKT | TPDU | SendDataRequest | vc::ChannelPduHeader | …
let written = encode_send_data_request(initiator_id, drdynvc_id, &channel_header, buf)?;
let written = encode_send_data_request(initiator_id, drdynvc_id, &channel_header, buf).map_err(map_error)?;
buf.truncate(written);
// … | dvc::ClientPdu | …
@ -36,7 +38,7 @@ pub struct DynamicChannelCtx<'a> {
pub dvc_data: &'a [u8],
}
pub fn decode_dvc_message(ctx: SendDataIndicationCtx<'_>) -> crate::Result<DynamicChannelCtx<'_>> {
pub fn decode_dvc_message(ctx: SendDataIndicationCtx<'_>) -> SessionResult<DynamicChannelCtx<'_>> {
let mut user_data = ctx.user_data;
let user_data_len = user_data.len();
@ -53,57 +55,25 @@ pub fn decode_dvc_message(ctx: SendDataIndicationCtx<'_>) -> crate::Result<Dynam
Ok(DynamicChannelCtx { dvc_pdu, dvc_data })
}
impl From<ironrdp_pdu::rdp::vc::ChannelError> for crate::Error {
fn from(e: ironrdp_pdu::rdp::vc::ChannelError) -> Self {
Self::new("virtual channel error").with_custom(e)
}
}
// FIXME: code should be fixed so that we never need this conversion
// For that, some code from this ironrdp_session::legacy and ironrdp_connector::legacy modules should be moved to ironrdp_pdu itself
impl From<ironrdp_connector::Error> for crate::Error {
fn from(value: ironrdp_connector::Error) -> Self {
Self {
context: value.context,
kind: match value.kind {
ironrdp_connector::ErrorKind::Pdu(e) => crate::ErrorKind::Pdu(e),
ironrdp_connector::ErrorKind::Credssp(_) => panic!("unexpected"),
ironrdp_connector::ErrorKind::AccessDenied => panic!("unexpected"),
ironrdp_connector::ErrorKind::Custom(e) => crate::ErrorKind::Custom(e),
ironrdp_connector::ErrorKind::General => crate::ErrorKind::General,
_ => crate::ErrorKind::General,
},
reason: value.reason,
impl From<ironrdp_connector::ConnectorErrorKind> for crate::SessionErrorKind {
fn from(value: ironrdp_connector::ConnectorErrorKind) -> Self {
match value {
ironrdp_connector::ConnectorErrorKind::Pdu(e) => crate::SessionErrorKind::Pdu(e),
ironrdp_connector::ConnectorErrorKind::Credssp(_) => panic!("unexpected"),
ironrdp_connector::ConnectorErrorKind::AccessDenied => panic!("unexpected"),
ironrdp_connector::ConnectorErrorKind::General => crate::SessionErrorKind::General,
ironrdp_connector::ConnectorErrorKind::Custom => crate::SessionErrorKind::Custom,
_ => crate::SessionErrorKind::General,
}
}
}
impl From<ironrdp_pdu::fast_path::FastPathError> for crate::Error {
fn from(e: ironrdp_pdu::fast_path::FastPathError) -> Self {
Self::new("Fast-Path").with_custom(e)
}
pub(crate) fn map_error(error: ironrdp_connector::ConnectorError) -> SessionError {
error.into_other_kind()
}
impl From<ironrdp_pdu::codecs::rfx::RfxError> for crate::Error {
fn from(e: ironrdp_pdu::codecs::rfx::RfxError) -> Self {
Self::new("RFX").with_custom(e)
}
}
impl From<ironrdp_pdu::dvc::display::DisplayPipelineError> for crate::Error {
fn from(e: ironrdp_pdu::dvc::display::DisplayPipelineError) -> Self {
Self::new("display pipeline").with_custom(e)
}
}
impl From<ironrdp_graphics::zgfx::ZgfxError> for crate::Error {
fn from(e: ironrdp_graphics::zgfx::ZgfxError) -> Self {
Self::new("zgfx").with_reason(e.to_string())
}
}
impl From<ironrdp_pdu::dvc::gfx::GraphicsPipelineError> for crate::Error {
fn from(e: ironrdp_pdu::dvc::gfx::GraphicsPipelineError) -> Self {
Self::new("graphics pipeline").with_reason(e.to_string())
}
impl ironrdp_error::legacy::CatchAllKind for crate::SessionErrorKind {
const CATCH_ALL_VALUE: Self = crate::SessionErrorKind::General;
}

View file

@ -1,6 +1,9 @@
#[macro_use]
extern crate tracing;
#[macro_use]
mod macros;
pub mod image;
pub mod legacy;
pub mod rfx; // FIXME: maybe this module should not be in this crate
@ -14,119 +17,79 @@ use core::fmt;
pub use active_stage::{ActiveStage, ActiveStageOutput};
pub type Result<T> = std::result::Result<T, Error>;
pub type SessionResult<T> = std::result::Result<T, SessionError>;
#[non_exhaustive]
#[derive(Debug)]
pub enum ErrorKind {
Pdu(ironrdp_pdu::Error),
Custom(Box<dyn std::error::Error + Sync + Send + 'static>),
pub enum SessionErrorKind {
Pdu(ironrdp_pdu::PduError),
Reason(String),
General,
Custom,
}
#[derive(Debug)]
pub struct Error {
pub context: &'static str,
pub kind: ErrorKind,
pub reason: Option<String>,
impl fmt::Display for SessionErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self {
SessionErrorKind::Pdu(_) => write!(f, "PDU error"),
SessionErrorKind::Reason(description) => write!(f, "reason: {description}"),
SessionErrorKind::General => write!(f, "general"),
SessionErrorKind::Custom => write!(f, "custom"),
}
}
}
impl Error {
pub const fn new(context: &'static str) -> Self {
Self {
context,
kind: ErrorKind::General,
reason: None,
impl std::error::Error for SessionErrorKind {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self {
SessionErrorKind::Pdu(e) => Some(e),
SessionErrorKind::Reason(_) => None,
SessionErrorKind::General => None,
SessionErrorKind::Custom => None,
}
}
}
pub type SessionError = ironrdp_error::Error<SessionErrorKind>;
pub trait SessionErrorExt {
fn pdu(error: ironrdp_pdu::PduError) -> Self;
fn general(context: &'static str) -> Self;
fn reason(context: &'static str, reason: impl Into<String>) -> Self;
fn custom<E>(context: &'static str, e: E) -> Self
where
E: std::error::Error + Sync + Send + 'static;
}
impl SessionErrorExt for SessionError {
fn pdu(error: ironrdp_pdu::PduError) -> Self {
Self::new("invalid payload", SessionErrorKind::Pdu(error))
}
pub fn with_kind(mut self, kind: ErrorKind) -> Self {
self.kind = kind;
self
fn general(context: &'static str) -> Self {
Self::new(context, SessionErrorKind::General)
}
pub fn with_custom<E>(mut self, custom_error: E) -> Self
fn reason(context: &'static str, reason: impl Into<String>) -> Self {
Self::new(context, SessionErrorKind::Reason(reason.into()))
}
fn custom<E>(context: &'static str, e: E) -> Self
where
E: std::error::Error + Sync + Send + 'static,
{
self.kind = ErrorKind::Custom(Box::new(custom_error));
self
}
pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
self.reason = Some(reason.into());
self
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
ErrorKind::Pdu(e) => Some(e),
ErrorKind::Custom(e) => Some(e.as_ref()),
ErrorKind::General => None,
}
}
}
impl From<Error> for std::io::Error {
fn from(error: Error) -> Self {
std::io::Error::new(std::io::ErrorKind::Other, error)
}
}
impl From<ironrdp_pdu::Error> for Error {
fn from(value: ironrdp_pdu::Error) -> Self {
Self {
context: "invalid payload",
kind: ErrorKind::Pdu(value),
reason: None,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.context)?;
match &self.kind {
ErrorKind::Pdu(e) => {
if f.alternate() {
write!(f, ": {e}")?;
}
}
ErrorKind::Custom(e) => {
if f.alternate() {
write!(f, ": {e}")?;
let mut next_source = e.source();
while let Some(e) = next_source {
write!(f, ", caused by: {e}")?;
next_source = e.source();
}
}
}
ErrorKind::General => {}
}
if let Some(reason) = &self.reason {
write!(f, " ({reason})")?;
}
Ok(())
Self::new(context, SessionErrorKind::Custom).with_source(e)
}
}
pub trait SessionResultExt {
fn with_context(self, context: &'static str) -> Self;
fn with_kind(self, kind: ErrorKind) -> Self;
fn with_custom<E>(self, custom_error: E) -> Self
fn with_source<E>(self, source: E) -> Self
where
E: std::error::Error + Sync + Send + 'static;
fn with_reason(self, reason: impl Into<String>) -> Self;
}
impl<T> SessionResultExt for Result<T> {
impl<T> SessionResultExt for SessionResult<T> {
fn with_context(self, context: &'static str) -> Self {
self.map_err(|mut e| {
e.context = context;
@ -134,27 +97,10 @@ impl<T> SessionResultExt for Result<T> {
})
}
fn with_kind(self, kind: ErrorKind) -> Self {
self.map_err(|mut e| {
e.kind = kind;
e
})
}
fn with_custom<E>(self, custom_error: E) -> Self
fn with_source<E>(self, source: E) -> Self
where
E: std::error::Error + Sync + Send + 'static,
{
self.map_err(|mut e| {
e.kind = ErrorKind::Custom(Box::new(custom_error));
e
})
}
fn with_reason(self, reason: impl Into<String>) -> Self {
self.map_err(|mut e| {
e.reason = Some(reason.into());
e
})
self.map_err(|e| e.with_source(source))
}
}

View file

@ -0,0 +1,38 @@
/// Creates a `SessionError` with `General` kind
///
/// Shorthand for
/// ```rust
/// <crate::SessionError as crate::SessionErrorExt>::general(context)
/// ```
#[macro_export]
macro_rules! general_err {
( $context:expr $(,)? ) => {{
<$crate::SessionError as $crate::SessionErrorExt>::general($context)
}};
}
/// Creates a `SessionError` with `Reason` kind
///
/// Shorthand for
/// ```rust
/// <crate::SessionError as crate::SessionErrorExt>::reason(context, reason)
/// ```
#[macro_export]
macro_rules! reason_err {
( $context:expr, $($arg:tt)* ) => {{
<$crate::SessionError as $crate::SessionErrorExt>::reason($context, format!($($arg)*))
}};
}
/// Creates a `SessionError` with `Custom` kind and a source error attached to it
///
/// Shorthand for
/// ```rust
/// <crate::SessionError as crate::SessionErrorExt>::custom(context, source)
/// ```
#[macro_export]
macro_rules! custom_err {
( $context:expr, $source:expr $(,)? ) => {{
<$crate::SessionError as $crate::SessionErrorExt>::custom($context, $source)
}};
}

View file

@ -8,7 +8,7 @@ use ironrdp_pdu::geometry::Rectangle;
use ironrdp_pdu::PduBufferParsing;
use crate::image::DecodedImage;
use crate::{Error, Result};
use crate::SessionResult;
const TILE_SIZE: u16 = 64;
@ -45,7 +45,7 @@ impl DecodingContext {
image: &mut DecodedImage,
destination: &Rectangle,
input: &mut &[u8],
) -> Result<(FrameId, Rectangle)> {
) -> SessionResult<(FrameId, Rectangle)> {
loop {
match self.state {
SequenceState::HeaderMessages => {
@ -58,7 +58,7 @@ impl DecodingContext {
}
}
fn process_headers(&mut self, input: &mut &[u8]) -> Result<()> {
fn process_headers(&mut self, input: &mut &[u8]) -> SessionResult<()> {
let _sync = rfx::SyncPdu::from_buffer_consume(input)?;
let mut context = None;
@ -72,11 +72,11 @@ impl DecodingContext {
Headers::CodecVersions(_) => (),
}
}
let context = context.ok_or(Error::new("context header is missing"))?;
let channels = channels.ok_or(Error::new("channels header is missing"))?;
let context = context.ok_or(general_err!("context header is missing"))?;
let channels = channels.ok_or(general_err!("channels header is missing"))?;
if channels.0.is_empty() {
return Err(Error::new("no RFX channel announced"));
return Err(general_err!("no RFX channel annouced"));
}
self.context = context;
@ -92,7 +92,7 @@ impl DecodingContext {
image: &mut DecodedImage,
destination: &Rectangle,
input: &mut &[u8],
) -> Result<(FrameId, Rectangle)> {
) -> SessionResult<(FrameId, Rectangle)> {
let channel = self.channels.0.first().unwrap();
let width = channel.width as u16;
let height = channel.height as u16;
@ -172,7 +172,7 @@ fn decode_tile(
output: &mut [u8],
ycbcr_temp: &mut [Vec<i16>],
temp: &mut [i16],
) -> Result<()> {
) -> SessionResult<()> {
for ((quant, data), ycbcr_buffer) in tile.quants.iter().zip(tile.data.iter()).zip(ycbcr_temp.iter_mut()) {
decode_component(quant, entropy_algorithm, data, ycbcr_buffer.as_mut_slice(), temp)?;
}
@ -183,7 +183,7 @@ fn decode_tile(
cr: ycbcr_temp[2].as_slice(),
};
color_conversion::ycbcr_to_bgra(ycbcr_buffer, output).map_err(|e| Error::new("decode_tile").with_custom(e))?;
color_conversion::ycbcr_to_bgra(ycbcr_buffer, output).map_err(|e| custom_err!("decode_tile", e))?;
Ok(())
}
@ -194,8 +194,8 @@ fn decode_component(
data: &[u8],
output: &mut [i16],
temp: &mut [i16],
) -> Result<()> {
rlgr::decode(entropy_algorithm, data, output).map_err(|e| Error::new("decode_component").with_custom(e))?;
) -> SessionResult<()> {
rlgr::decode(entropy_algorithm, data, output).map_err(|e| custom_err!("decode_component", e))?;
subband_reconstruction::decode(&mut output[4032..]);
quantization::decode(output, quant);
dwt::decode(output, temp);

View file

@ -2,12 +2,12 @@ use ironrdp_pdu::dvc::display::ServerPdu;
use ironrdp_pdu::PduParsing;
use super::DynamicChannelDataHandler;
use crate::Result;
use crate::SessionResult;
pub struct Handler;
impl DynamicChannelDataHandler for Handler {
fn process_complete_data(&mut self, complete_data: Vec<u8>) -> Result<Option<Vec<u8>>> {
fn process_complete_data(&mut self, complete_data: Vec<u8>) -> SessionResult<Option<Vec<u8>>> {
let gfx_pdu = ServerPdu::from_buffer(&mut complete_data.as_slice())?;
debug!("Got Display PDU: {:?}", gfx_pdu);
Ok(None)

View file

@ -9,10 +9,10 @@ use ironrdp_pdu::dvc::gfx::{
use ironrdp_pdu::PduParsing;
use crate::x224::DynamicChannelDataHandler;
use crate::{Error, Result};
use crate::SessionResult;
pub trait GfxHandler {
fn on_message(&self, message: ServerPdu) -> Result<Option<ClientPdu>>;
fn on_message(&self, message: ServerPdu) -> SessionResult<Option<ClientPdu>>;
}
pub struct Handler {
@ -34,7 +34,7 @@ impl Handler {
}
impl DynamicChannelDataHandler for Handler {
fn process_complete_data(&mut self, complete_data: Vec<u8>) -> Result<Option<Vec<u8>>> {
fn process_complete_data(&mut self, complete_data: Vec<u8>) -> SessionResult<Option<Vec<u8>>> {
let mut client_pdu_buffer: Vec<u8> = vec![];
self.decompressed_buffer.clear();
self.decompressor
@ -96,13 +96,12 @@ bitflags! {
}
}
pub fn create_capabilities_advertise(graphics_config: &Option<GraphicsConfig>) -> Result<Vec<u8>> {
pub fn create_capabilities_advertise(graphics_config: &Option<GraphicsConfig>) -> SessionResult<Vec<u8>> {
let mut capabilities = vec![];
if let Some(config) = graphics_config {
let capability_version = CapabilityVersion::from_bits(config.capabilities).ok_or_else(|| {
Error::new("invalid capabilities mask provided").with_reason(format!("received: {:x}", config.capabilities))
})?;
let capability_version = CapabilityVersion::from_bits(config.capabilities)
.ok_or_else(|| reason_err!("GFX", "invalid capabilities mask: {:x}", config.capabilities))?;
if capability_version.contains(CapabilityVersion::V8) {
let flags = if config.thin_client {

View file

@ -12,7 +12,7 @@ use ironrdp_pdu::rdp::server_error_info::{ErrorInfo, ProtocolIndependentCode, Se
use ironrdp_pdu::rdp::vc::{self, dvc};
pub use self::gfx::GfxHandler;
use crate::{Error, Result};
use crate::SessionResult;
pub const RDP8_GRAPHICS_PIPELINE_NAME: &str = "Microsoft::Windows::RDS::Graphics";
pub const RDP8_DISPLAY_PIPELINE_NAME: &str = "Microsoft::Windows::RDS::DisplayControl";
@ -54,8 +54,9 @@ impl Processor {
}
}
pub fn process(&mut self, frame: &[u8]) -> Result<Vec<u8>> {
let data_ctx = ironrdp_connector::legacy::decode_send_data_indication(frame)?;
pub fn process(&mut self, frame: &[u8]) -> SessionResult<Vec<u8>> {
let data_ctx =
ironrdp_connector::legacy::decode_send_data_indication(frame).map_err(crate::legacy::map_error)?;
let channel_id = data_ctx.channel_id;
if channel_id == self.io_channel_id {
@ -64,15 +65,15 @@ impl Processor {
} else {
match self.drdynvc_channel_id {
Some(dyvc_id) if channel_id == dyvc_id => self.process_dyvc(data_ctx),
_ => Err(Error::new("unexpected channel").with_reason(format!("received ID {channel_id}"))),
_ => Err(reason_err!("X224", "unexpected channel received: ID {channel_id}")),
}
}
}
fn process_io_channel(&self, data_ctx: SendDataIndicationCtx<'_>) -> Result<()> {
fn process_io_channel(&self, data_ctx: SendDataIndicationCtx<'_>) -> SessionResult<()> {
debug_assert_eq!(data_ctx.channel_id, self.io_channel_id);
let ctx = ironrdp_connector::legacy::decode_share_data(data_ctx)?;
let ctx = ironrdp_connector::legacy::decode_share_data(data_ctx).map_err(crate::legacy::map_error)?;
match ctx.pdu {
ShareDataPdu::SaveSessionInfo(session_info) => {
@ -86,16 +87,17 @@ impl Processor {
Ok(())
}
ShareDataPdu::ServerSetErrorInfo(ServerSetErrorInfoPdu(e)) => {
Err(Error::new("ServerSetErrorInfo").with_reason(e.description()))
Err(reason_err!("ServerSetErrorInfo", "{}", e.description()))
}
_ => Err(Error::new("unexpected PDU").with_reason(format!(
"Expected Session Save Info PDU, got: {:?}",
_ => Err(reason_err!(
"IO channel",
"unexpected PDU: expected Session Save Info PDU, got: {:?}",
ctx.pdu.as_short_name()
))),
)),
}
}
fn process_dyvc(&mut self, data_ctx: SendDataIndicationCtx<'_>) -> Result<Vec<u8>> {
fn process_dyvc(&mut self, data_ctx: SendDataIndicationCtx<'_>) -> SessionResult<Vec<u8>> {
debug_assert_eq!(Some(data_ctx.channel_id), self.drdynvc_channel_id);
let dvc_ctx = crate::legacy::decode_dvc_message(data_ctx)?;
@ -188,9 +190,7 @@ impl Processor {
if let Some(dvc_data) = self
.dynamic_channels
.get_mut(&data.channel_id)
.ok_or_else(|| {
Error::new("access to non existing channel").with_reason(data.channel_id.to_string())
})?
.ok_or_else(|| reason_err!("DVC", "access to non existing channel: {}", data.channel_id))?
.process_data_first_pdu(data.total_data_size as usize, dvc_data.to_vec())?
{
let client_data = dvc::ClientPdu::Data(dvc::DataPdu {
@ -218,9 +218,7 @@ impl Processor {
if let Some(dvc_data) = self
.dynamic_channels
.get_mut(&data.channel_id)
.ok_or_else(|| {
Error::new("access to non existing channel").with_reason(data.channel_id.to_string())
})?
.ok_or_else(|| reason_err!("DVC", "access to non existing channel: {}", data.channel_id))?
.process_data_pdu(dvc_data.to_vec())?
{
let client_data = dvc::ClientPdu::Data(dvc::DataPdu {
@ -244,20 +242,20 @@ impl Processor {
}
/// Sends a PDU on the dynamic channel.
pub fn encode_dynamic(&self, output: &mut Vec<u8>, channel_name: &str, dvc_data: &[u8]) -> Result<usize> {
pub fn encode_dynamic(&self, output: &mut Vec<u8>, channel_name: &str, dvc_data: &[u8]) -> SessionResult<usize> {
let drdynvc_channel_id = self
.drdynvc_channel_id
.ok_or(Error::new("dynamic virtual channel not connected"))?;
.ok_or_else(|| general_err!("dynamic virtual channel not connected"))?;
let dvc_channel_id = self
.channel_map
.get(channel_name)
.ok_or_else(|| Error::new("access to non existing channel name").with_reason(channel_name))?;
.ok_or_else(|| reason_err!("DVC", "access to non existing channel name: {}", channel_name))?;
let dvc_channel = self
.dynamic_channels
.get(dvc_channel_id)
.ok_or_else(|| Error::new("access to non existing channel").with_reason(dvc_channel_id.to_string()))?;
.ok_or_else(|| reason_err!("DVC", "access to non existing channel: {}", dvc_channel_id))?;
let dvc_client_data = dvc::ClientPdu::Data(dvc::DataPdu {
channel_id_type: dvc_channel.channel_id_type,
@ -277,9 +275,10 @@ impl Processor {
}
/// Send a pdu on the static global channel. Typically used to send input events
pub fn encode_static(&self, output: &mut Vec<u8>, pdu: ShareDataPdu) -> Result<usize> {
pub fn encode_static(&self, output: &mut Vec<u8>, pdu: ShareDataPdu) -> SessionResult<usize> {
let written =
ironrdp_connector::legacy::encode_share_data(self.user_channel_id, self.io_channel_id, 0, pdu, output)?;
ironrdp_connector::legacy::encode_share_data(self.user_channel_id, self.io_channel_id, 0, pdu, output)
.map_err(crate::legacy::map_error)?;
Ok(written)
}
}
@ -317,7 +316,7 @@ fn negotiate_dvc(
channel_id: u16,
mut stream: impl io::Write,
graphics_config: &Option<GraphicsConfig>,
) -> Result<()> {
) -> SessionResult<()> {
if create_request.channel_name == RDP8_GRAPHICS_PIPELINE_NAME {
let dvc_data = gfx::create_capabilities_advertise(graphics_config)?;
let dvc_pdu = dvc::ClientPdu::Data(dvc::DataPdu {
@ -329,16 +328,14 @@ fn negotiate_dvc(
debug!("Send GFX Capabilities Advertise PDU");
let mut buf = Vec::new();
crate::legacy::encode_dvc_message(initiator_id, channel_id, dvc_pdu, &dvc_data, &mut buf)?;
stream
.write_all(&buf)
.map_err(|e| Error::new("write negotiation dvc").with_custom(e))?;
stream.write_all(&buf).map_err(|e| custom_err!("DVC write", e))?;
}
Ok(())
}
trait DynamicChannelDataHandler {
fn process_complete_data(&mut self, complete_data: Vec<u8>) -> Result<Option<Vec<u8>>>;
fn process_complete_data(&mut self, complete_data: Vec<u8>) -> SessionResult<Option<Vec<u8>>>;
}
pub struct DynamicChannel {
@ -358,7 +355,7 @@ impl DynamicChannel {
}
}
fn process_data_first_pdu(&mut self, total_data_size: usize, data: Vec<u8>) -> Result<Option<Vec<u8>>> {
fn process_data_first_pdu(&mut self, total_data_size: usize, data: Vec<u8>) -> SessionResult<Option<Vec<u8>>> {
if let Some(complete_data) = self.data.process_data_first_pdu(total_data_size, data) {
self.handler.process_complete_data(complete_data)
} else {
@ -366,7 +363,7 @@ impl DynamicChannel {
}
}
fn process_data_pdu(&mut self, data: Vec<u8>) -> Result<Option<Vec<u8>>> {
fn process_data_pdu(&mut self, data: Vec<u8>) -> SessionResult<Option<Vec<u8>>> {
if let Some(complete_data) = self.data.process_data_pdu(data) {
self.handler.process_complete_data(complete_data)
} else {

View file

@ -18,6 +18,7 @@ harness = true
[dependencies]
anyhow = "1"
array-concat = "0.5.2"
expect-test.workspace = true
ironrdp-pdu.workspace = true
lazy_static = "1.4.0"
paste = "1"

View file

@ -1,9 +1,10 @@
use expect_test::expect;
use ironrdp_pdu::mcs::*;
use ironrdp_pdu::{Error, PduParsing as _};
use ironrdp_pdu::PduParsing as _;
use ironrdp_testsuite_core::mcs::*;
use ironrdp_testsuite_core::mcs_encode_decode_test;
fn mcs_decode<'de, T: McsPdu<'de>>(src: &'de [u8]) -> ironrdp_pdu::Result<T> {
fn mcs_decode<'de, T: McsPdu<'de>>(src: &'de [u8]) -> ironrdp_pdu::PduResult<T> {
let mut cursor = ironrdp_pdu::cursor::ReadCursor::new(src);
T::mcs_body_decode(&mut cursor, src.len())
}
@ -14,13 +15,17 @@ fn invalid_domain_mcspdu() {
.err()
.unwrap();
if let Error::InvalidMessage { name, field, reason } = e {
assert_eq!(name, "McsMessage");
assert_eq!(field, "domain-mcspdu");
assert_eq!(reason, "unexpected application tag for CHOICE");
} else {
panic!("unexpected error: {e}");
expect![[r#"
Error {
context: "McsMessage",
kind: InvalidMessage {
field: "domain-mcspdu",
reason: "unexpected application tag for CHOICE",
},
source: None,
}
"#]]
.assert_debug_eq(&e);
}
mcs_encode_decode_test! {

View file

@ -1,3 +1,4 @@
use expect_test::expect;
use ironrdp_pdu::cursor::{ReadCursor, WriteCursor};
use ironrdp_pdu::nego::{
ConnectionConfirm, ConnectionRequest, Cookie, FailureCode, NegoRequestData, RequestFlags, ResponseFlags,
@ -6,7 +7,6 @@ use ironrdp_pdu::nego::{
use ironrdp_pdu::tpdu::{TpduCode, TpduHeader};
use ironrdp_pdu::tpkt::TpktHeader;
use ironrdp_pdu::x224::user_data_size;
use ironrdp_pdu::Error;
use ironrdp_testsuite_core::encode_decode_test;
const SAMPLE_TPKT_HEADER_BINARY: [u8; 4] = [
@ -217,12 +217,15 @@ fn nego_request_unexpected_rdp_msg_type() {
let e = ironrdp_pdu::decode::<ConnectionRequest>(&payload).err().unwrap();
if let Error::UnexpectedMessageType { name, got } = e {
assert_eq!(name, "Client X.224 Connection Request");
assert_eq!(got, 0x03);
} else {
panic!("unexpected error: {e}");
expect![[r#"
Error {
context: "Client X.224 Connection Request",
kind: UnexpectedMessageType {
got: 3,
},
source: None,
}
"#]].assert_debug_eq(&e);
}
#[test]
@ -247,12 +250,15 @@ fn nego_confirm_unexpected_rdp_msg_type() {
let e = ironrdp_pdu::decode::<ConnectionConfirm>(&payload).err().unwrap();
if let Error::UnexpectedMessageType { name, got } = e {
assert_eq!(name, "Server X.224 Connection Confirm");
assert_eq!(got, 0xAF);
} else {
panic!("unexpected error: {e}");
expect![[r#"
Error {
context: "Server X.224 Connection Confirm",
kind: UnexpectedMessageType {
got: 175,
},
source: None,
}
"#]].assert_debug_eq(&e);
}
#[test]
@ -305,16 +311,14 @@ fn cookie_without_cr_lf_error_decode() {
let e = Cookie::read(&mut ReadCursor::new(&payload)).err().unwrap();
if let Error::NotEnoughBytes {
name,
received,
expected,
} = e
{
assert_eq!(name, "Cookie");
assert_eq!(received, 1);
assert_eq!(expected, 2);
} else {
panic!("unexpected error: {e}");
expect![[r#"
Error {
context: "Cookie",
kind: NotEnoughBytes {
received: 1,
expected: 2,
},
source: None,
}
"#]].assert_debug_eq(&e);
}

View file

@ -1,5 +1,5 @@
use ironrdp::connector;
use ironrdp::connector::sspi;
use ironrdp::connector::{self, ConnectorErrorKind};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
@ -36,20 +36,20 @@ impl IronRdpError {
}
}
impl From<connector::Error> for IronRdpError {
fn from(e: connector::Error) -> Self {
impl From<connector::ConnectorError> for IronRdpError {
fn from(e: connector::ConnectorError) -> Self {
use sspi::credssp::NStatusCode;
let kind = match e.kind {
connector::ErrorKind::Credssp(sspi::Error {
ConnectorErrorKind::Credssp(sspi::Error {
nstatus: Some(NStatusCode::WRONG_PASSWORD),
..
}) => IronRdpErrorKind::WrongPassword,
connector::ErrorKind::Credssp(sspi::Error {
ConnectorErrorKind::Credssp(sspi::Error {
nstatus: Some(NStatusCode::LOGON_FAILURE),
..
}) => IronRdpErrorKind::LogonFailure,
connector::ErrorKind::AccessDenied => IronRdpErrorKind::AccessDenied,
ConnectorErrorKind::AccessDenied => IronRdpErrorKind::AccessDenied,
_ => IronRdpErrorKind::General,
};
@ -60,8 +60,8 @@ impl From<connector::Error> for IronRdpError {
}
}
impl From<ironrdp::session::Error> for IronRdpError {
fn from(e: ironrdp::session::Error) -> Self {
impl From<ironrdp::session::SessionError> for IronRdpError {
fn from(e: ironrdp::session::SessionError) -> Self {
Self {
kind: IronRdpErrorKind::General,
source: anyhow::Error::new(e),

View file

@ -458,14 +458,14 @@ where
pub const RDCLEANPATH_HINT: RDCleanPathHint = RDCleanPathHint;
impl ironrdp::pdu::PduHint for RDCleanPathHint {
fn find_size(&self, bytes: &[u8]) -> ironrdp::pdu::Result<Option<usize>> {
fn find_size(&self, bytes: &[u8]) -> ironrdp::pdu::PduResult<Option<usize>> {
match ironrdp_rdcleanpath::RDCleanPathPdu::detect(bytes) {
ironrdp_rdcleanpath::DetectionResult::Detected { total_length, .. } => Ok(Some(total_length)),
ironrdp_rdcleanpath::DetectionResult::NotEnoughBytes => Ok(None),
ironrdp_rdcleanpath::DetectionResult::Failed => Err(ironrdp::pdu::Error::Other {
context: "RDCleanPathHint",
reason: "detection failed (invalid PDU)",
}),
ironrdp_rdcleanpath::DetectionResult::Failed => Err(ironrdp::pdu::other_err!(
"RDCleanPathHint",
"detection failed (invalid PDU)"
)),
}
}
}