Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 168 additions & 94 deletions src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use crate::{
Nl80211ElementHeCap, Nl80211ElementHtCap, Nl80211ElementVhtCap,
};

#[cfg(test)]
mod test;

/// [Nl80211Elements] Vec
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Nl80211Elements(pub Vec<Nl80211Element>);
Expand Down Expand Up @@ -214,9 +217,10 @@ const BSS_MEMBERSHIP_SELECTOR_HT_PHY: u8 = 127;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[non_exhaustive]
pub enum Nl80211RateAndSelector {
/// BSS basic rate set in Mb/s.
/// BSS basic rate in units of 500 kb/s, if necessary rounded up to the
/// next 500 kbs.
BssBasicRateSet(u8),
/// Rate in Mb/s.
/// Rate in units of 500 kb/s, if necessary rounded up to the next 500 kbs.
Rate(u8),
SelectorHt,
SelectorVht,
Expand All @@ -237,43 +241,45 @@ pub enum Nl80211RateAndSelector {

impl From<u8> for Nl80211RateAndSelector {
fn from(d: u8) -> Self {
let msb: bool = (d & 1 << 7) > 0;
let value = d & 0b01111111;
const MSB_MASK: u8 = 0b1000_0000;
let msb: bool = (d & MSB_MASK) == MSB_MASK;
let value = d & !MSB_MASK;
if msb {
match value {
BSS_MEMBERSHIP_SELECTOR_SAE_HASH => Self::SelectorSaeHash,
BSS_MEMBERSHIP_SELECTOR_EPD => Self::SelectorEpd,
BSS_MEMBERSHIP_SELECTOR_GLK => Self::SelectorGlk,
BSS_MEMBERSHIP_SELECTOR_VHT_PHY => Self::SelectorVht,
BSS_MEMBERSHIP_SELECTOR_HT_PHY => Self::SelectorHt,
_ => Self::BssBasicRateSet(value / 2),
_ => Self::BssBasicRateSet(value),
}
} else {
Self::Rate(value / 2)
Self::Rate(value)
}
}
}

impl From<Nl80211RateAndSelector> for u8 {
fn from(v: Nl80211RateAndSelector) -> u8 {
const MSB: u8 = 0b1000_0000;
match v {
Nl80211RateAndSelector::BssBasicRateSet(r) => (r * 2) & 1 << 7,
Nl80211RateAndSelector::BssBasicRateSet(r) => r & !MSB | MSB,
Nl80211RateAndSelector::SelectorHt => {
BSS_MEMBERSHIP_SELECTOR_HT_PHY & 1 << 7
BSS_MEMBERSHIP_SELECTOR_HT_PHY | MSB
}
Nl80211RateAndSelector::SelectorVht => {
BSS_MEMBERSHIP_SELECTOR_VHT_PHY & 1 << 7
BSS_MEMBERSHIP_SELECTOR_VHT_PHY | MSB
}
Nl80211RateAndSelector::SelectorGlk => {
BSS_MEMBERSHIP_SELECTOR_GLK & 1 << 7
BSS_MEMBERSHIP_SELECTOR_GLK | MSB
}
Nl80211RateAndSelector::SelectorEpd => {
BSS_MEMBERSHIP_SELECTOR_EPD & 1 << 7
BSS_MEMBERSHIP_SELECTOR_EPD | MSB
}
Nl80211RateAndSelector::SelectorSaeHash => {
BSS_MEMBERSHIP_SELECTOR_SAE_HASH & 1 << 7
BSS_MEMBERSHIP_SELECTOR_SAE_HASH | MSB
}
Nl80211RateAndSelector::Rate(r) => r * 2,
Nl80211RateAndSelector::Rate(r) => r,
}
}
}
Expand Down Expand Up @@ -333,7 +339,7 @@ impl Emitable for Nl80211ElementCountry {
buffer[0] = self.country.as_bytes()[0];
buffer[1] = self.country.as_bytes()[1];
}
buffer[3] = self.environment.into();
buffer[2] = self.environment.into();
for (i, triplet) in self.triplets.as_slice().iter().enumerate() {
triplet.emit(&mut buffer[(i + 1) * 3..(i + 2) * 3]);
}
Expand Down Expand Up @@ -494,6 +500,11 @@ pub struct Nl80211ElementRsn {

impl Nl80211ElementRsn {
pub fn parse(payload: &[u8]) -> Result<Self, DecodeError> {
let wrong_buffer_len = || {
DecodeError::from(format!(
"Invalid buffer length for Nl80211ElementRsn, got {payload:?}"
))
};
if payload.len() != 2 && payload.len() < 8 {
return Err(format!(
"Invalid buffer length of Nl80211ElementRsn, \
Expand All @@ -513,83 +524,83 @@ impl Nl80211ElementRsn {
}

ret.group_cipher = Some(Nl80211CipherSuite::parse(
&payload[offset..offset + Nl80211CipherSuite::LENGTH],
payload
.get(offset..offset + Nl80211CipherSuite::LENGTH)
.ok_or_else(wrong_buffer_len)?,
)?);
offset += Nl80211CipherSuite::LENGTH;

if offset >= payload.len() || offset + 2 >= payload.len() {
return Ok(ret);
}
let pairwise_cipher_count =
u16::from_le_bytes([payload[offset], payload[offset + 1]]) as usize;
offset += 2;
if offset >= payload.len() {
return Ok(ret);
}

let pairwise_cipher_count = parse_u16_le(
payload
.get(offset..offset + 2)
.ok_or_else(wrong_buffer_len)?,
)? as usize;
offset += 2;
for _ in 0..pairwise_cipher_count {
if offset + Nl80211CipherSuite::LENGTH >= payload.len() {
return Ok(ret);
}
ret.pairwise_ciphers.push(Nl80211CipherSuite::parse(
&payload[offset..offset + Nl80211CipherSuite::LENGTH],
payload
.get(offset..offset + Nl80211CipherSuite::LENGTH)
.ok_or_else(wrong_buffer_len)?,
)?);
offset += Nl80211CipherSuite::LENGTH;
}

if offset >= payload.len() || offset + 2 >= payload.len() {
return Ok(ret);
}
let akm_count =
u16::from_le_bytes([payload[offset], payload[offset + 1]]) as usize;
offset += 2;
if offset >= payload.len() {
return Ok(ret);
}

let akm_count = parse_u16_le(
payload
.get(offset..offset + 2)
.ok_or_else(wrong_buffer_len)?,
)? as usize;
offset += 2;
for _ in 0..akm_count {
if offset + Nl80211AkmSuite::LENGTH >= payload.len() {
return Ok(ret);
}
ret.akm_suits.push(Nl80211AkmSuite::parse(
&payload[offset..offset + Nl80211AkmSuite::LENGTH],
payload
.get(offset..offset + Nl80211AkmSuite::LENGTH)
.ok_or_else(wrong_buffer_len)?,
)?);
offset += Nl80211AkmSuite::LENGTH;
}
if offset >= payload.len() || offset + 2 >= payload.len() {
if offset >= payload.len() {
return Ok(ret);
}

ret.rsn_capbilities =
Some(Nl80211RsnCapbilities::parse(&payload[offset..offset + 2])?);
offset += 2;

if offset >= payload.len() || offset + 2 >= payload.len() {
return Ok(ret);
}
let pmkids_count =
u16::from_le_bytes([payload[offset], payload[offset + 1]]) as usize;
ret.rsn_capbilities = Some(Nl80211RsnCapbilities::parse(
payload
.get(offset..offset + 2)
.ok_or_else(wrong_buffer_len)?,
)?);
offset += 2;
if offset >= payload.len() {
return Ok(ret);
}

let pmkids_count = parse_u16_le(
payload
.get(offset..offset + 2)
.ok_or_else(wrong_buffer_len)?,
)? as usize;
offset += 2;
for _ in 0..pmkids_count {
if offset + Nl80211Pmkid::LENGTH >= payload.len() {
return Ok(ret);
}
ret.pmkids.push(Nl80211Pmkid::parse(
&payload[offset..offset + Nl80211Pmkid::LENGTH],
payload
.get(offset..offset + Nl80211Pmkid::LENGTH)
.ok_or_else(wrong_buffer_len)?,
)?);
offset += Nl80211Pmkid::LENGTH;
}

if offset >= payload.len()
|| offset + Nl80211CipherSuite::LENGTH >= payload.len()
{
if offset >= payload.len() {
return Ok(ret);
}

ret.group_mgmt_cipher = Some(Nl80211CipherSuite::parse(
&payload[offset..offset + Nl80211CipherSuite::LENGTH],
payload
.get(offset..offset + Nl80211CipherSuite::LENGTH)
.ok_or_else(wrong_buffer_len)?,
)?);

Ok(ret)
Expand All @@ -600,56 +611,119 @@ impl Emitable for Nl80211ElementRsn {
fn buffer_len(&self) -> usize {
// version field
let mut len = 2usize;
if self.group_cipher.is_none() {
return len;
} else {
len += Nl80211CipherSuite::LENGTH;
}

if self.pairwise_ciphers.is_empty() {
return len;
} else {
len += 2 + self.pairwise_ciphers.len() * Nl80211CipherSuite::LENGTH;
// If any of the following fields do have some content, the next field
// has to be parsed into bytes even if is 0 or 0 for the length
// field.
let fields_with_content = [
(self.group_cipher.is_some(), Nl80211CipherSuite::LENGTH),
(
!self.pairwise_ciphers.is_empty(),
2 + self.pairwise_ciphers.len() * Nl80211CipherSuite::LENGTH,
),
(
!self.akm_suits.is_empty(),
2 + self.akm_suits.len() * Nl80211AkmSuite::LENGTH,
),
(self.rsn_capbilities.is_some(), 2),
(
!self.pmkids.is_empty(),
2 + self.pmkids.len() * Nl80211Pmkid::LENGTH,
),
(self.group_mgmt_cipher.is_some(), Nl80211CipherSuite::LENGTH),
];

let mut i = 0;
while fields_with_content[i..]
.iter()
.any(|&(has_data, _)| has_data)
{
len += fields_with_content[i].1;
i += 1;
}

if self.akm_suits.is_empty() {
return len;
} else {
len += 2 + self.akm_suits.len() * Nl80211AkmSuite::LENGTH;
len
}

fn emit(&self, buffer: &mut [u8]) {
let mut position = 0;

write_u16_le(&mut buffer[position..], self.version);
position += 2;

// If any of the following fields do have some content, the next field
// has to be parsed into bytes even if is 0 or 0 for the length
// field.
let fields_with_content = [
self.group_cipher.is_some(),
!self.pairwise_ciphers.is_empty(),
!self.akm_suits.is_empty(),
self.rsn_capbilities.is_some(),
!self.pmkids.is_empty(),
self.group_mgmt_cipher.is_some(),
];
let mut fields_with_content =
crate::helper::emit::FieldFlags::new(&fields_with_content);

if !fields_with_content.should_emit() {
return;
}

if self.rsn_capbilities.is_none() {
return len;
} else {
len += 2;
write_u32_le(
&mut buffer[position..],
u32::from(self.group_cipher.unwrap_or_default()),
);
position += 4;

if !fields_with_content.should_emit() {
return;
}
write_u16_le(
&mut buffer[position..],
self.pairwise_ciphers.len() as u16,
);
position += 2;
for cipher in &self.pairwise_ciphers {
write_u32_le(&mut buffer[position..], u32::from(*cipher));
position += 4;
}

if self.pmkids.is_empty() {
return len;
} else {
len += 2 + self.pmkids.len() * Nl80211Pmkid::LENGTH;
if !fields_with_content.should_emit() {
return;
}
if self.group_mgmt_cipher.is_none() {
return len;
} else {
len += Nl80211CipherSuite::LENGTH;
write_u16_le(&mut buffer[position..], self.akm_suits.len() as u16);
position += 2;
for akm_suite in &self.akm_suits {
write_u32_le(&mut buffer[position..], u32::from(*akm_suite));
position += 4;
}

len
}
if !fields_with_content.should_emit() {
return;
}
write_u16_le(
&mut buffer[position..],
self.rsn_capbilities.unwrap_or_default().bits(),
);
position += 2;

if !fields_with_content.should_emit() {
return;
}
write_u16_le(&mut buffer[6..8], self.pairwise_ciphers.len() as u16);
position += 2;
for pkmid in &self.pmkids {
buffer[position..].copy_from_slice(&pkmid.0);
position += pkmid.0.len();
}

fn emit(&self, buffer: &mut [u8]) {
write_u16_le(&mut buffer[0..2], self.version);
if let Some(g) = self.group_cipher {
write_u32_le(&mut buffer[2..6], u32::from(g));
write_u16_le(&mut buffer[6..8], self.pairwise_ciphers.len() as u16);
}
for (i, cipher) in self.pairwise_ciphers.as_slice().iter().enumerate() {
write_u32_le(
&mut buffer[(8 + i * 4)..(12 + i * 4)],
u32::from(*cipher),
);
if !fields_with_content.should_emit() {
return;
}
write_u32_le(
&mut buffer[position..],
u32::from(self.group_mgmt_cipher.unwrap_or_default()),
);
}
}

Expand Down
Loading