diff --git a/zebra-state/src/service/non_finalized_state.rs b/zebra-state/src/service/non_finalized_state.rs index ebcbb2cfd35..91ae30ae23d 100644 --- a/zebra-state/src/service/non_finalized_state.rs +++ b/zebra-state/src/service/non_finalized_state.rs @@ -9,7 +9,7 @@ use std::{ }; use zebra_chain::{ - block::{self, Block}, + block::{self, Block, Hash}, parameters::Network, sprout, transparent, }; @@ -45,6 +45,10 @@ pub struct NonFinalizedState { /// callers should migrate to `chain_iter().next()`. chain_set: BTreeSet>, + /// Blocks that have been invalidated in, and removed from, the non finalized + /// state. + invalidated_blocks: HashMap>>, + // Configuration // /// The configured Zcash network. @@ -92,6 +96,7 @@ impl Clone for NonFinalizedState { Self { chain_set: self.chain_set.clone(), network: self.network.clone(), + invalidated_blocks: self.invalidated_blocks.clone(), #[cfg(feature = "getblocktemplate-rpcs")] should_count_metrics: self.should_count_metrics, @@ -112,6 +117,7 @@ impl NonFinalizedState { NonFinalizedState { chain_set: Default::default(), network: network.clone(), + invalidated_blocks: Default::default(), #[cfg(feature = "getblocktemplate-rpcs")] should_count_metrics: true, #[cfg(feature = "progress-bar")] @@ -264,6 +270,37 @@ impl NonFinalizedState { Ok(()) } + /// Invalidate block with hash `block_hash` and all descendants from the non-finalized state. Insert + /// the new chain into the chain_set and discard the previous. + pub fn invalidate_block(&mut self, block_hash: Hash) { + let Some(chain) = self.find_chain(|chain| chain.contains_block_hash(block_hash)) else { + return; + }; + + let invalidated_blocks = if chain.non_finalized_root_hash() == block_hash { + self.chain_set.remove(&chain); + chain.blocks.values().cloned().collect() + } else { + let (new_chain, invalidated_blocks) = chain + .invalidate_block(block_hash) + .expect("already checked that chain contains hash"); + + // Add the new chain fork or updated chain to the set of recent chains, and + // remove the chain containing the hash of the block from chain set + self.insert_with(Arc::new(new_chain.clone()), |chain_set| { + chain_set.retain(|c| !c.contains_block_hash(block_hash)) + }); + + invalidated_blocks + }; + + self.invalidated_blocks + .insert(block_hash, Arc::new(invalidated_blocks)); + + self.update_metrics_for_chains(); + self.update_metrics_bars(); + } + /// Commit block to the non-finalized state as a new chain where its parent /// is the finalized tip. #[tracing::instrument(level = "debug", skip(self, finalized_state, prepared))] @@ -586,6 +623,11 @@ impl NonFinalizedState { self.chain_set.len() } + /// Return the invalidated blocks. + pub fn invalidated_blocks(&self) -> HashMap>> { + self.invalidated_blocks.clone() + } + /// Return the chain whose tip block hash is `parent_hash`. /// /// The chain can be an existing chain in the non-finalized state, or a freshly diff --git a/zebra-state/src/service/non_finalized_state/chain.rs b/zebra-state/src/service/non_finalized_state/chain.rs index 6ad284a23f5..c7d0d2877c6 100644 --- a/zebra-state/src/service/non_finalized_state/chain.rs +++ b/zebra-state/src/service/non_finalized_state/chain.rs @@ -359,6 +359,26 @@ impl Chain { (block, treestate) } + // Returns the block at the provided height and all of its descendant blocks. + pub fn child_blocks(&self, block_height: &block::Height) -> Vec { + self.blocks + .range(block_height..) + .map(|(_h, b)| b.clone()) + .collect() + } + + // Returns a new chain without the invalidated block or its descendants. + pub fn invalidate_block( + &self, + block_hash: block::Hash, + ) -> Option<(Self, Vec)> { + let block_height = self.height_by_hash(block_hash)?; + let mut new_chain = self.fork(block_hash)?; + new_chain.pop_tip(); + new_chain.last_fork_height = None; + Some((new_chain, self.child_blocks(&block_height))) + } + /// Returns the height of the chain root. pub fn non_finalized_root_height(&self) -> block::Height { self.blocks @@ -1600,7 +1620,7 @@ impl DerefMut for Chain { /// The revert position being performed on a chain. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -enum RevertPosition { +pub(crate) enum RevertPosition { /// The chain root is being reverted via [`Chain::pop_root`], when a block /// is finalized. Root, @@ -1619,7 +1639,7 @@ enum RevertPosition { /// and [`Chain::pop_tip`] functions, and fear that it would be easy to /// introduce bugs when updating them, unless the code was reorganized to keep /// related operations adjacent to each other. -trait UpdateWith { +pub(crate) trait UpdateWith { /// When `T` is added to the chain tip, /// update [`Chain`] cumulative data members to add data that are derived from `T`. fn update_chain_tip_with(&mut self, _: &T) -> Result<(), ValidateContextError>; diff --git a/zebra-state/src/service/non_finalized_state/tests/vectors.rs b/zebra-state/src/service/non_finalized_state/tests/vectors.rs index b489d6f94f0..5b392e4a0b9 100644 --- a/zebra-state/src/service/non_finalized_state/tests/vectors.rs +++ b/zebra-state/src/service/non_finalized_state/tests/vectors.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use zebra_chain::{ amount::NonNegative, - block::{Block, Height}, + block::{self, Block, Height}, history_tree::NonEmptyHistoryTree, parameters::{Network, NetworkUpgrade}, serialization::ZcashDeserializeInto, @@ -216,6 +216,94 @@ fn finalize_pops_from_best_chain_for_network(network: Network) -> Result<()> { Ok(()) } +fn invalidate_block_removes_block_and_descendants_from_chain_for_network( + network: Network, +) -> Result<()> { + let block1: Arc = Arc::new(network.test_block(653599, 583999).unwrap()); + let block2 = block1.make_fake_child().set_work(10); + let block3 = block2.make_fake_child().set_work(1); + + let mut state = NonFinalizedState::new(&network); + let finalized_state = FinalizedState::new( + &Config::ephemeral(), + &network, + #[cfg(feature = "elasticsearch")] + false, + ); + + let fake_value_pool = ValueBalance::::fake_populated_pool(); + finalized_state.set_finalized_value_pool(fake_value_pool); + + state.commit_new_chain(block1.clone().prepare(), &finalized_state)?; + state.commit_block(block2.clone().prepare(), &finalized_state)?; + state.commit_block(block3.clone().prepare(), &finalized_state)?; + + assert_eq!( + state + .best_chain() + .unwrap_or(&Arc::new(Chain::default())) + .blocks + .len(), + 3 + ); + + state.invalidate_block(block2.hash()); + + let post_invalidated_chain = state.best_chain().unwrap(); + + assert_eq!(post_invalidated_chain.blocks.len(), 1); + assert!( + post_invalidated_chain.contains_block_hash(block1.hash()), + "the new modified chain should contain block1" + ); + + assert!( + !post_invalidated_chain.contains_block_hash(block2.hash()), + "the new modified chain should not contain block2" + ); + assert!( + !post_invalidated_chain.contains_block_hash(block3.hash()), + "the new modified chain should not contain block3" + ); + + let invalidated_blocks_state = &state.invalidated_blocks; + assert!( + invalidated_blocks_state.contains_key(&block2.hash()), + "invalidated blocks map should reference the hash of block2" + ); + + let invalidated_blocks_state_descendants = + invalidated_blocks_state.get(&block2.hash()).unwrap(); + + match network { + Network::Mainnet => assert!( + invalidated_blocks_state_descendants + .iter() + .any(|block| block.height == block::Height(653601)), + "invalidated descendants vec should contain block3" + ), + Network::Testnet(_parameters) => assert!( + invalidated_blocks_state_descendants + .iter() + .any(|block| block.height == block::Height(584001)), + "invalidated descendants vec should contain block3" + ), + } + + Ok(()) +} + +#[test] +fn invalidate_block_removes_block_and_descendants_from_chain() -> Result<()> { + let _init_guard = zebra_test::init(); + + for network in Network::iter() { + invalidate_block_removes_block_and_descendants_from_chain_for_network(network)?; + } + + Ok(()) +} + #[test] // This test gives full coverage for `take_chain_if` fn commit_block_extending_best_chain_doesnt_drop_worst_chains() -> Result<()> {