Skip to content

Commit

Permalink
pr: share the hash computation from slice with from vec
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoranYi committed Jan 17, 2025
1 parent aeb53a4 commit fe25c2d
Showing 1 changed file with 37 additions and 185 deletions.
222 changes: 37 additions & 185 deletions accounts-db/src/accounts_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use {
sysvar::epoch_schedule::EpochSchedule,
},
std::{
borrow::Borrow,
clone,
convert::TryInto,
fs::File,
io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write},
Expand Down Expand Up @@ -362,6 +362,32 @@ impl CumulativeHashesFromFiles {
}
}

trait AsHashSlice: std::marker::Sync + std::marker::Send + clone::Clone {
fn num_hashes(&self) -> usize;
fn get(&self, i: usize) -> &Hash;
}

impl AsHashSlice for &[Hash] {
fn num_hashes(&self) -> usize {
self.len()
}
fn get(&self, i: usize) -> &Hash {
&self[i]
}
}

impl AsHashSlice for Arc<Vec<u8>> {
fn num_hashes(&self) -> usize {
self.len() / std::mem::size_of::<Hash>()
}
fn get(&self, i: usize) -> &Hash {
let start = i * std::mem::size_of::<Hash>();
let end = start + std::mem::size_of::<Hash>();
let bytes = &self[start..end];
unsafe { &*(bytes.as_ptr() as *const Hash) }
}
}

impl CumulativeOffsets {
fn new<I>(iter: I) -> Self
where
Expand Down Expand Up @@ -559,183 +585,9 @@ impl AccountsHasher<'_> {
(num_hashes_per_chunk, levels_hashed, three_level)
}

// This function is called at the top level to compute the merkle hash. It
// takes a closure that returns an owned vec of hash data at the leaf level
// of the merkle tree. The input data for this bottom level are read from a
// file. For non-leaves nodes, where the input data is already in memory, we
// will use `compute_merkle_root_from_slices`, which is a version that takes
// a borrowed slice of hash data instead.
fn compute_merkle_root_from_start<F, T>(
total_hashes: usize,
fanout: usize,
max_levels_per_pass: Option<usize>,
get_hash_slice_starting_at_index: F,
specific_level_count: Option<usize>,
) -> (Hash, Vec<Hash>)
where
// returns a vec hash bytes starting at the given overall index
F: Fn(usize) -> Arc<Vec<u8>> + std::marker::Sync,
T: AsRef<[u8]> + std::marker::Send + std::marker::Sync + bytemuck::Pod,
{
if total_hashes == 0 {
return (Hasher::default().result(), vec![]);
}

let mut time = Measure::start("time");

let (num_hashes_per_chunk, levels_hashed, three_level) = Self::calculate_three_level_chunks(
total_hashes,
fanout,
max_levels_per_pass,
specific_level_count,
);

let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk);

// initial fetch - could return entire slice
let data_bytes = get_hash_slice_starting_at_index(0);
let data: &[T] = bytemuck::cast_slice(&data_bytes);
let data_len = data.len();

let result: Vec<_> = (0..chunks)
.into_par_iter()
.map(|i| {
// summary:
// this closure computes 1 or 3 levels of merkle tree (all chunks will be 1 or all will be 3)
// for a subset (our chunk) of the input data [start_index..end_index]

// index into get_hash_slice_starting_at_index where this chunk's range begins
let start_index = i * num_hashes_per_chunk;
// index into get_hash_slice_starting_at_index where this chunk's range ends
let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes);

// will compute the final result for this closure
let mut hasher = Hasher::default();

// index into 'data' where we are currently pulling data
// if we exhaust our data, then we will request a new slice, and data_index resets to 0, the beginning of the new slice
let mut data_index = start_index;
// source data, which we may refresh when we exhaust
let mut data_bytes = data_bytes.clone();
let mut data: &[T] = bytemuck::cast_slice(&data_bytes);
// len of the source data
let mut data_len = data_len;

if !three_level {
// 1 group of fanout
// The result of this loop is a single hash value from fanout input hashes.
for i in start_index..end_index {
if data_index >= data_len {
// we exhausted our data, fetch next slice starting at i
data_bytes = get_hash_slice_starting_at_index(i);
data = bytemuck::cast_slice(&data_bytes);
data_len = data.len();
data_index = 0;
}
hasher.hash(data[data_index].as_ref());
data_index += 1;
}
} else {
// hash 3 levels of fanout simultaneously.
// This codepath produces 1 hash value for between 1..=fanout^3 input hashes.
// It is equivalent to running the normal merkle tree calculation 3 iterations on the input.
//
// big idea:
// merkle trees usually reduce the input vector by a factor of fanout with each iteration
// example with fanout 2:
// start: [0,1,2,3,4,5,6,7] in our case: [...16M...] or really, 1B
// iteration0 [.5, 2.5, 4.5, 6.5] [... 1M...]
// iteration1 [1.5, 5.5] [...65k...]
// iteration2 3.5 [...4k... ]
// So iteration 0 consumes N elements, hashes them in groups of 'fanout' and produces a vector of N/fanout elements
// and the process repeats until there is only 1 hash left.
//
// With the three_level code path, we make each chunk we iterate of size fanout^3 (4096)
// So, the input could be 16M hashes and the output will be 4k hashes, or N/fanout^3
// The goal is to reduce the amount of data that has to be constructed and held in memory.
// When we know we have enough hashes, then, in 1 pass, we hash 3 levels simultaneously, storing far fewer intermediate hashes.
//
// Now, some details:
// The result of this loop is a single hash value from fanout^3 input hashes.
// concepts:
// what we're conceptually hashing: "raw_hashes"[start_index..end_index]
// example: [a,b,c,d,e,f]
// but... hashes[] may really be multiple vectors that are pieced together.
// example: [[a,b],[c],[d,e,f]]
// get_hash_slice_starting_at_index(any_index) abstracts that and returns a slice starting at raw_hashes[any_index..]
// such that the end of get_hash_slice_starting_at_index may be <, >, or = end_index
// example: get_hash_slice_starting_at_index(1) returns [b]
// get_hash_slice_starting_at_index(3) returns [d,e,f]
// This code is basically 3 iterations of merkle tree hashing occurring simultaneously.
// The first fanout raw hashes are hashed in hasher_k. This is iteration0
// Once hasher_k has hashed fanout hashes, hasher_k's result hash is hashed in hasher_j and then discarded
// hasher_k then starts over fresh and hashes the next fanout raw hashes. This is iteration0 again for a new set of data.
// Once hasher_j has hashed fanout hashes (from k), hasher_j's result hash is hashed in hasher and then discarded
// Once hasher has hashed fanout hashes (from j), then the result of hasher is the hash for fanout^3 raw hashes.
// If there are < fanout^3 hashes, then this code stops when it runs out of raw hashes and returns whatever it hashed.
// This is always how the very last elements work in a merkle tree.
let mut i = start_index;
while i < end_index {
let mut hasher_j = Hasher::default();
for _j in 0..fanout {
let mut hasher_k = Hasher::default();
let end = std::cmp::min(end_index - i, fanout);
for _k in 0..end {
if data_index >= data_len {
// we exhausted our data, fetch next slice starting at i
data_bytes = get_hash_slice_starting_at_index(i);
data = bytemuck::cast_slice(&data_bytes);
data_len = data.len();
data_index = 0;
}
hasher_k.hash(data[data_index].borrow().as_ref());
data_index += 1;
i += 1;
}
hasher_j.hash(hasher_k.result().as_ref());
if i >= end_index {
break;
}
}
hasher.hash(hasher_j.result().as_ref());
}
}

hasher.result()
})
.collect();
time.stop();
debug!("hashing {} {}", total_hashes, time);

