Skip to content

Commit

Permalink
Refactor CertValidity to hold ArrayVec instead of &str
Browse files Browse the repository at this point in the history
This will allow platforms which use non-const cert-validities
to return the data correctly, since we cannot copy out a &str from
a function because it lives on the stack.
  • Loading branch information
sree-revoori1 committed Feb 21, 2024
1 parent 5c9f6ed commit e33e311
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 25 deletions.
2 changes: 1 addition & 1 deletion dpe/src/commands/certify_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl CommandExecution for CertifyKeyCmd {
&subject_name,
&pub_key,
&measurements,
cert_validity,
&cert_validity,
)?;
if bytes_written > MAX_CERT_SIZE {
return Err(DpeErrorCode::InternalError);
Expand Down
46 changes: 31 additions & 15 deletions dpe/src/x509.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ impl CertWriter<'_> {
}

/// If `tagged`, include the tag and size fields
fn get_validity_size(validity: CertValidity<'_>, tagged: bool) -> Result<usize, DpeErrorCode> {
let len = Self::get_bytes_size(validity.not_before.as_bytes(), true)?
+ Self::get_bytes_size(validity.not_after.as_bytes(), true)?;
fn get_validity_size(validity: &CertValidity, tagged: bool) -> Result<usize, DpeErrorCode> {
let len = Self::get_bytes_size(validity.not_before.as_slice(), true)?
+ Self::get_bytes_size(validity.not_after.as_slice(), true)?;
Self::get_structure_size(len, tagged)
}

Expand Down Expand Up @@ -470,7 +470,7 @@ impl CertWriter<'_> {
subject_name: &Name,
pubkey: &EcdsaPub,
measurements: &MeasurementData,
validity: CertValidity<'_>,
validity: &CertValidity,
tagged: bool,
) -> Result<usize, DpeErrorCode> {
let tbs_size = Self::get_version_size(/*tagged=*/ true)?
Expand Down Expand Up @@ -886,19 +886,19 @@ impl CertWriter<'_> {
}

