feat: add support for skip channel join flag (#373)

This commit is contained in:
Mihnea Buzatu 2024-02-28 15:39:50 +02:00 committed by GitHub
parent af7805a5fc
commit 5fd5ce946e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 124 additions and 51 deletions

View file

@ -1,6 +1,6 @@
use std::collections::HashSet;
use ironrdp_connector::{ConnectorError, ConnectorErrorExt, ConnectorResult, Sequence, State, Written};
use ironrdp_connector::{reason_err, ConnectorError, ConnectorErrorExt, ConnectorResult, Sequence, State, Written};
use ironrdp_pdu as pdu;
use pdu::mcs;
use pdu::write_buf::WriteBuf;
@ -9,7 +9,7 @@ use pdu::write_buf::WriteBuf;
pub struct ChannelConnectionSequence {
state: ChannelConnectionState,
user_channel_id: u16,
channels: HashSet<u16>,
channel_ids: Option<HashSet<u16>>,
}
#[derive(Default, Debug)]
@ -21,10 +21,10 @@ pub enum ChannelConnectionState {
WaitAttachUserRequest,
SendAttachUserConfirm,
WaitChannelJoinRequest {
joined: HashSet<u16>,
remaining: HashSet<u16>,
},
SendChannelJoinConfirm {
joined: HashSet<u16>,
remaining: HashSet<u16>,
channel_id: u16,
},
AllJoined,
@ -61,7 +61,7 @@ impl Sequence for ChannelConnectionSequence {
ChannelConnectionState::SendAttachUserConfirm => None,
ChannelConnectionState::WaitChannelJoinRequest { .. } => Some(&pdu::X224_HINT),
ChannelConnectionState::SendChannelJoinConfirm { .. } => None,
ChannelConnectionState::AllJoined { .. } => None,
ChannelConnectionState::AllJoined => None,
}
}
@ -99,28 +99,41 @@ impl Sequence for ChannelConnectionSequence {
let written = ironrdp_pdu::encode_buf(&attach_user_confirm, output).map_err(ConnectorError::pdu)?;
(
Written::from_size(written)?,
ChannelConnectionState::WaitChannelJoinRequest { joined: HashSet::new() },
)
let next_state = match self.channel_ids.take() {
Some(channel_ids) => ChannelConnectionState::WaitChannelJoinRequest { remaining: channel_ids },
None => ChannelConnectionState::AllJoined,
};
(Written::from_size(written)?, next_state)
}
// TODO(#165): support RNS_UD_CS_SUPPORT_SKIP_CHANNELJOIN
ChannelConnectionState::WaitChannelJoinRequest { joined } => {
ChannelConnectionState::WaitChannelJoinRequest { mut remaining } => {
let channel_request =
ironrdp_pdu::decode::<mcs::ChannelJoinRequest>(input).map_err(ConnectorError::pdu)?;
debug!(message = ?channel_request, "Received");
let channel_id = channel_request.channel_id;
let is_expected = remaining.remove(&channel_request.channel_id);
if !is_expected {
return Err(reason_err!(
"ChannelJoinConfirm",
"unexpected channel_id in MCS Channel Join Request: got {}, expected one of: {:?}",
channel_request.channel_id,
remaining,
));
}
(
Written::Nothing,
ChannelConnectionState::SendChannelJoinConfirm { joined, channel_id },
ChannelConnectionState::SendChannelJoinConfirm {
remaining,
channel_id: channel_request.channel_id,
},
)
}
ChannelConnectionState::SendChannelJoinConfirm { mut joined, channel_id } => {
ChannelConnectionState::SendChannelJoinConfirm { remaining, channel_id } => {
let channel_confirm = mcs::ChannelJoinConfirm {
result: 0,
initiator_id: self.user_channel_id,
@ -132,15 +145,13 @@ impl Sequence for ChannelConnectionSequence {
let written = ironrdp_pdu::encode_buf(&channel_confirm, output).map_err(ConnectorError::pdu)?;
joined.insert(channel_id);
let state = if joined != self.channels {
ChannelConnectionState::WaitChannelJoinRequest { joined }
let next_state = if remaining.is_empty() {
ChannelConnectionState::AllJoined
} else {
ChannelConnectionState::AllJoined {}
ChannelConnectionState::WaitChannelJoinRequest { remaining }
};
(Written::from_size(written)?, state)
(Written::from_size(written)?, next_state)
}
_ => unreachable!(),
@ -156,10 +167,20 @@ impl ChannelConnectionSequence {
Self {
state: ChannelConnectionState::WaitErectDomainRequest,
user_channel_id,
channels: vec![user_channel_id, io_channel_id]
.into_iter()
.chain(other_channels)
.collect(),
channel_ids: Some(
vec![user_channel_id, io_channel_id]
.into_iter()
.chain(other_channels)
.collect(),
),
}
}
pub fn skip_channel_join(user_channel_id: u16) -> Self {
Self {
state: ChannelConnectionState::WaitErectDomainRequest,
user_channel_id,
channel_ids: None,
}
}

View file

@ -295,7 +295,17 @@ impl Sequence for Acceptor {
channels,
} => {
let channel_ids: Vec<u16> = channels.iter().map(|&(i, _)| i).collect();
let server_blocks = create_gcc_blocks(self.io_channel_id, channel_ids.clone(), requested_protocol);
let skip_channel_join = early_capability
.is_some_and(|client| client.contains(gcc::ClientEarlyCapabilityFlags::SUPPORT_SKIP_CHANNELJOIN));
let server_blocks = create_gcc_blocks(
self.io_channel_id,
channel_ids.clone(),
requested_protocol,
skip_channel_join,
);
let settings_response = mcs::ConnectResponse {
conference_create_response: gcc::ConferenceCreateResponse {
user_id: self.user_channel_id,
@ -315,11 +325,11 @@ impl Sequence for Acceptor {
AcceptorState::ChannelConnection {
early_capability,
channels,
connection: ChannelConnectionSequence::new(
self.user_channel_id,
self.io_channel_id,
channel_ids,
),
connection: if skip_channel_join {
ChannelConnectionSequence::skip_channel_join(self.user_channel_id)
} else {
ChannelConnectionSequence::new(self.user_channel_id, self.io_channel_id, channel_ids)
},
},
)
}
@ -530,13 +540,15 @@ fn create_gcc_blocks(
io_channel: u16,
channel_ids: Vec<u16>,
requested: nego::SecurityProtocol,
skip_channel_join: bool,
) -> gcc::ServerGccBlocks {
pdu::gcc::ServerGccBlocks {
core: gcc::ServerCoreData {
version: gcc::RdpVersion::V5_PLUS,
optional_data: gcc::ServerCoreOptionalData {
client_requested_protocols: Some(requested),
early_capability_flags: None,
early_capability_flags: skip_channel_join
.then_some(gcc::ServerEarlyCapabilityFlags::SKIP_CHANNELJOIN_SUPPORTED),
},
},
security: gcc::ServerSecurityData::no_security(),

View file

@ -18,9 +18,11 @@ pub enum ChannelConnectionState {
WaitAttachUserConfirm,
SendChannelJoinRequest {
user_channel_id: u16,
join_channel_ids: HashSet<u16>,
},
WaitChannelJoinConfirm {
user_channel_id: u16,
remaining_channel_ids: HashSet<u16>,
},
AllJoined {
user_channel_id: u16,
@ -53,7 +55,7 @@ impl State for ChannelConnectionState {
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct ChannelConnectionSequence {
pub state: ChannelConnectionState,
pub channel_ids: HashSet<u16>,
pub channel_ids: Option<HashSet<u16>>,
}
impl ChannelConnectionSequence {
@ -65,7 +67,14 @@ impl ChannelConnectionSequence {
Self {
state: ChannelConnectionState::SendErectDomainRequest,
channel_ids,
channel_ids: Some(channel_ids),
}
}
pub fn skip_channel_join() -> Self {
Self {
state: ChannelConnectionState::SendErectDomainRequest,
channel_ids: None,
}
}
}
@ -126,17 +135,23 @@ impl Sequence for ChannelConnectionSequence {
let user_channel_id = attach_user_confirm.initiator_id;
// User channel ID must also be joined.
self.channel_ids.insert(user_channel_id);
debug!(message = ?attach_user_confirm, user_channel_id, "Received");
debug_assert!(!self.channel_ids.is_empty());
let next = match self.channel_ids.take() {
Some(mut channel_ids) => {
// User channel ID must also be joined.
channel_ids.insert(user_channel_id);
(
Written::Nothing,
ChannelConnectionState::SendChannelJoinRequest { user_channel_id },
)
ChannelConnectionState::SendChannelJoinRequest {
user_channel_id,
join_channel_ids: channel_ids,
}
}
None => ChannelConnectionState::AllJoined { user_channel_id },
};
(Written::Nothing, next)
}
// Send all the join requests in a single batch.
@ -145,10 +160,15 @@ impl Sequence for ChannelConnectionSequence {
// > Channel Join Confirm for a previously sent request has been received. RDP 8.1,
// > 10.0, and 10.1 clients send all of the Channel Join Requests to the server in a
// > single batch to minimize the overall connection sequence time.
ChannelConnectionState::SendChannelJoinRequest { user_channel_id } => {
ChannelConnectionState::SendChannelJoinRequest {
user_channel_id,
join_channel_ids,
} => {
let mut total_written: usize = 0;
for channel_id in self.channel_ids.iter().copied() {
debug_assert!(!join_channel_ids.is_empty());
for channel_id in join_channel_ids.iter().copied() {
let channel_join_request = mcs::ChannelJoinRequest {
initiator_id: user_channel_id,
channel_id,
@ -164,11 +184,17 @@ impl Sequence for ChannelConnectionSequence {
(
Written::from_size(total_written)?,
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id },
ChannelConnectionState::WaitChannelJoinConfirm {
user_channel_id,
remaining_channel_ids: join_channel_ids,
},
)
}
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id } => {
ChannelConnectionState::WaitChannelJoinConfirm {
user_channel_id,
mut remaining_channel_ids,
} => {
let channel_join_confirm =
ironrdp_pdu::decode::<mcs::ChannelJoinConfirm>(input).map_err(ConnectorError::pdu)?;
@ -181,14 +207,14 @@ impl Sequence for ChannelConnectionSequence {
)
}
let is_expected = self.channel_ids.remove(&channel_join_confirm.requested_channel_id);
let is_expected = remaining_channel_ids.remove(&channel_join_confirm.requested_channel_id);
if !is_expected {
return Err(reason_err!(
"ChannelJoinConfirm",
"unexpected requested_channel_id in MCS Channel Join Confirm: got {}, expected one of: {:?}",
channel_join_confirm.requested_channel_id,
self.channel_ids,
remaining_channel_ids,
));
}
@ -202,10 +228,13 @@ impl Sequence for ChannelConnectionSequence {
));
}
let next_state = if self.channel_ids.is_empty() {
let next_state = if remaining_channel_ids.is_empty() {
ChannelConnectionState::AllJoined { user_channel_id }
} else {
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id }
ChannelConnectionState::WaitChannelJoinConfirm {
user_channel_id,
remaining_channel_ids,
}
};
(Written::Nothing, next_state)

View file

@ -377,11 +377,21 @@ impl Sequence for ClientConnector {
self.static_channels.attach_channel_id(channel, channel_id);
});
let skip_channel_join = server_gcc_blocks
.core
.optional_data
.early_capability_flags
.is_some_and(|c| c.contains(gcc::ServerEarlyCapabilityFlags::SKIP_CHANNELJOIN_SUPPORTED));
(
Written::Nothing,
ClientConnectorState::ChannelConnection {
io_channel_id,
channel_connection: ChannelConnectionSequence::new(io_channel_id, static_channel_ids),
channel_connection: if skip_channel_join {
ChannelConnectionSequence::skip_channel_join()
} else {
ChannelConnectionSequence::new(io_channel_id, static_channel_ids)
},
},
)
}
@ -677,7 +687,8 @@ fn create_gcc_blocks<'a>(
early_capability_flags: {
let mut early_capability_flags = ClientEarlyCapabilityFlags::VALID_CONNECTION_TYPE
| ClientEarlyCapabilityFlags::SUPPORT_ERR_INFO_PDU
| ClientEarlyCapabilityFlags::STRONG_ASYMMETRIC_KEYS;
| ClientEarlyCapabilityFlags::STRONG_ASYMMETRIC_KEYS
| ClientEarlyCapabilityFlags::SUPPORT_SKIP_CHANNELJOIN;
// TODO(#136): support for ClientEarlyCapabilityFlags::SUPPORT_STATUS_INFO_PDU