From cfb97316e8543d3fdcb1631ece5b1d889d37657f Mon Sep 17 00:00:00 2001 From: 0xripleys <0xripleys@solend.fi> Date: Fri, 12 Jan 2024 02:22:25 -0500 Subject: [PATCH] more expressive ValueCmp filter --- rpc-client-api/src/filter.rs | 264 +++++++++++++++++++++++++---------- 1 file changed, 192 insertions(+), 72 deletions(-) diff --git a/rpc-client-api/src/filter.rs b/rpc-client-api/src/filter.rs index a14e0555f16533..f72d05e3b9f813 100644 --- a/rpc-client-api/src/filter.rs +++ b/rpc-client-api/src/filter.rs @@ -76,13 +76,7 @@ impl RpcFilterType { } } } - RpcFilterType::ValueCmp(compare) => { - if compare.num_bytes > 8 { - Err(RpcFilterError::DataTooLarge) - } else { - Ok(()) - } - } + RpcFilterType::ValueCmp(_) => Ok(()), RpcFilterType::TokenAccountState => Ok(()), } } @@ -91,7 +85,9 @@ impl RpcFilterType { match self { RpcFilterType::DataSize(size) => account.data().len() as u64 == *size, RpcFilterType::Memcmp(compare) => compare.bytes_match(account.data()), - RpcFilterType::ValueCmp(compare) => compare.values_match(account.data()), + RpcFilterType::ValueCmp(compare) => { + compare.values_match(account.data()).unwrap_or(false) + } RpcFilterType::TokenAccountState => Account::valid_account_data(account.data()), } } @@ -117,6 +113,8 @@ pub enum RpcFilterError { Base58DecodeError(#[from] bs58::decode::Error), #[error("base64 decode error")] Base64DecodeError(#[from] base64::DecodeError), + #[error("invalid filter")] + InvalidFilter, } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -231,54 +229,154 @@ impl Memcmp { #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct ValueCmp { - pub offset: usize, - pub num_bytes: u8, - pub value: u64, - cmp_type: ValueCmpType, - endian: EndianType, + pub left: Operand, + comparator: Comparator, + pub right: Operand, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum Operand { + Mem { + offset: usize, + value_type: ValueType, + }, + Constant(String), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum ValueType { + U8, + U16, + U32, + U64, + U128, +} + +enum WrappedValueType { + U8(u8), + U16(u16), + U32(u32), + U64(u64), + U128(u128), } impl ValueCmp { - pub fn values_match(&self, data: &[u8]) -> bool { - if self.offset > data.len() { - return false; - } - if data[self.offset..].len() < self.num_bytes as usize { - return false; + fn parse_mem_into_value_type( + o: &Operand, + data: &[u8], + ) -> Result { + match o { + Operand::Mem { offset, value_type } => match value_type { + ValueType::U8 => { + if *offset >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + + Ok(WrappedValueType::U8(data[*offset])) + } + ValueType::U16 => { + if *offset + 1 >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + Ok(WrappedValueType::U16(u16::from_le_bytes( + data[*offset..*offset + 2].try_into().unwrap(), + ))) + } + ValueType::U32 => { + if *offset + 3 >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + Ok(WrappedValueType::U32(u32::from_le_bytes( + data[*offset..*offset + 4].try_into().unwrap(), + ))) + } + ValueType::U64 => { + if *offset + 7 >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + Ok(WrappedValueType::U64(u64::from_le_bytes( + data[*offset..*offset + 8].try_into().unwrap(), + ))) + } + ValueType::U128 => { + if *offset + 15 >= data.len() { + return Err(RpcFilterError::InvalidFilter); + } + Ok(WrappedValueType::U128(u128::from_le_bytes( + data[*offset..*offset + 16].try_into().unwrap(), + ))) + } + }, + _ => Err(RpcFilterError::InvalidFilter), } - let bytes = &data[self.offset..self.offset + self.num_bytes as usize]; + } + + pub fn values_match(&self, data: &[u8]) -> Result { + match (&self.left, &self.right) { + (left @ Operand::Mem { .. }, right @ Operand::Mem { .. }) => { + let left = Self::parse_mem_into_value_type(left, data)?; + let right = Self::parse_mem_into_value_type(right, data)?; - let mut padded_bytes = [0u8; 8]; - let value = match self.endian { - EndianType::Big => { - padded_bytes[8 - self.num_bytes as usize..].copy_from_slice(bytes); - u64::from_be_bytes(padded_bytes) + match (left, right) { + (WrappedValueType::U8(left), WrappedValueType::U8(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U16(left), WrappedValueType::U16(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U32(left), WrappedValueType::U32(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U64(left), WrappedValueType::U64(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U128(left), WrappedValueType::U128(right)) => { + Ok(self.comparator.compare(left, right)) + } + _ => Err(RpcFilterError::InvalidFilter), + } } - EndianType::Little => { - padded_bytes[..self.num_bytes as usize].copy_from_slice(bytes); - u64::from_le_bytes(padded_bytes) + (left @ Operand::Mem { .. }, Operand::Constant(constant)) => { + match Self::parse_mem_into_value_type(left, data)? { + WrappedValueType::U8(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U16(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U32(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U64(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U128(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidFilter)?; + Ok(self.comparator.compare(left, right)) + } + } } - }; - - match self.cmp_type { - ValueCmpType::Eq => value == self.value, - ValueCmpType::Ne => value != self.value, - ValueCmpType::Gt => value > self.value, - ValueCmpType::Ge => value >= self.value, - ValueCmpType::Lt => value < self.value, - ValueCmpType::Le => value <= self.value, + _ => Err(RpcFilterError::InvalidFilter), } } } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum EndianType { - Big = 0, - Little, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum ValueCmpType { +pub enum Comparator { Eq = 0, Ne, Gt, @@ -287,6 +385,20 @@ pub enum ValueCmpType { Le, } +impl Comparator { + // write a generic function to compare two values + pub fn compare(&self, left: T, right: T) -> bool { + match self { + Comparator::Eq => left == right, + Comparator::Ne => left != right, + Comparator::Gt => left > right, + Comparator::Ge => left >= right, + Comparator::Lt => left < right, + Comparator::Le => left <= right, + } + } +} + // Internal struct to hold Memcmp filter data as either encoded String or raw Bytes #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -467,44 +579,52 @@ mod tests { fn test_values_match() { // test all the ValueCmp cases let data = vec![1, 2, 3, 4, 5]; + + let filter = ValueCmp { + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8, + }, + comparator: Comparator::Eq, + right: Operand::Constant("2".to_string()), + }; + assert!(ValueCmp { - offset: 0, - num_bytes: 4, - cmp_type: ValueCmpType::Gt, - value: 4, - endian: EndianType::Little, + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8 + }, + comparator: Comparator::Eq, + right: Operand::Constant("2".to_string()) } - .values_match(data[0..4].as_ref())); + .values_match(&data) + .unwrap()); assert!(ValueCmp { - offset: 0, - num_bytes: 4, - cmp_type: ValueCmpType::Eq, - value: 67305985, - endian: EndianType::Little, + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8 + }, + comparator: Comparator::Lt, + right: Operand::Constant("3".to_string()) } - .values_match(data[0..4].as_ref())); + .values_match(&data) + .unwrap()); assert!(ValueCmp { - offset: 0, - num_bytes: 2, - cmp_type: ValueCmpType::Eq, - value: 515, - endian: EndianType::Big, + left: Operand::Mem { + offset: 0, + value_type: ValueType::U32 + }, + comparator: Comparator::Eq, + right: Operand::Constant("67305985".to_string()) } - .values_match(data[1..3].as_ref())); - - let filter = ValueCmp { - offset: 0, - num_bytes: 2, - cmp_type: ValueCmpType::Eq, - value: 515, - endian: EndianType::Big, - }; + .values_match(&data) + .unwrap()); // serialize - // let s = serde_json::to_string(&filter).unwrap(); - // println!("{}", s); + let s = serde_json::to_string(&filter).unwrap(); + println!("{}", s); } #[test]