From 706997304bcddc65d684850b70ad62db700f34a5 Mon Sep 17 00:00:00 2001 From: HaoranYi Date: Thu, 16 Jan 2025 15:27:54 +0000 Subject: [PATCH] pr: cap max file read size to 64M --- accounts-db/src/accounts_hash.rs | 109 ++++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 17 deletions(-) diff --git a/accounts-db/src/accounts_hash.rs b/accounts-db/src/accounts_hash.rs index b71cd4a1f02896..7cd3e1f96d4a0b 100644 --- a/accounts-db/src/accounts_hash.rs +++ b/accounts-db/src/accounts_hash.rs @@ -337,8 +337,9 @@ impl CumulativeHashesFromFiles { // return the biggest hash data possible that starts at the overall index 'start' // start is the index of hashes + // The data is returned as raw bytes vec. The caller is responsible for casting the byte slice into Hash slices. fn get_data(&self, start: usize) -> Arc> { - let (start, offset) = self.cumulative.find(start); + let (start, offset, num_hashes) = self.cumulative.find(start); let data_source_index = offset.index[0]; let mut data = self.readers[data_source_index].lock().unwrap(); @@ -347,8 +348,16 @@ impl CumulativeHashesFromFiles { data.seek(SeekFrom::Start(file_offset_in_bytes.try_into().unwrap())) .unwrap(); - let mut result_bytes: Vec = vec![]; - data.read_to_end(&mut result_bytes).unwrap(); + #[cfg(test)] + const MAX_BUFFER_SIZE: usize = 128; // for testing + #[cfg(not(test))] + const MAX_BUFFER_SIZE: usize = 64 * 1024 * 1024; // 64MB + + let eof_offset_in_bytes = std::mem::size_of::() * num_hashes; + let buffer_size = (eof_offset_in_bytes - file_offset_in_bytes).min(MAX_BUFFER_SIZE); + + let mut result_bytes: Vec = vec![0; buffer_size]; + data.read_exact(&mut result_bytes).unwrap(); Arc::new(result_bytes) } } @@ -397,11 +406,21 @@ impl CumulativeOffsets { /// given overall start index 'start' /// return ('start', which is the offset into the data source at 'index', /// and 'index', which is the data source to use) - fn find(&self, start: usize) -> (usize, &CumulativeOffset) { - let index = self.find_index(start); - let index = &self.cumulative_offsets[index]; + /// and number of hashes stored in the data source + fn find(&self, start: usize) -> (usize, &CumulativeOffset, usize) { + let i = self.find_index(start); + let index = &self.cumulative_offsets[i]; let start = start - index.start_offset; - (start, index) + + let i_next = i + 1; + let next_start_offset = if i_next == self.cumulative_offsets.len() { + self.total_count + } else { + let next = &self.cumulative_offsets[i_next]; + next.start_offset + }; + let num_hashes = next_start_offset - index.start_offset; + (start, index, num_hashes) } // return the biggest slice possible that starts at 'start' @@ -409,7 +428,7 @@ impl CumulativeOffsets { where U: ExtractSliceFromRawData<'b, T> + 'b, { - let (start, index) = self.find(start); + let (start, index, _) = self.find(start); raw.extract(index, start) } } @@ -540,7 +559,7 @@ impl AccountsHasher<'_> { (num_hashes_per_chunk, levels_hashed, three_level) } - // This function is called at the top lover level to compute the merkle. It + // This function is called at the top level to compute the merkle hash. It // takes a closure that returns an owned vec of hash data at the leaf level // of the merkle tree. The input data for this bottom level are read from a // file. For non-leaves nodes, where the input data is already in memory, we @@ -1470,8 +1489,11 @@ impl From for SerdeIncrementalAccountsHash { #[cfg(test)] mod tests { use { - super::*, crate::accounts_db::DEFAULT_HASH_CALCULATION_PUBKEY_BINS, itertools::Itertools, - std::str::FromStr, tempfile::tempdir, + super::*, + crate::accounts_db::DEFAULT_HASH_CALCULATION_PUBKEY_BINS, + itertools::Itertools, + std::{char::MAX, str::FromStr}, + tempfile::tempdir, }; lazy_static! { @@ -2284,7 +2306,7 @@ mod tests { }], total_count: 0, }; - assert_eq!(input.find(0), (0, &input.cumulative_offsets[0])); + assert_eq!(input.find(0), (0, &input.cumulative_offsets[0], 0)); let input = CumulativeOffsets { cumulative_offsets: vec![ @@ -2297,12 +2319,12 @@ mod tests { start_offset: 2, }, ], - total_count: 0, + total_count: 2, }; - assert_eq!(input.find(0), (0, &input.cumulative_offsets[0])); // = first start_offset - assert_eq!(input.find(1), (1, &input.cumulative_offsets[0])); // > first start_offset - assert_eq!(input.find(2), (0, &input.cumulative_offsets[1])); // = last start_offset - assert_eq!(input.find(3), (1, &input.cumulative_offsets[1])); // > last start_offset + assert_eq!(input.find(0), (0, &input.cumulative_offsets[0], 2)); // = first start_offset + assert_eq!(input.find(1), (1, &input.cumulative_offsets[0], 2)); // > first start_offset + assert_eq!(input.find(2), (0, &input.cumulative_offsets[1], 0)); // = last start_offset + assert_eq!(input.find(3), (1, &input.cumulative_offsets[1], 0)); // > last start_offset } #[test] @@ -2608,4 +2630,57 @@ mod tests { 2, // accounts above are in 2 groups ); } + + #[test] + fn test_get_data() { + // Create a temporary directory for test files + let temp_dir = tempdir().unwrap(); + + const MAX_BUFFER_SIZE: usize = 128; + let extra_size = 64; + + // Create a test file and write some data to it + let file_path = temp_dir.path().join("test_file"); + let mut file = File::create(&file_path).unwrap(); + let test_data: Vec = (0..(MAX_BUFFER_SIZE + extra_size) as u8).collect(); // 128 + 64 bytes of test data + file.write_all(&test_data).unwrap(); + file.seek(SeekFrom::Start(0)).unwrap(); + drop(file); + + // Create a BufReader for the test file + let file = File::open(&file_path).unwrap(); + let reader = BufReader::new(file); + let readers = vec![Mutex::new(reader)]; + + // Create a CumulativeOffsets instance + let cumulative_offsets = CumulativeOffsets { + cumulative_offsets: vec![CumulativeOffset { + index: [0, 0], + start_offset: 0, + }], + total_count: test_data.len() / std::mem::size_of::(), + }; + + // Create a CumulativeHashesFromFiles instance + let cumulative_hashes = CumulativeHashesFromFiles { + readers, + cumulative: cumulative_offsets, + }; + + // Test get_data function + // First read MAX_BUFFER_SIZE 128 bytes + let start_index = 0; + let result = cumulative_hashes.get_data(start_index); + assert_eq!(result.len(), MAX_BUFFER_SIZE); + assert_eq!(&test_data[..result.len()], &result[..]); + + // Second read 64 bytes + let start_index = 4; + let result = cumulative_hashes.get_data(start_index); + assert_eq!(result.len(), extra_size); + assert_eq!( + &test_data[start_index * 32..start_index * 32 + result.len()], + &result[..] + ); + } }