diff --git a/types/src/protocol_params.rs b/types/src/protocol_params.rs index f3ea1f04..149a354a 100644 --- a/types/src/protocol_params.rs +++ b/types/src/protocol_params.rs @@ -3,6 +3,9 @@ use anyhow::anyhow; use bytes::{Buf, BufMut}; use commonware_codec::{EncodeSize, Error, Read, Write}; +pub const MIN_EPOCH_LENGTH: u64 = 10; +pub const MAX_EPOCH_LENGTH: u64 = 1_814_400; + #[derive(Clone, Debug)] pub enum ProtocolParam { MinimumStake(u64), @@ -47,8 +50,15 @@ impl TryFrom for ProtocolParam { } let bytes: [u8; 8] = request.param.as_slice().try_into()?; let epoch_length = u64::from_le_bytes(bytes); - if epoch_length == 0 { - return Err(anyhow!("Epoch length must be nonzero")); + if epoch_length < MIN_EPOCH_LENGTH { + return Err(anyhow!( + "Epoch length {epoch_length} is below minimum {MIN_EPOCH_LENGTH}" + )); + } + if epoch_length > MAX_EPOCH_LENGTH { + return Err(anyhow!( + "Epoch length {epoch_length} exceeds maximum {MAX_EPOCH_LENGTH}" + )); } Ok(ProtocolParam::EpochLength(epoch_length)) } @@ -93,7 +103,15 @@ impl Read for ProtocolParam { match tag { 0x00 => Ok(ProtocolParam::MinimumStake(value)), 0x01 => Ok(ProtocolParam::MaximumStake(value)), - 0x02 => Ok(ProtocolParam::EpochLength(value)), + 0x02 => { + if !(MIN_EPOCH_LENGTH..=MAX_EPOCH_LENGTH).contains(&value) { + return Err(Error::Invalid( + "ProtocolParam", + "epoch length out of bounds", + )); + } + Ok(ProtocolParam::EpochLength(value)) + } _ => Err(Error::Invalid("ProtocolParam", "unknown tag")), } } @@ -341,4 +359,75 @@ mod tests { let result = ProtocolParam::try_from(request); assert!(result.is_err()); } + + #[test] + fn test_try_from_epoch_length_below_minimum() { + let request = ProtocolParamRequest { + param_id: 0x02, + param: (MIN_EPOCH_LENGTH - 1).to_le_bytes().to_vec(), + }; + assert!(ProtocolParam::try_from(request).is_err()); + } + + #[test] + fn test_try_from_epoch_length_at_minimum() { + let request = ProtocolParamRequest { + param_id: 0x02, + param: MIN_EPOCH_LENGTH.to_le_bytes().to_vec(), + }; + let param = ProtocolParam::try_from(request).unwrap(); + match param { + ProtocolParam::EpochLength(v) => assert_eq!(v, MIN_EPOCH_LENGTH), + _ => panic!("Expected EpochLength"), + } + } + + #[test] + fn test_try_from_epoch_length_above_maximum() { + let request = ProtocolParamRequest { + param_id: 0x02, + param: (MAX_EPOCH_LENGTH + 1).to_le_bytes().to_vec(), + }; + assert!(ProtocolParam::try_from(request).is_err()); + } + + #[test] + fn test_try_from_epoch_length_at_maximum() { + let request = ProtocolParamRequest { + param_id: 0x02, + param: MAX_EPOCH_LENGTH.to_le_bytes().to_vec(), + }; + let param = ProtocolParam::try_from(request).unwrap(); + match param { + ProtocolParam::EpochLength(v) => assert_eq!(v, MAX_EPOCH_LENGTH), + _ => panic!("Expected EpochLength"), + } + } + + #[test] + fn test_decode_epoch_length_out_of_bounds() { + // Below minimum + let mut buf = BytesMut::new(); + buf.put_u8(0x02); + buf.put_u64(1); + assert!(ProtocolParam::read(&mut buf.as_ref()).is_err()); + + // Above maximum + let mut buf = BytesMut::new(); + buf.put_u8(0x02); + buf.put_u64(MAX_EPOCH_LENGTH + 1); + assert!(ProtocolParam::read(&mut buf.as_ref()).is_err()); + } + + #[test] + fn test_decode_epoch_length_within_bounds() { + let mut buf = BytesMut::new(); + buf.put_u8(0x02); + buf.put_u64(500); + let param = ProtocolParam::read(&mut buf.as_ref()).unwrap(); + match param { + ProtocolParam::EpochLength(v) => assert_eq!(v, 500), + _ => panic!("Expected EpochLength"), + } + } }