refactor(credssp): follow up to #260 (#287)

This commit is contained in:
Benoît Cortier 2023-11-17 09:50:03 -05:00 committed by GitHub
parent 5530550ef3
commit 8fc213e699
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 172 additions and 206 deletions

View file

@ -1,16 +1,14 @@
use ironrdp_connector::credssp::{CredsspProcessGenerator, CredsspSequence, KerberosConfig};
use ironrdp_connector::sspi::credssp::ClientState;
use ironrdp_connector::sspi::generator::GeneratorState;
use ironrdp_connector::{
credssp_sequence::{CredsspProcessGenerator, CredsspSequence},
custom_err,
sspi::{credssp::ClientState, generator::GeneratorState},
ClientConnector, ClientConnectorState, ConnectionResult, ConnectorResult, KerberosConfig, Sequence as _,
ServerName, State as _, Written,
custom_err, ClientConnector, ClientConnectorState, ConnectionResult, ConnectorError, ConnectorResult,
Sequence as _, ServerName, State as _,
};
use ironrdp_pdu::write_buf::WriteBuf;
use crate::{
framed::{Framed, FramedRead, FramedWrite},
AsyncNetworkClient,
};
use crate::framed::{Framed, FramedRead, FramedWrite};
use crate::AsyncNetworkClient;
#[non_exhaustive]
pub struct ShouldUpgrade;
@ -50,10 +48,10 @@ pub fn mark_as_upgraded(_: ShouldUpgrade, connector: &mut ClientConnector) -> Up
pub async fn connect_finalize<S>(
_: Upgraded,
framed: &mut Framed<S>,
mut connector: ClientConnector,
server_name: ServerName,
server_public_key: Vec<u8>,
network_client: Option<&mut dyn AsyncNetworkClient>,
mut connector: ClientConnector,
kerberos_config: Option<KerberosConfig>,
) -> ConnectorResult<ConnectionResult>
where
@ -92,6 +90,7 @@ async fn resolve_generator(
network_client: &mut dyn AsyncNetworkClient,
) -> ConnectorResult<ClientState> {
let mut state = generator.start();
loop {
match state {
GeneratorState::Suspended(request) => {
@ -99,13 +98,14 @@ async fn resolve_generator(
state = generator.resume(Ok(response));
}
GeneratorState::Completed(client_state) => {
break Ok(client_state.map_err(|e| custom_err!("cannot resolve generator state", e))?)
break client_state
.map_err(|e| ConnectorError::new("CredSSP", ironrdp_connector::ConnectorErrorKind::Credssp(e)))
}
}
}
}
#[instrument(level = "trace", skip(network_client, framed, buf, server_name, server_public_key))]
#[instrument(level = "trace", skip_all)]
async fn perform_credssp_step<S>(
framed: &mut Framed<S>,
connector: &mut ClientConnector,
@ -119,10 +119,13 @@ where
S: FramedRead + FramedWrite,
{
assert!(connector.should_perform_credssp());
let mut credssp_sequence = CredsspSequence::new(connector, server_name, server_public_key, kerberos_config)?;
while !credssp_sequence.is_done() {
buf.clear();
let input = if let Some(next_pdu_hint) = credssp_sequence.next_pdu_hint() {
if let Some(next_pdu_hint) = credssp_sequence.next_pdu_hint() {
debug!(
connector.state = connector.state.name(),
hint = ?next_pdu_hint,
@ -135,25 +138,23 @@ where
.map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;
trace!(length = pdu.len(), "PDU received");
Some(pdu.to_vec())
} else {
None
};
if credssp_sequence.wants_request_from_server() {
credssp_sequence.read_request_from_server(&input.unwrap_or_else(|| [].to_vec()))?;
credssp_sequence.read_request_from_server(&pdu)?;
}
let client_state = {
let mut generator = credssp_sequence.process();
if let Some(network_client_ref) = network_client.as_deref_mut() {
trace!("resolving network");
resolve_generator(&mut generator, network_client_ref).await?
} else {
generator
.resolve_to_result()
.map_err(|e| custom_err!(" cannot resolve generator without a network client", e))?
.map_err(|e| custom_err!("resolve without network client", e))?
}
}; // drop generator
let written = credssp_sequence.handle_process_result(client_state, buf)?;
if let Some(response_len) = written.size() {
@ -165,7 +166,9 @@ where
.map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
}
}
connector.mark_credssp_as_done();
Ok(())
}
@ -179,7 +182,7 @@ where
{
buf.clear();
let written: Written = if let Some(next_pdu_hint) = connector.next_pdu_hint() {
let written = if let Some(next_pdu_hint) = connector.next_pdu_hint() {
debug!(
connector.state = connector.state.name(),
hint = ?next_pdu_hint,