if let Some(mut specific_level_count_value) = specific_level_count {
specific_level_count_value -= levels_hashed;
if specific_level_count_value == 0 {
(Hash::default(), result)
} else {
assert!(specific_level_count_value > 0);
// We did not hash the number of levels required by 'specific_level_count', so repeat
Self::compute_merkle_root_from_slices_recurse(
result,
fanout,
max_levels_per_pass,
Some(specific_level_count_value),
)
}
} else {
(
if result.len() == 1 {
result[0]
} else {
Self::compute_merkle_root_recurse(result, fanout)
},
vec![], // no intermediate results needed by caller
)
}
}

// This function is designed to allow hashes to be located in multiple, perhaps multiply deep vecs.
// The caller provides a function to return a slice from the source data.
fn compute_merkle_root_from_slices<'b, F, T>(
fn compute_merkle_root_from_slices<'b, F, U>(
total_hashes: usize,
fanout: usize,
max_levels_per_pass: Option<usize>,
Expand All @@ -744,8 +596,8 @@ impl AccountsHasher<'_> {
) -> (Hash, Vec<Hash>)
where
// returns a slice of hashes starting at the given overall index
F: Fn(usize) -> &'b [T] + std::marker::Sync,
T: Borrow<Hash> + std::marker::Sync + 'b,
F: Fn(usize) -> U + std::marker::Sync,
U: AsHashSlice + 'b,
{
if total_hashes == 0 {
return (Hasher::default().result(), vec![]);
Expand All @@ -764,7 +616,7 @@ impl AccountsHasher<'_> {

// initial fetch - could return entire slice
let data = get_hash_slice_starting_at_index(0);
let data_len = data.len();
let data_len = data.num_hashes();

let result: Vec<_> = (0..chunks)
.into_par_iter()
Expand All @@ -785,7 +637,7 @@ impl AccountsHasher<'_> {
// if we exhaust our data, then we will request a new slice, and data_index resets to 0, the beginning of the new slice
let mut data_index = start_index;
// source data, which we may refresh when we exhaust
let mut data = data;
let mut data = data.clone();
// len of the source data
let mut data_len = data_len;

Expand All @@ -796,10 +648,10 @@ impl AccountsHasher<'_> {
if data_index >= data_len {
// we exhausted our data, fetch next slice starting at i
data = get_hash_slice_starting_at_index(i);
data_len = data.len();
data_len = data.num_hashes();
data_index = 0;
}
hasher.hash(data[data_index].borrow().as_ref());
hasher.hash(data.get(data_index).as_ref());
data_index += 1;
}
} else {
Expand Down Expand Up @@ -851,10 +703,10 @@ impl AccountsHasher<'_> {
if data_index >= data_len {
// we exhausted our data, fetch next slice starting at i
data = get_hash_slice_starting_at_index(i);
data_len = data.len();
data_len = data.num_hashes();
data_index = 0;
}
hasher_k.hash(data[data_index].borrow().as_ref());
hasher_k.hash(data.get(data_index).as_ref());
data_index += 1;
i += 1;
}
Expand Down Expand Up @@ -1344,7 +1196,7 @@ impl AccountsHasher<'_> {

let _guard = self.active_stats.activate(ActiveStatItem::HashMerkleTree);
let mut hash_time = Measure::start("hash");
let (hash, _) = Self::compute_merkle_root_from_start::<_, Hash>(
let (hash, _) = Self::compute_merkle_root_from_slices(
cumulative.total_count(),
MERKLE_FANOUT,
None,
Expand Down

0 comments on commit fe25c2d

Please sign in to comment.