// Encode ASN.1 Validity according to Platform
fn encode_validity(&mut self, validity: CertValidity<'_>) -> Result<usize, DpeErrorCode> {
fn encode_validity(&mut self, validity: &CertValidity) -> Result<usize, DpeErrorCode> {
let seq_size = Self::get_validity_size(validity, /*tagged=*/ false)?;

let mut bytes_written = self.encode_tag_field(Self::SEQUENCE_TAG)?;
bytes_written += self.encode_size_field(seq_size)?;

bytes_written += self.encode_tag_field(Self::GENERALIZE_TIME_TAG)?;
bytes_written += self.encode_size_field(validity.not_before.len())?;
bytes_written += self.encode_bytes(validity.not_before.as_bytes())?;
bytes_written += self.encode_bytes(validity.not_before.as_slice())?;

bytes_written += self.encode_tag_field(Self::GENERALIZE_TIME_TAG)?;
bytes_written += self.encode_size_field(validity.not_after.len())?;
bytes_written += self.encode_bytes(validity.not_after.as_bytes())?;
bytes_written += self.encode_bytes(validity.not_after.as_slice())?;

Ok(bytes_written)
}
Expand Down Expand Up @@ -1647,7 +1647,7 @@ impl CertWriter<'_> {
subject_name: &Name,
pubkey: &EcdsaPub,
measurements: &MeasurementData,
validity: CertValidity<'_>,
validity: &CertValidity,
) -> Result<usize, DpeErrorCode> {
let tbs_size = Self::get_tbs_size(
serial_number,
Expand Down Expand Up @@ -1840,7 +1840,7 @@ mod tests {
use crate::x509::{CertWriter, DirectoryString, MeasurementData, Name};
use crate::DPE_PROFILE;
use crypto::{CryptoBuf, EcdsaPub, EcdsaSig};
use platform::CertValidity;
use platform::{ArrayVec, CertValidity};
use std::str;
use x509_parser::certificate::X509CertificateParser;
use x509_parser::nom::Parser;
Expand Down Expand Up @@ -2086,9 +2086,17 @@ mod tests {
supports_recursive: true,
};

let mut not_before = ArrayVec::new();
not_before
.try_extend_from_slice("20230227000000Z".as_bytes())
.unwrap();
let mut not_after = ArrayVec::new();
not_after
.try_extend_from_slice("99991231235959Z".as_bytes())
.unwrap();
let validity = CertValidity {
not_before: "20230227000000Z",
not_after: "99991231235959Z",
not_before,
not_after,
};

let bytes_written = w
Expand All @@ -2098,7 +2106,7 @@ mod tests {
&test_subject_name,
&test_pub,
&measurements,
validity,
&validity,
)
.unwrap();

Expand Down Expand Up @@ -2152,9 +2160,17 @@ mod tests {
supports_recursive: true,
};

let mut not_before = ArrayVec::new();
not_before
.try_extend_from_slice("20230227000000Z".as_bytes())
.unwrap();
let mut not_after = ArrayVec::new();
not_after
.try_extend_from_slice("99991231235959Z".as_bytes())
.unwrap();
let validity = CertValidity {
not_before: "20230227000000Z",
not_after: "99991231235959Z",
not_before,
not_after,
};

let mut tbs_writer = CertWriter::new(cert_buf, true);
Expand All @@ -2165,7 +2181,7 @@ mod tests {
&TEST_SUBJECT_NAME,
&test_pub,
&measurements,
validity,
&validity,
)
.unwrap();

Expand Down
14 changes: 11 additions & 3 deletions platform/src/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,18 @@ impl Platform for DefaultPlatform {
Ok(())
}

fn get_cert_validity<'a>(&mut self) -> Result<CertValidity<'a>, PlatformError> {
fn get_cert_validity(&mut self) -> Result<CertValidity, PlatformError> {
let mut not_before_vec = ArrayVec::new();
not_before_vec
.try_extend_from_slice(NOT_BEFORE.as_bytes())
.map_err(|_| PlatformError::CertValidityError(0))?;
let mut not_after_vec = ArrayVec::new();
not_after_vec
.try_extend_from_slice(NOT_AFTER.as_bytes())
.map_err(|_| PlatformError::CertValidityError(0))?;
Ok(CertValidity {
not_before: NOT_BEFORE,
not_after: NOT_AFTER,
not_before: not_before_vec,
not_after: not_after_vec,
})
}
}
15 changes: 9 additions & 6 deletions platform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ pub const MAX_CHUNK_SIZE: usize = 2048;
pub const MAX_ISSUER_NAME_SIZE: usize = 128;
pub const MAX_SN_SIZE: usize = 20;
pub const MAX_SKI_SIZE: usize = 20;
pub const MAX_VALIDITY_SIZE: usize = 24;

#[allow(variant_size_differences)]
#[derive(Debug, PartialEq, Eq)]
pub enum SignerIdentifier {
IssuerAndSerialNumber {
issuer_name: ArrayVec<u8, { MAX_ISSUER_NAME_SIZE }>,
Expand All @@ -29,10 +30,10 @@ pub enum SignerIdentifier {
SubjectKeyIdentifier(ArrayVec<u8, { MAX_SKI_SIZE }>),
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct CertValidity<'a> {
pub not_before: &'a str,
pub not_after: &'a str,
#[derive(Debug, PartialEq, Eq)]
pub struct CertValidity {
pub not_before: ArrayVec<u8, { MAX_VALIDITY_SIZE }>,
pub not_after: ArrayVec<u8, { MAX_VALIDITY_SIZE }>,
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
Expand All @@ -44,6 +45,7 @@ pub enum PlatformError {
PrintError(u32) = 0x4,
SerialNumberError(u32) = 0x5,
SubjectKeyIdentifierError(u32) = 0x6,
CertValidityError(u32) = 0x7,
}

impl PlatformError {
Expand All @@ -62,6 +64,7 @@ impl PlatformError {
PlatformError::PrintError(code) => Some(*code),
PlatformError::SerialNumberError(code) => Some(*code),
PlatformError::SubjectKeyIdentifierError(code) => Some(*code),
PlatformError::CertValidityError(code) => Some(*code),
}
}
}
Expand Down Expand Up @@ -111,5 +114,5 @@ pub trait Platform {
/// in the yyyyMMddHHmmss format followed by a timezone.
///
/// Example: 99991231235959Z is December 31st, 9999 23:59:59 UTC
fn get_cert_validity<'a>(&mut self) -> Result<CertValidity<'a>, PlatformError>;
fn get_cert_validity(&mut self) -> Result<CertValidity, PlatformError>;
}

0 comments on commit e33e311

Please sign in to comment.