From 189b4d0d43c75b5e3ab9031ef0008c7ba9ae9378 Mon Sep 17 00:00:00 2001 From: Andrew Pan <3821575+tnytown@users.noreply.github.com> Date: Thu, 4 Jan 2024 12:41:09 -0600 Subject: [PATCH] tls_codec: implement `U24` (#1284) --- tls_codec/src/lib.rs | 50 +++++++++++++++++++++++++++++++-- tls_codec/src/primitives.rs | 5 ++-- tls_codec/src/tls_vec.rs | 23 +++++++++------ tls_codec/tests/decode.rs | 7 +++-- tls_codec/tests/decode_bytes.rs | 11 +++++++- tls_codec/tests/encode.rs | 13 +++++++-- tls_codec/tests/encode_bytes.rs | 19 +++++++++++-- 7 files changed, 106 insertions(+), 22 deletions(-) diff --git a/tls_codec/src/lib.rs b/tls_codec/src/lib.rs index 3a01567ee..768f79dd6 100644 --- a/tls_codec/src/lib.rs +++ b/tls_codec/src/lib.rs @@ -41,9 +41,10 @@ mod quic_vec; mod tls_vec; pub use tls_vec::{ - SecretTlsVecU16, SecretTlsVecU32, SecretTlsVecU8, TlsByteSliceU16, TlsByteSliceU32, - TlsByteSliceU8, TlsByteVecU16, TlsByteVecU32, TlsByteVecU8, TlsSliceU16, TlsSliceU32, - TlsSliceU8, TlsVecU16, TlsVecU32, TlsVecU8, + SecretTlsVecU16, SecretTlsVecU24, SecretTlsVecU32, SecretTlsVecU8, TlsByteSliceU16, + TlsByteSliceU24, TlsByteSliceU32, TlsByteSliceU8, TlsByteVecU16, TlsByteVecU24, TlsByteVecU32, + TlsByteVecU8, TlsSliceU16, TlsSliceU24, TlsSliceU32, TlsSliceU8, TlsVecU16, TlsVecU24, + TlsVecU32, TlsVecU8, }; #[cfg(feature = "std")] @@ -226,3 +227,46 @@ pub trait DeserializeBytes: Size { Ok(out) } } + +/// A 3 byte wide unsigned integer type as defined in [RFC 5246]. +/// +/// [RFC 5246]: https://datatracker.ietf.org/doc/html/rfc5246#section-4.4 +#[derive(Copy, Clone, Debug, Default, PartialEq)] +pub struct U24([u8; 3]); + +impl U24 { + pub const MAX: Self = Self([255u8; 3]); + pub const MIN: Self = Self([0u8; 3]); + + pub fn from_be_bytes(bytes: [u8; 3]) -> Self { + U24(bytes) + } + + pub fn to_be_bytes(self) -> [u8; 3] { + self.0 + } +} + +impl From for usize { + fn from(value: U24) -> usize { + const LEN: usize = core::mem::size_of::(); + let mut usize_bytes = [0u8; LEN]; + usize_bytes[LEN - 3..].copy_from_slice(&value.0); + usize::from_be_bytes(usize_bytes) + } +} + +impl TryFrom for U24 { + type Error = Error; + + fn try_from(value: usize) -> Result { + const LEN: usize = core::mem::size_of::(); + // In practice, our usages of this conversion should never be invalid, as the values + // have to come from `TryFrom for usize`. + if value > (1 << 24) - 1 { + Err(Error::LibraryError) + } else { + Ok(U24(value.to_be_bytes()[LEN - 3..].try_into()?)) + } + } +} diff --git a/tls_codec/src/primitives.rs b/tls_codec/src/primitives.rs index f788f6398..2677f6fc9 100644 --- a/tls_codec/src/primitives.rs +++ b/tls_codec/src/primitives.rs @@ -2,7 +2,7 @@ use alloc::vec::Vec; -use crate::{DeserializeBytes, SerializeBytes}; +use crate::{DeserializeBytes, SerializeBytes, U24}; use super::{Deserialize, Error, Serialize, Size}; @@ -115,7 +115,7 @@ macro_rules! impl_unsigned { #[cfg(feature = "std")] #[inline] fn tls_deserialize(bytes: &mut R) -> Result { - let mut x = (0 as $t).to_be_bytes(); + let mut x = <$t>::default().to_be_bytes(); bytes.read_exact(&mut x)?; Ok(<$t>::from_be_bytes(x)) } @@ -187,6 +187,7 @@ macro_rules! impl_unsigned { impl_unsigned!(u8, 1); impl_unsigned!(u16, 2); +impl_unsigned!(U24, 3); impl_unsigned!(u32, 4); impl_unsigned!(u64, 8); diff --git a/tls_codec/src/tls_vec.rs b/tls_codec/src/tls_vec.rs index 7e91bd0b6..80a948543 100644 --- a/tls_codec/src/tls_vec.rs +++ b/tls_codec/src/tls_vec.rs @@ -13,7 +13,7 @@ use serde::ser::SerializeStruct; use std::io::{Read, Write}; use zeroize::Zeroize; -use crate::{Deserialize, DeserializeBytes, Error, Serialize, SerializeBytes, Size}; +use crate::{Deserialize, DeserializeBytes, Error, Serialize, SerializeBytes, Size, U24}; macro_rules! impl_size { ($self:ident, $size:ty, $name:ident, $len_len:literal) => { @@ -42,7 +42,7 @@ macro_rules! impl_byte_deserialize { #[cfg(feature = "std")] #[inline(always)] fn deserialize_bytes(bytes: &mut R) -> Result { - let len = <$size>::tls_deserialize(bytes)? as usize; + let len = <$size>::tls_deserialize(bytes)?.try_into().unwrap(); // When fuzzing we limit the maximum size to allocate. // XXX: We should think about a configurable limit for the allocation // here. @@ -63,7 +63,7 @@ macro_rules! impl_byte_deserialize { #[inline(always)] fn deserialize_bytes_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { let (type_len, remainder) = <$size>::tls_deserialize_bytes(bytes)?; - let len = type_len as usize; + let len = type_len.try_into().unwrap(); // When fuzzing we limit the maximum size to allocate. // XXX: We should think about a configurable limit for the allocation // here. @@ -92,7 +92,7 @@ macro_rules! impl_deserialize { let len = <$size>::tls_deserialize(bytes)?; let mut read = len.tls_serialized_len(); let len_len = read; - while (read - len_len) < len as usize { + while (read - len_len) < len.try_into().unwrap() { let element = T::tls_deserialize(bytes)?; read += element.tls_serialized_len(); result.push(element); @@ -110,7 +110,7 @@ macro_rules! impl_deserialize_bytes { let (len, mut remainder) = <$size>::tls_deserialize_bytes(bytes)?; let mut read = len.tls_serialized_len(); let len_len = read; - while (read - len_len) < len as usize { + while (read - len_len) < len.try_into().unwrap() { let (element, next_remainder) = T::tls_deserialize_bytes(remainder)?; remainder = next_remainder; read += element.tls_serialized_len(); @@ -130,7 +130,7 @@ macro_rules! impl_serialize { // large and write it out. let (tls_serialized_len, byte_length) = $self.get_content_lengths()?; - let mut written = <$size as Serialize>::tls_serialize(&(byte_length as $size), writer)?; + let mut written = <$size as Serialize>::tls_serialize(&<$size>::try_from(byte_length).unwrap(), writer)?; // Now serialize the elements for e in $self.as_slice().iter() { @@ -152,7 +152,7 @@ macro_rules! impl_byte_serialize { // large and write it out. let (tls_serialized_len, byte_length) = $self.get_content_lengths()?; - let mut written = <$size as Serialize>::tls_serialize(&(byte_length as $size), writer)?; + let mut written = <$size as Serialize>::tls_serialize(&<$size>::try_from(byte_length).unwrap(), writer)?; // Now serialize the elements written += writer.write($self.as_slice())?; @@ -170,7 +170,7 @@ macro_rules! impl_serialize_common { let tls_serialized_len = $self.tls_serialized_len(); let byte_length = tls_serialized_len - $len_len; - let max_len = <$size>::MAX as usize; + let max_len = <$size>::MAX.try_into().unwrap(); debug_assert!( byte_length <= max_len, "Vector length can't be encoded in the vector length a {} >= {}", @@ -207,7 +207,7 @@ macro_rules! impl_serialize_bytes_bytes { let (tls_serialized_len, byte_length) = $self.get_content_lengths()?; let mut vec = Vec::::with_capacity(tls_serialized_len); - let length_vec = <$size as SerializeBytes>::tls_serialize(&(byte_length as $size))?; + let length_vec = <$size as SerializeBytes>::tls_serialize(&byte_length.try_into().unwrap())?; let mut written = length_vec.len(); vec.extend_from_slice(&length_vec); @@ -885,15 +885,18 @@ macro_rules! impl_tls_byte_vec { impl_public_tls_vec!(u8, TlsVecU8, 1); impl_public_tls_vec!(u16, TlsVecU16, 2); +impl_public_tls_vec!(U24, TlsVecU24, 3); impl_public_tls_vec!(u32, TlsVecU32, 4); impl_tls_byte_vec!(u8, TlsByteVecU8, 1); impl_tls_byte_vec!(u16, TlsByteVecU16, 2); +impl_tls_byte_vec!(U24, TlsByteVecU24, 3); impl_tls_byte_vec!(u32, TlsByteVecU32, 4); // Secrets should be put into these Secret tls vectors as they implement zeroize. impl_secret_tls_vec!(u8, SecretTlsVecU8, 1); impl_secret_tls_vec!(u16, SecretTlsVecU16, 2); +impl_secret_tls_vec!(U24, SecretTlsVecU24, 3); impl_secret_tls_vec!(u32, SecretTlsVecU32, 4); // We also implement shallow serialization for slices @@ -948,6 +951,7 @@ macro_rules! impl_tls_byte_slice { impl_tls_byte_slice!(u8, TlsByteSliceU8, 1); impl_tls_byte_slice!(u16, TlsByteSliceU16, 2); +impl_tls_byte_slice!(U24, TlsByteSliceU24, 3); impl_tls_byte_slice!(u32, TlsByteSliceU32, 4); macro_rules! impl_tls_slice { @@ -1003,6 +1007,7 @@ macro_rules! impl_tls_slice { impl_tls_slice!(u8, TlsSliceU8, 1); impl_tls_slice!(u16, TlsSliceU16, 2); +impl_tls_slice!(U24, TlsSliceU24, 3); impl_tls_slice!(u32, TlsSliceU32, 4); impl From for Error { diff --git a/tls_codec/tests/decode.rs b/tls_codec/tests/decode.rs index 38cfca02b..aa7ba6dbb 100644 --- a/tls_codec/tests/decode.rs +++ b/tls_codec/tests/decode.rs @@ -2,7 +2,7 @@ use tls_codec::{ Error, Serialize, Size, TlsByteSliceU16, TlsByteVecU16, TlsByteVecU8, TlsSliceU16, TlsVecU16, - TlsVecU32, TlsVecU8, VLByteSlice, VLBytes, + TlsVecU32, TlsVecU8, VLByteSlice, VLBytes, U24, }; #[test] @@ -41,7 +41,7 @@ fn deserialize_option_bytes() { #[test] fn deserialize_bytes_primitives() { use tls_codec::DeserializeBytes; - let b = &[77u8, 88, 1, 99] as &[u8]; + let b = &[77u8, 88, 1, 99, 1, 0, 73] as &[u8]; let (a, remainder) = u8::tls_deserialize_bytes(b).expect("Unable to tls_deserialize"); assert_eq!(1, a.tls_serialized_len()); @@ -52,6 +52,9 @@ fn deserialize_bytes_primitives() { let (a, remainder) = u16::tls_deserialize_bytes(remainder).expect("Unable to tls_deserialize"); assert_eq!(2, a.tls_serialized_len()); assert_eq!(355, a); + let (a, remainder) = U24::tls_deserialize_bytes(remainder).expect("Unable to tls_deserialize"); + assert_eq!(3, a.tls_serialized_len()); + assert_eq!(U24::try_from(65609usize).unwrap(), a); // It's empty now. assert!(remainder.is_empty()); diff --git a/tls_codec/tests/decode_bytes.rs b/tls_codec/tests/decode_bytes.rs index ff1484cca..67b104805 100644 --- a/tls_codec/tests/decode_bytes.rs +++ b/tls_codec/tests/decode_bytes.rs @@ -1,4 +1,4 @@ -use tls_codec::{DeserializeBytes, TlsByteVecU16, TlsByteVecU32, TlsByteVecU8}; +use tls_codec::{DeserializeBytes, TlsByteVecU16, TlsByteVecU24, TlsByteVecU32, TlsByteVecU8}; #[test] fn deserialize_tls_byte_vec_u8() { @@ -18,6 +18,15 @@ fn deserialize_tls_byte_vec_u16() { assert_eq!(rest, []); } +#[test] +fn deserialize_tls_byte_vec_u24() { + let bytes = [0, 0, 3, 2, 1, 0]; + let (result, rest) = TlsByteVecU24::tls_deserialize_bytes(&bytes).unwrap(); + let expected_result = [2, 1, 0]; + assert_eq!(result.as_slice(), expected_result); + assert_eq!(rest, []); +} + #[test] fn deserialize_tls_byte_vec_u32() { let bytes = [0, 0, 0, 3, 2, 1, 0]; diff --git a/tls_codec/tests/encode.rs b/tls_codec/tests/encode.rs index 3bb91c83c..1c232a2f2 100644 --- a/tls_codec/tests/encode.rs +++ b/tls_codec/tests/encode.rs @@ -1,6 +1,6 @@ #![cfg(feature = "std")] -use tls_codec::{Serialize, TlsVecU16, VLByteSlice, VLBytes}; +use tls_codec::{Serialize, TlsVecU16, TlsVecU24, VLByteSlice, VLBytes, U24}; #[test] fn serialize_primitives() { @@ -8,7 +8,11 @@ fn serialize_primitives() { 77u8.tls_serialize(&mut v).expect("Error encoding u8"); 88u8.tls_serialize(&mut v).expect("Error encoding u8"); 355u16.tls_serialize(&mut v).expect("Error encoding u16"); - let b = [77u8, 88, 1, 99]; + U24::try_from(65609usize) + .unwrap() + .tls_serialize(&mut v) + .expect("Error encoding U24"); + let b = [77u8, 88, 1, 99, 1, 0, 73]; assert_eq!(&b[..], &v[..]); } @@ -19,8 +23,11 @@ fn serialize_tls_vec() { TlsVecU16::::from_slice(&[77, 88, 1, 99]) .tls_serialize(&mut v) .expect("Error encoding u8"); + TlsVecU24::::from_slice(&[255, 42, 73]) + .tls_serialize(&mut v) + .expect("Error encoding u8"); - let b = [1u8, 0, 4, 77, 88, 1, 99]; + let b = [1u8, 0, 4, 77, 88, 1, 99, 0, 0, 3, 255, 42, 73]; assert_eq!(&b[..], &v[..]); } diff --git a/tls_codec/tests/encode_bytes.rs b/tls_codec/tests/encode_bytes.rs index f9b30224b..febc9b15e 100644 --- a/tls_codec/tests/encode_bytes.rs +++ b/tls_codec/tests/encode_bytes.rs @@ -1,4 +1,4 @@ -use tls_codec::{SerializeBytes, TlsByteVecU16, TlsByteVecU32, TlsByteVecU8}; +use tls_codec::{SerializeBytes, TlsByteVecU16, TlsByteVecU24, TlsByteVecU32, TlsByteVecU8, U24}; #[test] fn serialize_primitives() { @@ -6,7 +6,13 @@ fn serialize_primitives() { v.append(&mut 77u8.tls_serialize().expect("Error encoding u8")); v.append(&mut 88u8.tls_serialize().expect("Error encoding u8")); v.append(&mut 355u16.tls_serialize().expect("Error encoding u16")); - let b = [77u8, 88, 1, 99]; + v.append( + &mut U24::try_from(65609usize) + .unwrap() + .tls_serialize() + .expect("Error encoding U24"), + ); + let b = [77u8, 88, 1, 99, 1, 0, 73]; assert_eq!(&b[..], &v[..]); } @@ -59,6 +65,15 @@ fn serialize_tls_byte_vec_u16() { assert_eq!(actual_result, vec![0, 3, 1, 2, 3]); } +#[test] +fn serialize_tls_byte_vec_u24() { + let byte_vec = TlsByteVecU24::from_slice(&[1, 2, 3]); + let actual_result = byte_vec + .tls_serialize() + .expect("Error encoding byte vector"); + assert_eq!(actual_result, vec![0, 0, 3, 1, 2, 3]); +} + #[test] fn serialize_tls_byte_vec_u32() { let byte_vec = TlsByteVecU32::from_slice(&[1, 2, 3]);