diff --git a/src/handshakestate.rs b/src/handshakestate.rs index 98cbfbaf..c38c84ad 100644 --- a/src/handshakestate.rs +++ b/src/handshakestate.rs @@ -18,6 +18,7 @@ use crate::{ use std::{ convert::{TryFrom, TryInto}, fmt, + marker::PhantomData, }; /// A state machine encompassing the handshake phase of a Noise session. @@ -47,6 +48,24 @@ pub struct HandshakeState { pub(crate) pattern_position: usize, } +/// Result of [`simulate_write_message`](HandshakeState::simulate_write_message) if the call was successful. +pub struct SimulatedWriteInfo { + /// Would-be resulting message length of a call to [`write_message`](HandshakeState::write_message) + /// + /// This is useful to pad the message length to disguise re-handshakes. + pub result_length: usize, + + /// Flag that indicates if the contained payload would be encrypted + /// + /// NOTE: even when the payload would be encrypted, that doesn't necessarily mean that the + /// encryption is as strong as the encryption in transport mode. + /// See: [Noise Explorer, Patterns](https://noiseexplorer.com/patterns/) + pub is_encrypted: bool, + + /// Disallow exhausive matching + _non_exhausive: PhantomData<()>, +} + impl HandshakeState { #[allow(clippy::too_many_arguments)] pub(crate) fn new( @@ -195,7 +214,7 @@ impl HandshakeState { } /// Construct a message from `payload` (and pending handshake tokens if in handshake state), - /// and writes it to the `output` buffer. + /// and writes it to the `message` buffer. /// /// Returns the size of the written payload. /// @@ -204,11 +223,12 @@ impl HandshakeState { /// Will result in `Error::Input` if the size of the output exceeds the max message /// length in the Noise Protocol (65535 bytes). #[must_use] - pub fn write_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result { + pub fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result { let checkpoint = self.symmetricstate.checkpoint(); - match self._write_message(message, payload) { + match self._write_message(payload, message) { Ok(res) => { self.pattern_position += 1; + self.my_turn = false; Ok(res) }, Err(err) => { @@ -218,6 +238,21 @@ impl HandshakeState { } } + /// Simulates a call to [`write_message`](HandshakeState::write_message) and returns + /// the result and additional information. + pub fn simulate_write_message(&mut self, payload: &[u8]) -> Result { + let mut tmp = [0u8; 65535]; + let checkpoint = self.symmetricstate.checkpoint(); + let ret = + self._write_message(payload, &mut tmp[..]).map(|result_length| SimulatedWriteInfo { + result_length, + is_encrypted: self.symmetricstate.has_key(), + _non_exhausive: PhantomData, + }); + self.symmetricstate.restore(checkpoint); + ret + } + fn _write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result { if !self.my_turn { bail!(StateProblem::NotTurnToWrite); @@ -317,7 +352,6 @@ impl HandshakeState { if self.pattern_position == (self.message_patterns.len() - 1) { self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1); } - self.my_turn = false; Ok(byte_index) } @@ -338,6 +372,7 @@ impl HandshakeState { match self._read_message(message, payload) { Ok(res) => { self.pattern_position += 1; + self.my_turn = true; Ok(res) }, Err(err) => { @@ -447,7 +482,6 @@ impl HandshakeState { } self.symmetricstate.decrypt_and_mix_hash(ptr, payload).map_err(|_| Error::Decrypt)?; - self.my_turn = true; if last { self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1); } @@ -516,10 +550,7 @@ impl HandshakeState { pub fn dangerously_get_raw_split(&mut self) -> ([u8; CIPHERKEYLEN], [u8; CIPHERKEYLEN]) { let mut output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]); self.symmetricstate.split_raw(&mut output.0, &mut output.1); - ( - output.0[..CIPHERKEYLEN].try_into().unwrap(), - output.1[..CIPHERKEYLEN].try_into().unwrap() - ) + (output.0[..CIPHERKEYLEN].try_into().unwrap(), output.1[..CIPHERKEYLEN].try_into().unwrap()) } /// Convert this `HandshakeState` into a `TransportState` with an internally stored nonce. diff --git a/src/lib.rs b/src/lib.rs index 9a92bfc4..370d0cf5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -89,7 +89,7 @@ pub mod types; pub use crate::{ builder::{Builder, Keypair}, error::Error, - handshakestate::HandshakeState, + handshakestate::{HandshakeState, SimulatedWriteInfo}, stateless_transportstate::StatelessTransportState, transportstate::TransportState, };