mirror of
https://github.com/Devolutions/IronRDP.git
synced 2025-08-04 15:18:17 +00:00
feat: add support for skip channel join flag (#373)
This commit is contained in:
parent
af7805a5fc
commit
5fd5ce946e
4 changed files with 124 additions and 51 deletions
|
@ -1,6 +1,6 @@
|
|||
use std::collections::HashSet;
|
||||
|
||||
use ironrdp_connector::{ConnectorError, ConnectorErrorExt, ConnectorResult, Sequence, State, Written};
|
||||
use ironrdp_connector::{reason_err, ConnectorError, ConnectorErrorExt, ConnectorResult, Sequence, State, Written};
|
||||
use ironrdp_pdu as pdu;
|
||||
use pdu::mcs;
|
||||
use pdu::write_buf::WriteBuf;
|
||||
|
@ -9,7 +9,7 @@ use pdu::write_buf::WriteBuf;
|
|||
pub struct ChannelConnectionSequence {
|
||||
state: ChannelConnectionState,
|
||||
user_channel_id: u16,
|
||||
channels: HashSet<u16>,
|
||||
channel_ids: Option<HashSet<u16>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
|
@ -21,10 +21,10 @@ pub enum ChannelConnectionState {
|
|||
WaitAttachUserRequest,
|
||||
SendAttachUserConfirm,
|
||||
WaitChannelJoinRequest {
|
||||
joined: HashSet<u16>,
|
||||
remaining: HashSet<u16>,
|
||||
},
|
||||
SendChannelJoinConfirm {
|
||||
joined: HashSet<u16>,
|
||||
remaining: HashSet<u16>,
|
||||
channel_id: u16,
|
||||
},
|
||||
AllJoined,
|
||||
|
@ -61,7 +61,7 @@ impl Sequence for ChannelConnectionSequence {
|
|||
ChannelConnectionState::SendAttachUserConfirm => None,
|
||||
ChannelConnectionState::WaitChannelJoinRequest { .. } => Some(&pdu::X224_HINT),
|
||||
ChannelConnectionState::SendChannelJoinConfirm { .. } => None,
|
||||
ChannelConnectionState::AllJoined { .. } => None,
|
||||
ChannelConnectionState::AllJoined => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,28 +99,41 @@ impl Sequence for ChannelConnectionSequence {
|
|||
|
||||
let written = ironrdp_pdu::encode_buf(&attach_user_confirm, output).map_err(ConnectorError::pdu)?;
|
||||
|
||||
(
|
||||
Written::from_size(written)?,
|
||||
ChannelConnectionState::WaitChannelJoinRequest { joined: HashSet::new() },
|
||||
)
|
||||
let next_state = match self.channel_ids.take() {
|
||||
Some(channel_ids) => ChannelConnectionState::WaitChannelJoinRequest { remaining: channel_ids },
|
||||
None => ChannelConnectionState::AllJoined,
|
||||
};
|
||||
|
||||
(Written::from_size(written)?, next_state)
|
||||
}
|
||||
|
||||
// TODO(#165): support RNS_UD_CS_SUPPORT_SKIP_CHANNELJOIN
|
||||
ChannelConnectionState::WaitChannelJoinRequest { joined } => {
|
||||
ChannelConnectionState::WaitChannelJoinRequest { mut remaining } => {
|
||||
let channel_request =
|
||||
ironrdp_pdu::decode::<mcs::ChannelJoinRequest>(input).map_err(ConnectorError::pdu)?;
|
||||
|
||||
debug!(message = ?channel_request, "Received");
|
||||
|
||||
let channel_id = channel_request.channel_id;
|
||||
let is_expected = remaining.remove(&channel_request.channel_id);
|
||||
|
||||
if !is_expected {
|
||||
return Err(reason_err!(
|
||||
"ChannelJoinConfirm",
|
||||
"unexpected channel_id in MCS Channel Join Request: got {}, expected one of: {:?}",
|
||||
channel_request.channel_id,
|
||||
remaining,
|
||||
));
|
||||
}
|
||||
|
||||
(
|
||||
Written::Nothing,
|
||||
ChannelConnectionState::SendChannelJoinConfirm { joined, channel_id },
|
||||
ChannelConnectionState::SendChannelJoinConfirm {
|
||||
remaining,
|
||||
channel_id: channel_request.channel_id,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
ChannelConnectionState::SendChannelJoinConfirm { mut joined, channel_id } => {
|
||||
ChannelConnectionState::SendChannelJoinConfirm { remaining, channel_id } => {
|
||||
let channel_confirm = mcs::ChannelJoinConfirm {
|
||||
result: 0,
|
||||
initiator_id: self.user_channel_id,
|
||||
|
@ -132,15 +145,13 @@ impl Sequence for ChannelConnectionSequence {
|
|||
|
||||
let written = ironrdp_pdu::encode_buf(&channel_confirm, output).map_err(ConnectorError::pdu)?;
|
||||
|
||||
joined.insert(channel_id);
|
||||
|
||||
let state = if joined != self.channels {
|
||||
ChannelConnectionState::WaitChannelJoinRequest { joined }
|
||||
let next_state = if remaining.is_empty() {
|
||||
ChannelConnectionState::AllJoined
|
||||
} else {
|
||||
ChannelConnectionState::AllJoined {}
|
||||
ChannelConnectionState::WaitChannelJoinRequest { remaining }
|
||||
};
|
||||
|
||||
(Written::from_size(written)?, state)
|
||||
(Written::from_size(written)?, next_state)
|
||||
}
|
||||
|
||||
_ => unreachable!(),
|
||||
|
@ -156,10 +167,20 @@ impl ChannelConnectionSequence {
|
|||
Self {
|
||||
state: ChannelConnectionState::WaitErectDomainRequest,
|
||||
user_channel_id,
|
||||
channels: vec![user_channel_id, io_channel_id]
|
||||
.into_iter()
|
||||
.chain(other_channels)
|
||||
.collect(),
|
||||
channel_ids: Some(
|
||||
vec![user_channel_id, io_channel_id]
|
||||
.into_iter()
|
||||
.chain(other_channels)
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn skip_channel_join(user_channel_id: u16) -> Self {
|
||||
Self {
|
||||
state: ChannelConnectionState::WaitErectDomainRequest,
|
||||
user_channel_id,
|
||||
channel_ids: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -295,7 +295,17 @@ impl Sequence for Acceptor {
|
|||
channels,
|
||||
} => {
|
||||
let channel_ids: Vec<u16> = channels.iter().map(|&(i, _)| i).collect();
|
||||
let server_blocks = create_gcc_blocks(self.io_channel_id, channel_ids.clone(), requested_protocol);
|
||||
|
||||
let skip_channel_join = early_capability
|
||||
.is_some_and(|client| client.contains(gcc::ClientEarlyCapabilityFlags::SUPPORT_SKIP_CHANNELJOIN));
|
||||
|
||||
let server_blocks = create_gcc_blocks(
|
||||
self.io_channel_id,
|
||||
channel_ids.clone(),
|
||||
requested_protocol,
|
||||
skip_channel_join,
|
||||
);
|
||||
|
||||
let settings_response = mcs::ConnectResponse {
|
||||
conference_create_response: gcc::ConferenceCreateResponse {
|
||||
user_id: self.user_channel_id,
|
||||
|
@ -315,11 +325,11 @@ impl Sequence for Acceptor {
|
|||
AcceptorState::ChannelConnection {
|
||||
early_capability,
|
||||
channels,
|
||||
connection: ChannelConnectionSequence::new(
|
||||
self.user_channel_id,
|
||||
self.io_channel_id,
|
||||
channel_ids,
|
||||
),
|
||||
connection: if skip_channel_join {
|
||||
ChannelConnectionSequence::skip_channel_join(self.user_channel_id)
|
||||
} else {
|
||||
ChannelConnectionSequence::new(self.user_channel_id, self.io_channel_id, channel_ids)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -530,13 +540,15 @@ fn create_gcc_blocks(
|
|||
io_channel: u16,
|
||||
channel_ids: Vec<u16>,
|
||||
requested: nego::SecurityProtocol,
|
||||
skip_channel_join: bool,
|
||||
) -> gcc::ServerGccBlocks {
|
||||
pdu::gcc::ServerGccBlocks {
|
||||
core: gcc::ServerCoreData {
|
||||
version: gcc::RdpVersion::V5_PLUS,
|
||||
optional_data: gcc::ServerCoreOptionalData {
|
||||
client_requested_protocols: Some(requested),
|
||||
early_capability_flags: None,
|
||||
early_capability_flags: skip_channel_join
|
||||
.then_some(gcc::ServerEarlyCapabilityFlags::SKIP_CHANNELJOIN_SUPPORTED),
|
||||
},
|
||||
},
|
||||
security: gcc::ServerSecurityData::no_security(),
|
||||
|
|
|
@ -18,9 +18,11 @@ pub enum ChannelConnectionState {
|
|||
WaitAttachUserConfirm,
|
||||
SendChannelJoinRequest {
|
||||
user_channel_id: u16,
|
||||
join_channel_ids: HashSet<u16>,
|
||||
},
|
||||
WaitChannelJoinConfirm {
|
||||
user_channel_id: u16,
|
||||
remaining_channel_ids: HashSet<u16>,
|
||||
},
|
||||
AllJoined {
|
||||
user_channel_id: u16,
|
||||
|
@ -53,7 +55,7 @@ impl State for ChannelConnectionState {
|
|||
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
|
||||
pub struct ChannelConnectionSequence {
|
||||
pub state: ChannelConnectionState,
|
||||
pub channel_ids: HashSet<u16>,
|
||||
pub channel_ids: Option<HashSet<u16>>,
|
||||
}
|
||||
|
||||
impl ChannelConnectionSequence {
|
||||
|
@ -65,7 +67,14 @@ impl ChannelConnectionSequence {
|
|||
|
||||
Self {
|
||||
state: ChannelConnectionState::SendErectDomainRequest,
|
||||
channel_ids,
|
||||
channel_ids: Some(channel_ids),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn skip_channel_join() -> Self {
|
||||
Self {
|
||||
state: ChannelConnectionState::SendErectDomainRequest,
|
||||
channel_ids: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -126,17 +135,23 @@ impl Sequence for ChannelConnectionSequence {
|
|||
|
||||
let user_channel_id = attach_user_confirm.initiator_id;
|
||||
|
||||
// User channel ID must also be joined.
|
||||
self.channel_ids.insert(user_channel_id);
|
||||
|
||||
debug!(message = ?attach_user_confirm, user_channel_id, "Received");
|
||||
|
||||
debug_assert!(!self.channel_ids.is_empty());
|
||||
let next = match self.channel_ids.take() {
|
||||
Some(mut channel_ids) => {
|
||||
// User channel ID must also be joined.
|
||||
channel_ids.insert(user_channel_id);
|
||||
|
||||
(
|
||||
Written::Nothing,
|
||||
ChannelConnectionState::SendChannelJoinRequest { user_channel_id },
|
||||
)
|
||||
ChannelConnectionState::SendChannelJoinRequest {
|
||||
user_channel_id,
|
||||
join_channel_ids: channel_ids,
|
||||
}
|
||||
}
|
||||
|
||||
None => ChannelConnectionState::AllJoined { user_channel_id },
|
||||
};
|
||||
|
||||
(Written::Nothing, next)
|
||||
}
|
||||
|
||||
// Send all the join requests in a single batch.
|
||||
|
@ -145,10 +160,15 @@ impl Sequence for ChannelConnectionSequence {
|
|||
// > Channel Join Confirm for a previously sent request has been received. RDP 8.1,
|
||||
// > 10.0, and 10.1 clients send all of the Channel Join Requests to the server in a
|
||||
// > single batch to minimize the overall connection sequence time.
|
||||
ChannelConnectionState::SendChannelJoinRequest { user_channel_id } => {
|
||||
ChannelConnectionState::SendChannelJoinRequest {
|
||||
user_channel_id,
|
||||
join_channel_ids,
|
||||
} => {
|
||||
let mut total_written: usize = 0;
|
||||
|
||||
for channel_id in self.channel_ids.iter().copied() {
|
||||
debug_assert!(!join_channel_ids.is_empty());
|
||||
|
||||
for channel_id in join_channel_ids.iter().copied() {
|
||||
let channel_join_request = mcs::ChannelJoinRequest {
|
||||
initiator_id: user_channel_id,
|
||||
channel_id,
|
||||
|
@ -164,11 +184,17 @@ impl Sequence for ChannelConnectionSequence {
|
|||
|
||||
(
|
||||
Written::from_size(total_written)?,
|
||||
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id },
|
||||
ChannelConnectionState::WaitChannelJoinConfirm {
|
||||
user_channel_id,
|
||||
remaining_channel_ids: join_channel_ids,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id } => {
|
||||
ChannelConnectionState::WaitChannelJoinConfirm {
|
||||
user_channel_id,
|
||||
mut remaining_channel_ids,
|
||||
} => {
|
||||
let channel_join_confirm =
|
||||
ironrdp_pdu::decode::<mcs::ChannelJoinConfirm>(input).map_err(ConnectorError::pdu)?;
|
||||
|
||||
|
@ -181,14 +207,14 @@ impl Sequence for ChannelConnectionSequence {
|
|||
)
|
||||
}
|
||||
|
||||
let is_expected = self.channel_ids.remove(&channel_join_confirm.requested_channel_id);
|
||||
let is_expected = remaining_channel_ids.remove(&channel_join_confirm.requested_channel_id);
|
||||
|
||||
if !is_expected {
|
||||
return Err(reason_err!(
|
||||
"ChannelJoinConfirm",
|
||||
"unexpected requested_channel_id in MCS Channel Join Confirm: got {}, expected one of: {:?}",
|
||||
channel_join_confirm.requested_channel_id,
|
||||
self.channel_ids,
|
||||
remaining_channel_ids,
|
||||
));
|
||||
}
|
||||
|
||||
|
@ -202,10 +228,13 @@ impl Sequence for ChannelConnectionSequence {
|
|||
));
|
||||
}
|
||||
|
||||
let next_state = if self.channel_ids.is_empty() {
|
||||
let next_state = if remaining_channel_ids.is_empty() {
|
||||
ChannelConnectionState::AllJoined { user_channel_id }
|
||||
} else {
|
||||
ChannelConnectionState::WaitChannelJoinConfirm { user_channel_id }
|
||||
ChannelConnectionState::WaitChannelJoinConfirm {
|
||||
user_channel_id,
|
||||
remaining_channel_ids,
|
||||
}
|
||||
};
|
||||
|
||||
(Written::Nothing, next_state)
|
||||
|
|
|
@ -377,11 +377,21 @@ impl Sequence for ClientConnector {
|
|||
self.static_channels.attach_channel_id(channel, channel_id);
|
||||
});
|
||||
|
||||
let skip_channel_join = server_gcc_blocks
|
||||
.core
|
||||
.optional_data
|
||||
.early_capability_flags
|
||||
.is_some_and(|c| c.contains(gcc::ServerEarlyCapabilityFlags::SKIP_CHANNELJOIN_SUPPORTED));
|
||||
|
||||
(
|
||||
Written::Nothing,
|
||||
ClientConnectorState::ChannelConnection {
|
||||
io_channel_id,
|
||||
channel_connection: ChannelConnectionSequence::new(io_channel_id, static_channel_ids),
|
||||
channel_connection: if skip_channel_join {
|
||||
ChannelConnectionSequence::skip_channel_join()
|
||||
} else {
|
||||
ChannelConnectionSequence::new(io_channel_id, static_channel_ids)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -677,7 +687,8 @@ fn create_gcc_blocks<'a>(
|
|||
early_capability_flags: {
|
||||
let mut early_capability_flags = ClientEarlyCapabilityFlags::VALID_CONNECTION_TYPE
|
||||
| ClientEarlyCapabilityFlags::SUPPORT_ERR_INFO_PDU
|
||||
| ClientEarlyCapabilityFlags::STRONG_ASYMMETRIC_KEYS;
|
||||
| ClientEarlyCapabilityFlags::STRONG_ASYMMETRIC_KEYS
|
||||
| ClientEarlyCapabilityFlags::SUPPORT_SKIP_CHANNELJOIN;
|
||||
|
||||
// TODO(#136): support for ClientEarlyCapabilityFlags::SUPPORT_STATUS_INFO_PDU
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue