diff --git a/accounts-db/src/accounts_hash.rs b/accounts-db/src/accounts_hash.rs index e735f07961aa0c..61db8b06b7e52d 100644 --- a/accounts-db/src/accounts_hash.rs +++ b/accounts-db/src/accounts_hash.rs @@ -337,28 +337,32 @@ impl CumulativeHashesFromFiles { // return the biggest hash data possible that starts at the overall index 'start' // start is the index of hashes - // The data is returned as raw bytes vec. The caller is responsible for casting the byte slice into Hash slices. - fn get_data(&self, start: usize) -> Arc> { + fn get_data(&self, start: usize) -> Box<[Hash]> { let (start, offset, num_hashes) = self.cumulative.find(start); let data_source_index = offset.index[0]; let mut data = self.readers[data_source_index].lock().unwrap(); - // unwrap here because we should never ask for data that doesn't exist. If we do, then cumulative calculated incorrectly. + // unwrap here because we should never ask for data that doesn't exist. + // If we do, then cumulative calculated incorrectly. let file_offset_in_bytes = std::mem::size_of::() * start; data.seek(SeekFrom::Start(file_offset_in_bytes.try_into().unwrap())) .unwrap(); #[cfg(test)] - const MAX_BUFFER_SIZE: usize = 128; // for testing + const MAX_BUFFER_SIZE_IN_HASH: usize = 4; // 4 hashes (total 128 bytes) for testing #[cfg(not(test))] - const MAX_BUFFER_SIZE: usize = 64 * 1024 * 1024; // 64MB + const MAX_BUFFER_SIZE_IN_HASH: usize = 2 * 1024 * 1024; // 2M hashes (total 64MB bytes) - let eof_offset_in_bytes = std::mem::size_of::() * num_hashes; - let buffer_size = (eof_offset_in_bytes - file_offset_in_bytes).min(MAX_BUFFER_SIZE); + let remaining_num_hashes = num_hashes - start; + let num_hashes_to_read = remaining_num_hashes.min(MAX_BUFFER_SIZE_IN_HASH); + let mut hashes = vec![Hash::default(); num_hashes_to_read].into_boxed_slice(); - let mut result_bytes: Vec = vec![0; buffer_size]; - data.read_exact(&mut result_bytes).unwrap(); - Arc::new(result_bytes) + // unwrap here because the slice that we are reading into is guaranteed + // to be the correct size. + data.read_exact(bytemuck::must_cast_slice_mut(hashes.as_mut())) + .unwrap(); + + hashes } } @@ -376,15 +380,12 @@ impl AsHashSlice for &[Hash] { } } -impl AsHashSlice for Arc> { +impl AsHashSlice for Arc> { fn num_hashes(&self) -> usize { - self.len() / std::mem::size_of::() + self.len() } fn get(&self, i: usize) -> &Hash { - let start = i * std::mem::size_of::(); - let end = start + std::mem::size_of::(); - let bytes = &self[start..end]; - unsafe { &*(bytes.as_ptr() as *const Hash) } + &self[i] } } @@ -1200,7 +1201,7 @@ impl AccountsHasher<'_> { cumulative.total_count(), MERKLE_FANOUT, None, - |start| cumulative.get_data(start), + |start| Arc::new(cumulative.get_data(start)), None, ); hash_time.stop(); @@ -2485,17 +2486,19 @@ mod tests { // Create a temporary directory for test files let temp_dir = tempdir().unwrap(); - const MAX_BUFFER_SIZE: usize = 128; + const MAX_BUFFER_SIZE_IN_BYTES: usize = 128; let extra_size = 64; // Create a test file and write some data to it let file_path = temp_dir.path().join("test_file"); let mut file = File::create(&file_path).unwrap(); - let test_data: Vec = (0..(MAX_BUFFER_SIZE + extra_size) as u8).collect(); // 128 + 64 bytes of test data + let test_data: Vec = (0..(MAX_BUFFER_SIZE_IN_BYTES + extra_size) as u8).collect(); // 128 + 64 bytes of test data file.write_all(&test_data).unwrap(); file.seek(SeekFrom::Start(0)).unwrap(); drop(file); + let test_data: &[Hash] = bytemuck::cast_slice(&test_data); + // Create a BufReader for the test file let file = File::open(&file_path).unwrap(); let reader = BufReader::new(file); @@ -2507,7 +2510,7 @@ mod tests { index: [0, 0], start_offset: 0, }], - total_count: test_data.len() / std::mem::size_of::(), + total_count: test_data.len(), }; // Create a CumulativeHashesFromFiles instance @@ -2517,18 +2520,21 @@ mod tests { }; // Test get_data function - // First read MAX_BUFFER_SIZE 128 bytes + // First read MAX_BUFFER_SIZE 128 bytes (4 hashes) let start_index = 0; let result = cumulative_hashes.get_data(start_index); - assert_eq!(result.len(), MAX_BUFFER_SIZE); + assert_eq!( + result.len(), + MAX_BUFFER_SIZE_IN_BYTES / std::mem::size_of::() + ); assert_eq!(&test_data[..result.len()], &result[..]); - // Second read 64 bytes - let start_index = 4; + // Second read extra 64 bytes (2 hashes) + let start_index = result.len(); let result = cumulative_hashes.get_data(start_index); - assert_eq!(result.len(), extra_size); + assert_eq!(result.len(), extra_size / std::mem::size_of::()); assert_eq!( - &test_data[start_index * 32..start_index * 32 + result.len()], + &test_data[start_index..start_index + result.len()], &result[..] ); }