diff --git a/src/element.rs b/src/element.rs index cf7bc20..9920749 100644 --- a/src/element.rs +++ b/src/element.rs @@ -9,6 +9,9 @@ use crate::{ Nl80211ElementHeCap, Nl80211ElementHtCap, Nl80211ElementVhtCap, }; +#[cfg(test)] +mod test; + /// [Nl80211Elements] Vec #[derive(Debug, PartialEq, Eq, Clone)] pub struct Nl80211Elements(pub Vec); @@ -214,9 +217,10 @@ const BSS_MEMBERSHIP_SELECTOR_HT_PHY: u8 = 127; #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[non_exhaustive] pub enum Nl80211RateAndSelector { - /// BSS basic rate set in Mb/s. + /// BSS basic rate in units of 500 kb/s, if necessary rounded up to the + /// next 500 kbs. BssBasicRateSet(u8), - /// Rate in Mb/s. + /// Rate in units of 500 kb/s, if necessary rounded up to the next 500 kbs. Rate(u8), SelectorHt, SelectorVht, @@ -237,8 +241,9 @@ pub enum Nl80211RateAndSelector { impl From for Nl80211RateAndSelector { fn from(d: u8) -> Self { - let msb: bool = (d & 1 << 7) > 0; - let value = d & 0b01111111; + const MSB_MASK: u8 = 0b1000_0000; + let msb: bool = (d & MSB_MASK) == MSB_MASK; + let value = d & !MSB_MASK; if msb { match value { BSS_MEMBERSHIP_SELECTOR_SAE_HASH => Self::SelectorSaeHash, @@ -246,34 +251,35 @@ impl From for Nl80211RateAndSelector { BSS_MEMBERSHIP_SELECTOR_GLK => Self::SelectorGlk, BSS_MEMBERSHIP_SELECTOR_VHT_PHY => Self::SelectorVht, BSS_MEMBERSHIP_SELECTOR_HT_PHY => Self::SelectorHt, - _ => Self::BssBasicRateSet(value / 2), + _ => Self::BssBasicRateSet(value), } } else { - Self::Rate(value / 2) + Self::Rate(value) } } } impl From for u8 { fn from(v: Nl80211RateAndSelector) -> u8 { + const MSB: u8 = 0b1000_0000; match v { - Nl80211RateAndSelector::BssBasicRateSet(r) => (r * 2) & 1 << 7, + Nl80211RateAndSelector::BssBasicRateSet(r) => r & !MSB | MSB, Nl80211RateAndSelector::SelectorHt => { - BSS_MEMBERSHIP_SELECTOR_HT_PHY & 1 << 7 + BSS_MEMBERSHIP_SELECTOR_HT_PHY | MSB } Nl80211RateAndSelector::SelectorVht => { - BSS_MEMBERSHIP_SELECTOR_VHT_PHY & 1 << 7 + BSS_MEMBERSHIP_SELECTOR_VHT_PHY | MSB } Nl80211RateAndSelector::SelectorGlk => { - BSS_MEMBERSHIP_SELECTOR_GLK & 1 << 7 + BSS_MEMBERSHIP_SELECTOR_GLK | MSB } Nl80211RateAndSelector::SelectorEpd => { - BSS_MEMBERSHIP_SELECTOR_EPD & 1 << 7 + BSS_MEMBERSHIP_SELECTOR_EPD | MSB } Nl80211RateAndSelector::SelectorSaeHash => { - BSS_MEMBERSHIP_SELECTOR_SAE_HASH & 1 << 7 + BSS_MEMBERSHIP_SELECTOR_SAE_HASH | MSB } - Nl80211RateAndSelector::Rate(r) => r * 2, + Nl80211RateAndSelector::Rate(r) => r, } } } @@ -333,7 +339,7 @@ impl Emitable for Nl80211ElementCountry { buffer[0] = self.country.as_bytes()[0]; buffer[1] = self.country.as_bytes()[1]; } - buffer[3] = self.environment.into(); + buffer[2] = self.environment.into(); for (i, triplet) in self.triplets.as_slice().iter().enumerate() { triplet.emit(&mut buffer[(i + 1) * 3..(i + 2) * 3]); } @@ -494,6 +500,11 @@ pub struct Nl80211ElementRsn { impl Nl80211ElementRsn { pub fn parse(payload: &[u8]) -> Result { + let wrong_buffer_len = || { + DecodeError::from(format!( + "Invalid buffer length for Nl80211ElementRsn, got {payload:?}" + )) + }; if payload.len() != 2 && payload.len() < 8 { return Err(format!( "Invalid buffer length of Nl80211ElementRsn, \ @@ -513,83 +524,83 @@ impl Nl80211ElementRsn { } ret.group_cipher = Some(Nl80211CipherSuite::parse( - &payload[offset..offset + Nl80211CipherSuite::LENGTH], + payload + .get(offset..offset + Nl80211CipherSuite::LENGTH) + .ok_or_else(wrong_buffer_len)?, )?); offset += Nl80211CipherSuite::LENGTH; - - if offset >= payload.len() || offset + 2 >= payload.len() { - return Ok(ret); - } - let pairwise_cipher_count = - u16::from_le_bytes([payload[offset], payload[offset + 1]]) as usize; - offset += 2; if offset >= payload.len() { return Ok(ret); } + let pairwise_cipher_count = parse_u16_le( + payload + .get(offset..offset + 2) + .ok_or_else(wrong_buffer_len)?, + )? as usize; + offset += 2; for _ in 0..pairwise_cipher_count { - if offset + Nl80211CipherSuite::LENGTH >= payload.len() { - return Ok(ret); - } ret.pairwise_ciphers.push(Nl80211CipherSuite::parse( - &payload[offset..offset + Nl80211CipherSuite::LENGTH], + payload + .get(offset..offset + Nl80211CipherSuite::LENGTH) + .ok_or_else(wrong_buffer_len)?, )?); offset += Nl80211CipherSuite::LENGTH; } - - if offset >= payload.len() || offset + 2 >= payload.len() { - return Ok(ret); - } - let akm_count = - u16::from_le_bytes([payload[offset], payload[offset + 1]]) as usize; - offset += 2; if offset >= payload.len() { return Ok(ret); } + + let akm_count = parse_u16_le( + payload + .get(offset..offset + 2) + .ok_or_else(wrong_buffer_len)?, + )? as usize; + offset += 2; for _ in 0..akm_count { - if offset + Nl80211AkmSuite::LENGTH >= payload.len() { - return Ok(ret); - } ret.akm_suits.push(Nl80211AkmSuite::parse( - &payload[offset..offset + Nl80211AkmSuite::LENGTH], + payload + .get(offset..offset + Nl80211AkmSuite::LENGTH) + .ok_or_else(wrong_buffer_len)?, )?); offset += Nl80211AkmSuite::LENGTH; } - if offset >= payload.len() || offset + 2 >= payload.len() { + if offset >= payload.len() { return Ok(ret); } - ret.rsn_capbilities = - Some(Nl80211RsnCapbilities::parse(&payload[offset..offset + 2])?); - offset += 2; - - if offset >= payload.len() || offset + 2 >= payload.len() { - return Ok(ret); - } - let pmkids_count = - u16::from_le_bytes([payload[offset], payload[offset + 1]]) as usize; + ret.rsn_capbilities = Some(Nl80211RsnCapbilities::parse( + payload + .get(offset..offset + 2) + .ok_or_else(wrong_buffer_len)?, + )?); offset += 2; if offset >= payload.len() { return Ok(ret); } + + let pmkids_count = parse_u16_le( + payload + .get(offset..offset + 2) + .ok_or_else(wrong_buffer_len)?, + )? as usize; + offset += 2; for _ in 0..pmkids_count { - if offset + Nl80211Pmkid::LENGTH >= payload.len() { - return Ok(ret); - } ret.pmkids.push(Nl80211Pmkid::parse( - &payload[offset..offset + Nl80211Pmkid::LENGTH], + payload + .get(offset..offset + Nl80211Pmkid::LENGTH) + .ok_or_else(wrong_buffer_len)?, )?); offset += Nl80211Pmkid::LENGTH; } - - if offset >= payload.len() - || offset + Nl80211CipherSuite::LENGTH >= payload.len() - { + if offset >= payload.len() { return Ok(ret); } ret.group_mgmt_cipher = Some(Nl80211CipherSuite::parse( - &payload[offset..offset + Nl80211CipherSuite::LENGTH], + payload + .get(offset..offset + Nl80211CipherSuite::LENGTH) + .ok_or_else(wrong_buffer_len)?, )?); Ok(ret) @@ -600,56 +611,119 @@ impl Emitable for Nl80211ElementRsn { fn buffer_len(&self) -> usize { // version field let mut len = 2usize; - if self.group_cipher.is_none() { - return len; - } else { - len += Nl80211CipherSuite::LENGTH; - } - if self.pairwise_ciphers.is_empty() { - return len; - } else { - len += 2 + self.pairwise_ciphers.len() * Nl80211CipherSuite::LENGTH; + // If any of the following fields do have some content, the next field + // has to be parsed into bytes even if is 0 or 0 for the length + // field. + let fields_with_content = [ + (self.group_cipher.is_some(), Nl80211CipherSuite::LENGTH), + ( + !self.pairwise_ciphers.is_empty(), + 2 + self.pairwise_ciphers.len() * Nl80211CipherSuite::LENGTH, + ), + ( + !self.akm_suits.is_empty(), + 2 + self.akm_suits.len() * Nl80211AkmSuite::LENGTH, + ), + (self.rsn_capbilities.is_some(), 2), + ( + !self.pmkids.is_empty(), + 2 + self.pmkids.len() * Nl80211Pmkid::LENGTH, + ), + (self.group_mgmt_cipher.is_some(), Nl80211CipherSuite::LENGTH), + ]; + + let mut i = 0; + while fields_with_content[i..] + .iter() + .any(|&(has_data, _)| has_data) + { + len += fields_with_content[i].1; + i += 1; } - if self.akm_suits.is_empty() { - return len; - } else { - len += 2 + self.akm_suits.len() * Nl80211AkmSuite::LENGTH; + len + } + + fn emit(&self, buffer: &mut [u8]) { + let mut position = 0; + + write_u16_le(&mut buffer[position..], self.version); + position += 2; + + // If any of the following fields do have some content, the next field + // has to be parsed into bytes even if is 0 or 0 for the length + // field. + let fields_with_content = [ + self.group_cipher.is_some(), + !self.pairwise_ciphers.is_empty(), + !self.akm_suits.is_empty(), + self.rsn_capbilities.is_some(), + !self.pmkids.is_empty(), + self.group_mgmt_cipher.is_some(), + ]; + let mut fields_with_content = + crate::helper::emit::FieldFlags::new(&fields_with_content); + + if !fields_with_content.should_emit() { + return; } - if self.rsn_capbilities.is_none() { - return len; - } else { - len += 2; + write_u32_le( + &mut buffer[position..], + u32::from(self.group_cipher.unwrap_or_default()), + ); + position += 4; + + if !fields_with_content.should_emit() { + return; + } + write_u16_le( + &mut buffer[position..], + self.pairwise_ciphers.len() as u16, + ); + position += 2; + for cipher in &self.pairwise_ciphers { + write_u32_le(&mut buffer[position..], u32::from(*cipher)); + position += 4; } - if self.pmkids.is_empty() { - return len; - } else { - len += 2 + self.pmkids.len() * Nl80211Pmkid::LENGTH; + if !fields_with_content.should_emit() { + return; } - if self.group_mgmt_cipher.is_none() { - return len; - } else { - len += Nl80211CipherSuite::LENGTH; + write_u16_le(&mut buffer[position..], self.akm_suits.len() as u16); + position += 2; + for akm_suite in &self.akm_suits { + write_u32_le(&mut buffer[position..], u32::from(*akm_suite)); + position += 4; } - len - } + if !fields_with_content.should_emit() { + return; + } + write_u16_le( + &mut buffer[position..], + self.rsn_capbilities.unwrap_or_default().bits(), + ); + position += 2; + + if !fields_with_content.should_emit() { + return; + } + write_u16_le(&mut buffer[6..8], self.pairwise_ciphers.len() as u16); + position += 2; + for pkmid in &self.pmkids { + buffer[position..].copy_from_slice(&pkmid.0); + position += pkmid.0.len(); + } - fn emit(&self, buffer: &mut [u8]) { - write_u16_le(&mut buffer[0..2], self.version); - if let Some(g) = self.group_cipher { - write_u32_le(&mut buffer[2..6], u32::from(g)); - write_u16_le(&mut buffer[6..8], self.pairwise_ciphers.len() as u16); - } - for (i, cipher) in self.pairwise_ciphers.as_slice().iter().enumerate() { - write_u32_le( - &mut buffer[(8 + i * 4)..(12 + i * 4)], - u32::from(*cipher), - ); + if !fields_with_content.should_emit() { + return; } + write_u32_le( + &mut buffer[position..], + u32::from(self.group_mgmt_cipher.unwrap_or_default()), + ); } } diff --git a/src/element/test.rs b/src/element/test.rs new file mode 100644 index 0000000..b4eb555 --- /dev/null +++ b/src/element/test.rs @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT + +use netlink_packet_core::{Emitable, Parseable}; + +use super::{ + Nl80211AkmSuite, Nl80211CipherSuite, Nl80211Element, Nl80211ElementCountry, + Nl80211ElementCountryEnvironment, Nl80211ElementCountryTriplet, + Nl80211ElementRsn, Nl80211ElementSubBand, Nl80211RateAndSelector, +}; + +#[test] +fn ssid() { + let val: Nl80211Element = Nl80211Element::Ssid("test-ssid".to_owned()); + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +} + +#[test] +fn rates_and_selectors() { + let val: Nl80211Element = Nl80211Element::SupportedRatesAndSelectors(vec![ + Nl80211RateAndSelector::BssBasicRateSet(1), + Nl80211RateAndSelector::Rate(1), + Nl80211RateAndSelector::SelectorHt, + Nl80211RateAndSelector::SelectorVht, + Nl80211RateAndSelector::SelectorGlk, + ]); + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +} + +#[test] +fn channel() { + let val: Nl80211Element = Nl80211Element::Channel(7); + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +} + +#[test] +fn country() { + let val: Nl80211Element = Nl80211Element::Country(Nl80211ElementCountry { + country: "DE".to_owned(), + environment: Nl80211ElementCountryEnvironment::IndoorAndOutdoor, + triplets: vec![Nl80211ElementCountryTriplet::Subband( + Nl80211ElementSubBand { + channel_start: 1, + channel_count: 13, + max_power_level: 20, + }, + )], + }); + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +} + +#[test] +fn rsn() { + let val: Nl80211Element = Nl80211Element::Rsn(Nl80211ElementRsn { + version: 1, + group_cipher: Some(Nl80211CipherSuite::Ccmp128), + pairwise_ciphers: vec![Nl80211CipherSuite::Ccmp128], + akm_suits: vec![Nl80211AkmSuite::Psk], + rsn_capbilities: None, + pmkids: Vec::new(), + group_mgmt_cipher: None, + }); + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +} diff --git a/src/helper.rs b/src/helper.rs new file mode 100644 index 0000000..79096e3 --- /dev/null +++ b/src/helper.rs @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT + +pub(crate) mod emit { + pub(crate) struct FieldFlags<'a> { + flags: &'a [bool], + pos: usize, + } + + impl<'a> FieldFlags<'a> { + pub(crate) fn new(flags: &'a [bool]) -> Self { + Self { flags, pos: 0 } + } + + pub(crate) fn should_emit(&mut self) -> bool { + let emit = self.flags[self.pos..].iter().any(|&f| f); + self.pos += 1; + emit + } + } +} diff --git a/src/lib.rs b/src/lib.rs index ab602b6..5bc83da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ mod ext_cap; mod feature; mod frame_type; mod handle; +mod helper; mod iface; mod macros; mod message; diff --git a/src/wifi4.rs b/src/wifi4.rs index 8695dfd..1e4b5e5 100644 --- a/src/wifi4.rs +++ b/src/wifi4.rs @@ -9,6 +9,9 @@ use netlink_packet_core::{ use crate::bytes::{get_bit, get_bits_as_u8, write_u16_le}; +#[cfg(test)] +mod test; + const NL80211_CHAN_NO_HT: u32 = 0; const NL80211_CHAN_HT20: u32 = 1; const NL80211_CHAN_HT40MINUS: u32 = 2; @@ -123,7 +126,7 @@ impl Emitable for Nl80211HtCaps { } fn emit(&self, buffer: &mut [u8]) { - buffer.copy_from_slice(&self.bits().to_ne_bytes()) + buffer[0..self.buffer_len()].copy_from_slice(&self.bits().to_ne_bytes()) } } @@ -469,9 +472,9 @@ impl From for [u8; 2] { fn from(v: Nl80211HtExtendedCap) -> [u8; 2] { [ v.pco as u8 | (v.pco_trans_time << 1) | (v.mcs_feedback & 0b1) << 7, - ((v.mcs_feedback & 0b10) >> 1) - | ((v.support_ht_control as u8) << 1) - | ((v.rd_responder as u8) << 2), + (v.mcs_feedback & 0b11) + | ((v.support_ht_control as u8) << 2) + | ((v.rd_responder as u8) << 3), ] } } @@ -626,7 +629,7 @@ impl Emitable for Nl80211HtTransmitBeamformingCaps { } fn emit(&self, buffer: &mut [u8]) { - buffer.copy_from_slice(&self.bits().to_ne_bytes()) + buffer[0..self.buffer_len()].copy_from_slice(&self.bits().to_ne_bytes()) } } @@ -674,6 +677,6 @@ impl Emitable for Nl80211HtAselCaps { } fn emit(&self, buffer: &mut [u8]) { - buffer.copy_from_slice(&self.bits().to_ne_bytes()) + buffer[0..self.buffer_len()].copy_from_slice(&self.bits().to_ne_bytes()) } } diff --git a/src/wifi4/test.rs b/src/wifi4/test.rs new file mode 100644 index 0000000..3570581 --- /dev/null +++ b/src/wifi4/test.rs @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT + +use super::{ + Emitable, Nl80211ElementHtCap, Nl80211HtAMpduPara, Nl80211HtAselCaps, + Nl80211HtCaps, Nl80211HtExtendedCap, Nl80211HtMcsInfo, + Nl80211HtTransmitBeamformingCaps, Nl80211HtTxParameter, + Nl80211HtWiphyChannelType, Parseable, IEEE80211_HT_MCS_MASK_LEN, + NL80211_CHAN_HT40PLUS, +}; + +#[test] +fn caps() { + let val: Nl80211HtCaps = Nl80211HtCaps::all(); + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +} + +#[test] +fn asel_caps() { + let val: Nl80211HtAselCaps = Nl80211HtAselCaps::all(); + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +} + +#[test] +fn transmit_beamforming_cap() { + let val: Nl80211HtTransmitBeamformingCaps = + Nl80211HtTransmitBeamformingCaps::all(); + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]) + .unwrap(), + val, + ); +} + +#[test] +fn tx_params() { + let val: Nl80211HtTxParameter = Nl80211HtTxParameter { + mcs_set_defined: false, + tx_rx_mcs_set_not_equal: false, + max_spatial_streams: 1, + unequal_modulation_supported: false, + }; + let into: u8 = val.into(); + assert_eq!(::from(into), val,); +} + +#[test] +fn ht_wiphy_no_ht() { + let val: Nl80211HtWiphyChannelType = Nl80211HtWiphyChannelType::NoHt; + let into: u32 = val.into(); + assert_eq!(::from(into), val,); +} + +#[test] +fn ht_wiphy_ht_20() { + let val: Nl80211HtWiphyChannelType = Nl80211HtWiphyChannelType::Ht20; + let into: u32 = val.into(); + assert_eq!(::from(into), val,); +} + +#[test] +fn ht_wiphy_other() { + let val: Nl80211HtWiphyChannelType = + Nl80211HtWiphyChannelType::Other(NL80211_CHAN_HT40PLUS + 1); + let into: u32 = val.into(); + assert_eq!(::from(into), val,); +} + +#[test] +fn mcs_info() { + let val: Nl80211HtMcsInfo = Nl80211HtMcsInfo { + rx_mask: [0xA5; IEEE80211_HT_MCS_MASK_LEN], + rx_highest: u16::MAX, + tx_params: Nl80211HtTxParameter { + mcs_set_defined: false, + tx_rx_mcs_set_not_equal: false, + max_spatial_streams: 1, + unequal_modulation_supported: false, + }, + }; + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +} + +#[test] +fn a_mpdu_para() { + let val: Nl80211HtAMpduPara = Nl80211HtAMpduPara { + max_len_exponent: u8::MAX & 0b11, + min_space: u8::MAX & 0b111, + }; + let into: u8 = val.into(); + assert_eq!(::from(into), val,); +} + +#[test] +fn extend_cap() { + let val: Nl80211HtExtendedCap = Nl80211HtExtendedCap { + pco: true, + pco_trans_time: 1, + mcs_feedback: 1, + support_ht_control: true, + rd_responder: true, + }; + let into: [u8; 2] = val.into(); + assert_eq!(::from(into), val,); +} + +#[test] +fn cap_mask() { + let val: Nl80211ElementHtCap = Nl80211ElementHtCap { + caps: Nl80211HtCaps::all(), + a_mpdu_para: Nl80211HtAMpduPara { + max_len_exponent: 3, + min_space: 7, + }, + mcs_set: Nl80211HtMcsInfo { + rx_mask: [0xA5; IEEE80211_HT_MCS_MASK_LEN], + rx_highest: u16::MAX, + tx_params: Nl80211HtTxParameter { + mcs_set_defined: false, + tx_rx_mcs_set_not_equal: false, + max_spatial_streams: 1, + unequal_modulation_supported: false, + }, + }, + ht_ext_cap: Nl80211HtExtendedCap { + pco: true, + pco_trans_time: 2, + mcs_feedback: 2, + support_ht_control: true, + rd_responder: true, + }, + transmit_beamforming_cap: Nl80211HtTransmitBeamformingCaps::all(), + asel_cap: Nl80211HtAselCaps::all(), + }; + let mut buffer = vec![0; val.buffer_len() + 1]; + val.emit(buffer.as_mut_slice()); + assert_eq!( + ::parse(&buffer[0..val.buffer_len()]).unwrap(), + val, + ); +}