diff --git a/accounts-db/src/accounts_hash.rs b/accounts-db/src/accounts_hash.rs index a3e406472586d7..e735f07961aa0c 100644 --- a/accounts-db/src/accounts_hash.rs +++ b/accounts-db/src/accounts_hash.rs @@ -18,7 +18,7 @@ use { sysvar::epoch_schedule::EpochSchedule, }, std::{ - borrow::Borrow, + clone, convert::TryInto, fs::File, io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write}, @@ -362,6 +362,32 @@ impl CumulativeHashesFromFiles { } } +trait AsHashSlice: std::marker::Sync + std::marker::Send + clone::Clone { + fn num_hashes(&self) -> usize; + fn get(&self, i: usize) -> &Hash; +} + +impl AsHashSlice for &[Hash] { + fn num_hashes(&self) -> usize { + self.len() + } + fn get(&self, i: usize) -> &Hash { + &self[i] + } +} + +impl AsHashSlice for Arc> { + fn num_hashes(&self) -> usize { + self.len() / std::mem::size_of::() + } + fn get(&self, i: usize) -> &Hash { + let start = i * std::mem::size_of::(); + let end = start + std::mem::size_of::(); + let bytes = &self[start..end]; + unsafe { &*(bytes.as_ptr() as *const Hash) } + } +} + impl CumulativeOffsets { fn new(iter: I) -> Self where @@ -559,183 +585,9 @@ impl AccountsHasher<'_> { (num_hashes_per_chunk, levels_hashed, three_level) } - // This function is called at the top level to compute the merkle hash. It - // takes a closure that returns an owned vec of hash data at the leaf level - // of the merkle tree. The input data for this bottom level are read from a - // file. For non-leaves nodes, where the input data is already in memory, we - // will use `compute_merkle_root_from_slices`, which is a version that takes - // a borrowed slice of hash data instead. - fn compute_merkle_root_from_start( - total_hashes: usize, - fanout: usize, - max_levels_per_pass: Option, - get_hash_slice_starting_at_index: F, - specific_level_count: Option, - ) -> (Hash, Vec) - where - // returns a vec hash bytes starting at the given overall index - F: Fn(usize) -> Arc> + std::marker::Sync, - T: AsRef<[u8]> + std::marker::Send + std::marker::Sync + bytemuck::Pod, - { - if total_hashes == 0 { - return (Hasher::default().result(), vec![]); - } - - let mut time = Measure::start("time"); - - let (num_hashes_per_chunk, levels_hashed, three_level) = Self::calculate_three_level_chunks( - total_hashes, - fanout, - max_levels_per_pass, - specific_level_count, - ); - - let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk); - - // initial fetch - could return entire slice - let data_bytes = get_hash_slice_starting_at_index(0); - let data: &[T] = bytemuck::cast_slice(&data_bytes); - let data_len = data.len(); - - let result: Vec<_> = (0..chunks) - .into_par_iter() - .map(|i| { - // summary: - // this closure computes 1 or 3 levels of merkle tree (all chunks will be 1 or all will be 3) - // for a subset (our chunk) of the input data [start_index..end_index] - - // index into get_hash_slice_starting_at_index where this chunk's range begins - let start_index = i * num_hashes_per_chunk; - // index into get_hash_slice_starting_at_index where this chunk's range ends - let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes); - - // will compute the final result for this closure - let mut hasher = Hasher::default(); - - // index into 'data' where we are currently pulling data - // if we exhaust our data, then we will request a new slice, and data_index resets to 0, the beginning of the new slice - let mut data_index = start_index; - // source data, which we may refresh when we exhaust - let mut data_bytes = data_bytes.clone(); - let mut data: &[T] = bytemuck::cast_slice(&data_bytes); - // len of the source data - let mut data_len = data_len; - - if !three_level { - // 1 group of fanout - // The result of this loop is a single hash value from fanout input hashes. - for i in start_index..end_index { - if data_index >= data_len { - // we exhausted our data, fetch next slice starting at i - data_bytes = get_hash_slice_starting_at_index(i); - data = bytemuck::cast_slice(&data_bytes); - data_len = data.len(); - data_index = 0; - } - hasher.hash(data[data_index].as_ref()); - data_index += 1; - } - } else { - // hash 3 levels of fanout simultaneously. - // This codepath produces 1 hash value for between 1..=fanout^3 input hashes. - // It is equivalent to running the normal merkle tree calculation 3 iterations on the input. - // - // big idea: - // merkle trees usually reduce the input vector by a factor of fanout with each iteration - // example with fanout 2: - // start: [0,1,2,3,4,5,6,7] in our case: [...16M...] or really, 1B - // iteration0 [.5, 2.5, 4.5, 6.5] [... 1M...] - // iteration1 [1.5, 5.5] [...65k...] - // iteration2 3.5 [...4k... ] - // So iteration 0 consumes N elements, hashes them in groups of 'fanout' and produces a vector of N/fanout elements - // and the process repeats until there is only 1 hash left. - // - // With the three_level code path, we make each chunk we iterate of size fanout^3 (4096) - // So, the input could be 16M hashes and the output will be 4k hashes, or N/fanout^3 - // The goal is to reduce the amount of data that has to be constructed and held in memory. - // When we know we have enough hashes, then, in 1 pass, we hash 3 levels simultaneously, storing far fewer intermediate hashes. - // - // Now, some details: - // The result of this loop is a single hash value from fanout^3 input hashes. - // concepts: - // what we're conceptually hashing: "raw_hashes"[start_index..end_index] - // example: [a,b,c,d,e,f] - // but... hashes[] may really be multiple vectors that are pieced together. - // example: [[a,b],[c],[d,e,f]] - // get_hash_slice_starting_at_index(any_index) abstracts that and returns a slice starting at raw_hashes[any_index..] - // such that the end of get_hash_slice_starting_at_index may be <, >, or = end_index - // example: get_hash_slice_starting_at_index(1) returns [b] - // get_hash_slice_starting_at_index(3) returns [d,e,f] - // This code is basically 3 iterations of merkle tree hashing occurring simultaneously. - // The first fanout raw hashes are hashed in hasher_k. This is iteration0 - // Once hasher_k has hashed fanout hashes, hasher_k's result hash is hashed in hasher_j and then discarded - // hasher_k then starts over fresh and hashes the next fanout raw hashes. This is iteration0 again for a new set of data. - // Once hasher_j has hashed fanout hashes (from k), hasher_j's result hash is hashed in hasher and then discarded - // Once hasher has hashed fanout hashes (from j), then the result of hasher is the hash for fanout^3 raw hashes. - // If there are < fanout^3 hashes, then this code stops when it runs out of raw hashes and returns whatever it hashed. - // This is always how the very last elements work in a merkle tree. - let mut i = start_index; - while i < end_index { - let mut hasher_j = Hasher::default(); - for _j in 0..fanout { - let mut hasher_k = Hasher::default(); - let end = std::cmp::min(end_index - i, fanout); - for _k in 0..end { - if data_index >= data_len { - // we exhausted our data, fetch next slice starting at i - data_bytes = get_hash_slice_starting_at_index(i); - data = bytemuck::cast_slice(&data_bytes); - data_len = data.len(); - data_index = 0; - } - hasher_k.hash(data[data_index].borrow().as_ref()); - data_index += 1; - i += 1; - } - hasher_j.hash(hasher_k.result().as_ref()); - if i >= end_index { - break; - } - } - hasher.hash(hasher_j.result().as_ref()); - } - } - - hasher.result() - }) - .collect(); - time.stop(); - debug!("hashing {} {}", total_hashes, time); - - if let Some(mut specific_level_count_value) = specific_level_count { - specific_level_count_value -= levels_hashed; - if specific_level_count_value == 0 { - (Hash::default(), result) - } else { - assert!(specific_level_count_value > 0); - // We did not hash the number of levels required by 'specific_level_count', so repeat - Self::compute_merkle_root_from_slices_recurse( - result, - fanout, - max_levels_per_pass, - Some(specific_level_count_value), - ) - } - } else { - ( - if result.len() == 1 { - result[0] - } else { - Self::compute_merkle_root_recurse(result, fanout) - }, - vec![], // no intermediate results needed by caller - ) - } - } - // This function is designed to allow hashes to be located in multiple, perhaps multiply deep vecs. // The caller provides a function to return a slice from the source data. - fn compute_merkle_root_from_slices<'b, F, T>( + fn compute_merkle_root_from_slices<'b, F, U>( total_hashes: usize, fanout: usize, max_levels_per_pass: Option, @@ -744,8 +596,8 @@ impl AccountsHasher<'_> { ) -> (Hash, Vec) where // returns a slice of hashes starting at the given overall index - F: Fn(usize) -> &'b [T] + std::marker::Sync, - T: Borrow + std::marker::Sync + 'b, + F: Fn(usize) -> U + std::marker::Sync, + U: AsHashSlice + 'b, { if total_hashes == 0 { return (Hasher::default().result(), vec![]); @@ -764,7 +616,7 @@ impl AccountsHasher<'_> { // initial fetch - could return entire slice let data = get_hash_slice_starting_at_index(0); - let data_len = data.len(); + let data_len = data.num_hashes(); let result: Vec<_> = (0..chunks) .into_par_iter() @@ -785,7 +637,7 @@ impl AccountsHasher<'_> { // if we exhaust our data, then we will request a new slice, and data_index resets to 0, the beginning of the new slice let mut data_index = start_index; // source data, which we may refresh when we exhaust - let mut data = data; + let mut data = data.clone(); // len of the source data let mut data_len = data_len; @@ -796,10 +648,10 @@ impl AccountsHasher<'_> { if data_index >= data_len { // we exhausted our data, fetch next slice starting at i data = get_hash_slice_starting_at_index(i); - data_len = data.len(); + data_len = data.num_hashes(); data_index = 0; } - hasher.hash(data[data_index].borrow().as_ref()); + hasher.hash(data.get(data_index).as_ref()); data_index += 1; } } else { @@ -851,10 +703,10 @@ impl AccountsHasher<'_> { if data_index >= data_len { // we exhausted our data, fetch next slice starting at i data = get_hash_slice_starting_at_index(i); - data_len = data.len(); + data_len = data.num_hashes(); data_index = 0; } - hasher_k.hash(data[data_index].borrow().as_ref()); + hasher_k.hash(data.get(data_index).as_ref()); data_index += 1; i += 1; } @@ -1344,7 +1196,7 @@ impl AccountsHasher<'_> { let _guard = self.active_stats.activate(ActiveStatItem::HashMerkleTree); let mut hash_time = Measure::start("hash"); - let (hash, _) = Self::compute_merkle_root_from_start::<_, Hash>( + let (hash, _) = Self::compute_merkle_root_from_slices( cumulative.total_count(), MERKLE_FANOUT, None,