diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 1f7147f3203..27ae32d006f 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -36,8 +36,17 @@ use lightning::onion_message::messenger::AOnionMessenger; use lightning::routing::gossip::{NetworkGraph, P2PGossipSync}; use lightning::routing::scoring::{ScoreUpdate, WriteableScore}; use lightning::routing::utxo::UtxoLookup; +#[cfg(feature = "futures")] +use lightning::sign::ChangeDestinationSource; +#[cfg(feature = "std")] +use lightning::sign::ChangeDestinationSourceSync; +use lightning::sign::OutputSpender; use lightning::util::logger::Logger; -use lightning::util::persist::Persister; +use lightning::util::persist::{KVStore, Persister}; +#[cfg(feature = "futures")] +use lightning::util::sweep::OutputSweeper; +#[cfg(feature = "std")] +use lightning::util::sweep::OutputSweeperSync; #[cfg(feature = "std")] use lightning::util::wakers::Sleeper; use lightning_rapid_gossip_sync::RapidGossipSync; @@ -132,6 +141,11 @@ const REBROADCAST_TIMER: u64 = 30; #[cfg(test)] const REBROADCAST_TIMER: u64 = 1; +#[cfg(not(test))] +const SWEEPER_TIMER: u64 = 30; +#[cfg(test)] +const SWEEPER_TIMER: u64 = 1; + #[cfg(feature = "futures")] /// core::cmp::min is not currently const, so we define a trivial (and equivalent) replacement const fn min_u64(a: u64, b: u64) -> u64 { @@ -308,6 +322,7 @@ macro_rules! define_run_body { $channel_manager: ident, $process_channel_manager_events: expr, $onion_messenger: ident, $process_onion_message_handler_events: expr, $peer_manager: ident, $gossip_sync: ident, + $process_sweeper: expr, $logger: ident, $scorer: ident, $loop_exit_check: expr, $await: expr, $get_timer: expr, $timer_elapsed: expr, $check_slow_await: expr, $time_fetch: expr, ) => { { @@ -322,6 +337,7 @@ macro_rules! define_run_body { let mut last_prune_call = $get_timer(FIRST_NETWORK_PRUNE_TIMER); let mut last_scorer_persist_call = $get_timer(SCORER_PERSIST_TIMER); let mut last_rebroadcast_call = $get_timer(REBROADCAST_TIMER); + let mut last_sweeper_call = $get_timer(SWEEPER_TIMER); let mut have_pruned = false; let mut have_decayed_scorer = false; @@ -465,6 +481,12 @@ macro_rules! define_run_body { $chain_monitor.rebroadcast_pending_claims(); last_rebroadcast_call = $get_timer(REBROADCAST_TIMER); } + + if $timer_elapsed(&mut last_sweeper_call, SWEEPER_TIMER) { + log_trace!($logger, "Regenerating sweeper spends if necessary"); + let _ = $process_sweeper; + last_sweeper_call = $get_timer(SWEEPER_TIMER); + } } // After we exit, ensure we persist the ChannelManager one final time - this avoids @@ -627,6 +649,7 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// ``` /// # use lightning::io; /// # use lightning::events::ReplayEvent; +/// # use lightning::util::sweep::OutputSweeper; /// # use std::sync::{Arc, RwLock}; /// # use std::sync::atomic::{AtomicBool, Ordering}; /// # use std::time::SystemTime; @@ -666,6 +689,9 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// # F: lightning::chain::Filter + Send + Sync + 'static, /// # FE: lightning::chain::chaininterface::FeeEstimator + Send + Sync + 'static, /// # UL: lightning::routing::utxo::UtxoLookup + Send + Sync + 'static, +/// # D: lightning::sign::ChangeDestinationSource + Send + Sync + 'static, +/// # K: lightning::util::persist::KVStore + Send + Sync + 'static, +/// # O: lightning::sign::OutputSpender + Send + Sync + 'static, /// # > { /// # peer_manager: Arc>, /// # event_handler: Arc, @@ -677,6 +703,7 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// # persister: Arc, /// # logger: Arc, /// # scorer: Arc, +/// # sweeper: Arc, Arc, Arc, Arc, Arc, Arc, Arc>>, /// # } /// # /// # async fn setup_background_processing< @@ -684,7 +711,10 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// # F: lightning::chain::Filter + Send + Sync + 'static, /// # FE: lightning::chain::chaininterface::FeeEstimator + Send + Sync + 'static, /// # UL: lightning::routing::utxo::UtxoLookup + Send + Sync + 'static, -/// # >(node: Node) { +/// # D: lightning::sign::ChangeDestinationSource + Send + Sync + 'static, +/// # K: lightning::util::persist::KVStore + Send + Sync + 'static, +/// # O: lightning::sign::OutputSpender + Send + Sync + 'static, +/// # >(node: Node) { /// let background_persister = Arc::clone(&node.persister); /// let background_event_handler = Arc::clone(&node.event_handler); /// let background_chain_mon = Arc::clone(&node.chain_monitor); @@ -695,7 +725,8 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// let background_liquidity_manager = Arc::clone(&node.liquidity_manager); /// let background_logger = Arc::clone(&node.logger); /// let background_scorer = Arc::clone(&node.scorer); -/// +/// let background_sweeper = Arc::clone(&node.sweeper); + /// // Setup the sleeper. #[cfg_attr( feature = "std", @@ -729,6 +760,7 @@ use futures_util::{dummy_waker, OptionalSelector, Selector, SelectorOutput}; /// background_gossip_sync, /// background_peer_man, /// Some(background_liquidity_manager), +/// Some(background_sweeper), /// background_logger, /// Some(background_scorer), /// sleeper, @@ -767,6 +799,10 @@ pub async fn process_events_async< RGS: 'static + Deref>, PM: 'static + Deref, LM: 'static + Deref, + D: 'static + Deref, + O: 'static + Deref, + K: 'static + Deref, + OS: 'static + Deref>, S: 'static + Deref + Send + Sync, SC: for<'b> WriteableScore<'b>, SleepFuture: core::future::Future + core::marker::Unpin, @@ -775,12 +811,12 @@ pub async fn process_events_async< >( persister: PS, event_handler: EventHandler, chain_monitor: M, channel_manager: CM, onion_messenger: Option, gossip_sync: GossipSync, peer_manager: PM, - liquidity_manager: Option, logger: L, scorer: Option, sleeper: Sleeper, - mobile_interruptable_platform: bool, fetch_time: FetchTime, + liquidity_manager: Option, sweeper: Option, logger: L, scorer: Option, + sleeper: Sleeper, mobile_interruptable_platform: bool, fetch_time: FetchTime, ) -> Result<(), lightning::io::Error> where UL::Target: 'static + UtxoLookup, - CF::Target: 'static + chain::Filter, + CF::Target: 'static + chain::Filter + Sync + Send, T::Target: 'static + BroadcasterInterface, F::Target: 'static + FeeEstimator, L::Target: 'static + Logger, @@ -790,6 +826,9 @@ where OM::Target: AOnionMessenger, PM::Target: APeerManager, LM::Target: ALiquidityManager, + O::Target: 'static + OutputSpender, + D::Target: 'static + ChangeDestinationSource, + K::Target: 'static + KVStore, { let mut should_break = false; let async_event_handler = |event| { @@ -833,6 +872,13 @@ where }, peer_manager, gossip_sync, + { + if let Some(ref sweeper) = sweeper { + sweeper.regenerate_and_broadcast_spend_if_necessary().await + } else { + Ok(()) + } + }, logger, scorer, should_break, @@ -953,14 +999,18 @@ impl BackgroundProcessor { LM: 'static + Deref + Send, S: 'static + Deref + Send + Sync, SC: for<'b> WriteableScore<'b>, + D: 'static + Deref, + O: 'static + Deref, + K: 'static + Deref, + OS: 'static + Deref> + Send + Sync, >( persister: PS, event_handler: EH, chain_monitor: M, channel_manager: CM, onion_messenger: Option, gossip_sync: GossipSync, peer_manager: PM, - liquidity_manager: Option, logger: L, scorer: Option, + liquidity_manager: Option, sweeper: Option, logger: L, scorer: Option, ) -> Self where UL::Target: 'static + UtxoLookup, - CF::Target: 'static + chain::Filter, + CF::Target: 'static + chain::Filter + Sync + Send, T::Target: 'static + BroadcasterInterface, F::Target: 'static + FeeEstimator, L::Target: 'static + Logger, @@ -970,6 +1020,9 @@ impl BackgroundProcessor { OM::Target: AOnionMessenger, PM::Target: APeerManager, LM::Target: ALiquidityManager, + D::Target: ChangeDestinationSourceSync, + O::Target: 'static + OutputSpender, + K::Target: 'static + KVStore, { let stop_thread = Arc::new(AtomicBool::new(false)); let stop_thread_clone = stop_thread.clone(); @@ -1005,6 +1058,13 @@ impl BackgroundProcessor { }, peer_manager, gossip_sync, + { + if let Some(ref sweeper) = sweeper { + sweeper.regenerate_and_broadcast_spend_if_necessary() + } else { + Ok(()) + } + }, logger, scorer, stop_thread.load(Ordering::Acquire), @@ -1127,7 +1187,7 @@ mod tests { use lightning::routing::gossip::{NetworkGraph, P2PGossipSync}; use lightning::routing::router::{CandidateRouteHop, DefaultRouter, Path, RouteHop}; use lightning::routing::scoring::{ChannelUsage, LockableScore, ScoreLookUp, ScoreUpdate}; - use lightning::sign::{ChangeDestinationSource, InMemorySigner, KeysManager}; + use lightning::sign::{ChangeDestinationSourceSync, InMemorySigner, KeysManager}; use lightning::types::features::{ChannelFeatures, NodeFeatures}; use lightning::types::payment::PaymentHash; use lightning::util::config::UserConfig; @@ -1139,7 +1199,7 @@ mod tests { SCORER_PERSISTENCE_SECONDARY_NAMESPACE, }; use lightning::util::ser::Writeable; - use lightning::util::sweep::{OutputSpendStatus, OutputSweeper, PRUNE_DELAY_BLOCKS}; + use lightning::util::sweep::{OutputSpendStatus, OutputSweeperSync, PRUNE_DELAY_BLOCKS}; use lightning::util::test_utils; use lightning::{get_event, get_event_msg}; use lightning_liquidity::LiquidityManager; @@ -1265,11 +1325,11 @@ mod tests { best_block: BestBlock, scorer: Arc>, sweeper: Arc< - OutputSweeper< + OutputSweeperSync< Arc, Arc, Arc, - Arc, + Arc, Arc, Arc, Arc, @@ -1566,7 +1626,7 @@ mod tests { struct TestWallet {} - impl ChangeDestinationSource for TestWallet { + impl ChangeDestinationSourceSync for TestWallet { fn get_change_destination_script(&self) -> Result { Ok(ScriptBuf::new()) } @@ -1644,11 +1704,11 @@ mod tests { IgnoringMessageHandler {}, )); let wallet = Arc::new(TestWallet {}); - let sweeper = Arc::new(OutputSweeper::new( + let sweeper = Arc::new(OutputSweeperSync::new( best_block, Arc::clone(&tx_broadcaster), Arc::clone(&fee_estimator), - None::>, + None::>, Arc::clone(&keys_manager), wallet, Arc::clone(&kv_store), @@ -1888,6 +1948,7 @@ mod tests { nodes[0].p2p_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -1982,6 +2043,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2025,6 +2087,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2058,6 +2121,7 @@ mod tests { nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.sweeper_async()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), move |dur: Duration| { @@ -2095,6 +2159,7 @@ mod tests { nodes[0].p2p_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2125,6 +2190,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2172,6 +2238,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2235,6 +2302,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2280,10 +2348,22 @@ mod tests { advance_chain(&mut nodes[0], 3); + let tx_broadcaster = nodes[0].tx_broadcaster.clone(); + let wait_for_sweep_tx = || -> Transaction { + loop { + let sweep_tx = tx_broadcaster.txn_broadcasted.lock().unwrap().pop(); + if let Some(sweep_tx) = sweep_tx { + return sweep_tx; + } + + std::thread::sleep(Duration::from_millis(100)); + } + }; + // Check we generate an initial sweeping tx. assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1); + let sweep_tx_0 = wait_for_sweep_tx(); let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone(); - let sweep_tx_0 = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap(); match tracked_output.status { OutputSpendStatus::PendingFirstConfirmation { latest_spending_tx, .. } => { assert_eq!(sweep_tx_0.compute_txid(), latest_spending_tx.compute_txid()); @@ -2294,8 +2374,8 @@ mod tests { // Check we regenerate and rebroadcast the sweeping tx each block. advance_chain(&mut nodes[0], 1); assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1); + let sweep_tx_1 = wait_for_sweep_tx(); let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone(); - let sweep_tx_1 = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap(); match tracked_output.status { OutputSpendStatus::PendingFirstConfirmation { latest_spending_tx, .. } => { assert_eq!(sweep_tx_1.compute_txid(), latest_spending_tx.compute_txid()); @@ -2306,8 +2386,8 @@ mod tests { advance_chain(&mut nodes[0], 1); assert_eq!(nodes[0].sweeper.tracked_spendable_outputs().len(), 1); + let sweep_tx_2 = wait_for_sweep_tx(); let tracked_output = nodes[0].sweeper.tracked_spendable_outputs().first().unwrap().clone(); - let sweep_tx_2 = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().pop().unwrap(); match tracked_output.status { OutputSpendStatus::PendingFirstConfirmation { latest_spending_tx, .. } => { assert_eq!(sweep_tx_2.compute_txid(), latest_spending_tx.compute_txid()); @@ -2387,6 +2467,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2417,6 +2498,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2513,6 +2595,7 @@ mod tests { nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2546,6 +2629,7 @@ mod tests { nodes[0].rapid_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.sweeper_async()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), move |dur: Duration| { @@ -2709,6 +2793,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.clone()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), ); @@ -2760,6 +2845,7 @@ mod tests { nodes[0].no_gossip_sync(), nodes[0].peer_manager.clone(), Some(Arc::clone(&nodes[0].liquidity_manager)), + Some(nodes[0].sweeper.sweeper_async()), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), move |dur: Duration| { diff --git a/lightning/src/lib.rs b/lightning/src/lib.rs index 5c608d9607a..f5f95667e96 100644 --- a/lightning/src/lib.rs +++ b/lightning/src/lib.rs @@ -30,7 +30,6 @@ //! * `grind_signatures` #![cfg_attr(not(any(test, fuzzing, feature = "_test_utils")), deny(missing_docs))] -#![cfg_attr(not(any(test, feature = "_test_utils")), forbid(unsafe_code))] #![deny(rustdoc::broken_intra_doc_links)] #![deny(rustdoc::private_intra_doc_links)] diff --git a/lightning/src/sign/mod.rs b/lightning/src/sign/mod.rs index eb3d57e6dec..5e5758b2cc6 100644 --- a/lightning/src/sign/mod.rs +++ b/lightning/src/sign/mod.rs @@ -67,6 +67,9 @@ use crate::sign::ecdsa::EcdsaChannelSigner; use crate::sign::taproot::TaprootChannelSigner; use crate::util::atomic_counter::AtomicCounter; use core::convert::TryInto; +use core::future::Future; +use core::ops::Deref; +use core::pin::Pin; use core::sync::atomic::{AtomicUsize, Ordering}; #[cfg(taproot)] use musig2::types::{PartialSignature, PublicNonce}; @@ -975,17 +978,56 @@ pub trait SignerProvider { fn get_shutdown_scriptpubkey(&self) -> Result; } +/// A type alias for a future that returns a result of type T. +pub type AsyncResult<'a, T> = Pin> + 'a + Send>>; + /// A helper trait that describes an on-chain wallet capable of returning a (change) destination /// script. pub trait ChangeDestinationSource { /// Returns a script pubkey which can be used as a change destination for /// [`OutputSpender::spend_spendable_outputs`]. /// + /// This method should return a different value each time it is called, to avoid linking + /// on-chain funds controlled to the same user. + fn get_change_destination_script<'a>(&self) -> AsyncResult<'a, ScriptBuf>; +} + +/// A synchronous helper trait that describes an on-chain wallet capable of returning a (change) destination script. +pub trait ChangeDestinationSourceSync { /// This method should return a different value each time it is called, to avoid linking /// on-chain funds controlled to the same user. fn get_change_destination_script(&self) -> Result; } +/// A wrapper around [`ChangeDestinationSource`] to allow for async calls. +#[cfg(any(test, feature = "_test_utils"))] +pub struct ChangeDestinationSourceSyncWrapper(T) +where + T::Target: ChangeDestinationSourceSync; +#[cfg(not(any(test, feature = "_test_utils")))] +pub(crate) struct ChangeDestinationSourceSyncWrapper(T) +where + T::Target: ChangeDestinationSourceSync; + +impl ChangeDestinationSourceSyncWrapper +where + T::Target: ChangeDestinationSourceSync, +{ + /// Creates a new [`ChangeDestinationSourceSyncWrapper`]. + pub fn new(source: T) -> Self { + Self(source) + } +} +impl ChangeDestinationSource for ChangeDestinationSourceSyncWrapper +where + T::Target: ChangeDestinationSourceSync, +{ + fn get_change_destination_script<'a>(&self) -> AsyncResult<'a, ScriptBuf> { + let script = self.0.get_change_destination_script(); + Box::pin(async move { script }) + } +} + mod sealed { use bitcoin::secp256k1::{Scalar, SecretKey}; diff --git a/lightning/src/util/async_poll.rs b/lightning/src/util/async_poll.rs index c18ada73a47..bca490ce9c8 100644 --- a/lightning/src/util/async_poll.rs +++ b/lightning/src/util/async_poll.rs @@ -13,7 +13,7 @@ use crate::prelude::*; use core::future::Future; use core::marker::Unpin; use core::pin::Pin; -use core::task::{Context, Poll}; +use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; pub(crate) enum ResultFuture>, E: Copy + Unpin> { Pending(F), @@ -74,3 +74,24 @@ impl> + Unpin, E: Copy + Unpin> Future } } } + +// If we want to poll a future without an async context to figure out if it has completed or +// not without awaiting, we need a Waker, which needs a vtable...we fill it with dummy values +// but sadly there's a good bit of boilerplate here. +// +// Waker::noop() would be preferable, but requires an MSRV of 1.85. +fn dummy_waker_clone(_: *const ()) -> RawWaker { + RawWaker::new(core::ptr::null(), &DUMMY_WAKER_VTABLE) +} +fn dummy_waker_action(_: *const ()) {} + +const DUMMY_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( + dummy_waker_clone, + dummy_waker_action, + dummy_waker_action, + dummy_waker_action, +); + +pub(crate) fn dummy_waker() -> Waker { + unsafe { Waker::from_raw(RawWaker::new(core::ptr::null(), &DUMMY_WAKER_VTABLE)) } +} diff --git a/lightning/src/util/sweep.rs b/lightning/src/util/sweep.rs index 5d856b9affb..3b0ce5e5e7d 100644 --- a/lightning/src/util/sweep.rs +++ b/lightning/src/util/sweep.rs @@ -15,7 +15,10 @@ use crate::io; use crate::ln::msgs::DecodeError; use crate::ln::types::ChannelId; use crate::prelude::*; -use crate::sign::{ChangeDestinationSource, OutputSpender, SpendableOutputDescriptor}; +use crate::sign::{ + ChangeDestinationSource, ChangeDestinationSourceSync, ChangeDestinationSourceSyncWrapper, + OutputSpender, SpendableOutputDescriptor, +}; use crate::sync::Mutex; use crate::util::logger::Logger; use crate::util::persist::{ @@ -28,9 +31,14 @@ use crate::{impl_writeable_tlv_based, log_debug, log_error}; use bitcoin::block::Header; use bitcoin::locktime::absolute::LockTime; use bitcoin::secp256k1::Secp256k1; -use bitcoin::{BlockHash, Transaction, Txid}; +use bitcoin::{BlockHash, ScriptBuf, Transaction, Txid}; +use crate::sync::Arc; +use core::future::Future; use core::ops::Deref; +use core::task; + +use super::async_poll::dummy_waker; /// The number of blocks we wait before we prune the tracked spendable outputs. pub const PRUNE_DELAY_BLOCKS: u32 = ARCHIVAL_DELAY_BLOCKS + ANTI_REORG_DELAY; @@ -342,7 +350,7 @@ where L::Target: Logger, O::Target: OutputSpender, { - sweeper_state: Mutex, + sweeper_state: Mutex, broadcaster: B, fee_estimator: E, chain_data_source: Option, @@ -372,7 +380,10 @@ where output_spender: O, change_destination_source: D, kv_store: K, logger: L, ) -> Self { let outputs = Vec::new(); - let sweeper_state = Mutex::new(SweeperState { outputs, best_block }); + let sweeper_state = Mutex::new(RuntimeSweeperState { + persistent: SweeperState { outputs, best_block }, + sweep_pending: false, + }); Self { sweeper_state, broadcaster, @@ -416,59 +427,42 @@ where return Ok(()); } - let spending_tx_opt; - { - let mut state_lock = self.sweeper_state.lock().unwrap(); - for descriptor in relevant_descriptors { - let output_info = TrackedSpendableOutput { - descriptor, - channel_id, - status: OutputSpendStatus::PendingInitialBroadcast { - delayed_until_height: delay_until_height, - }, - }; - - if state_lock - .outputs - .iter() - .find(|o| o.descriptor == output_info.descriptor) - .is_some() - { - continue; - } - - state_lock.outputs.push(output_info); + let state_lock = &mut self.sweeper_state.lock().unwrap().persistent; + for descriptor in relevant_descriptors { + let output_info = TrackedSpendableOutput { + descriptor, + channel_id, + status: OutputSpendStatus::PendingInitialBroadcast { + delayed_until_height: delay_until_height, + }, + }; + + if state_lock.outputs.iter().find(|o| o.descriptor == output_info.descriptor).is_some() + { + continue; } - spending_tx_opt = self.regenerate_spend_if_necessary(&mut *state_lock); - self.persist_state(&*state_lock).map_err(|e| { - log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); - })?; - } - if let Some(spending_tx) = spending_tx_opt { - self.broadcaster.broadcast_transactions(&[&spending_tx]); + state_lock.outputs.push(output_info); } - - Ok(()) + self.persist_state(&state_lock).map_err(|e| { + log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); + }) } /// Returns a list of the currently tracked spendable outputs. pub fn tracked_spendable_outputs(&self) -> Vec { - self.sweeper_state.lock().unwrap().outputs.clone() + self.sweeper_state.lock().unwrap().persistent.outputs.clone() } /// Gets the latest best block which was connected either via the [`Listen`] or /// [`Confirm`] interfaces. pub fn current_best_block(&self) -> BestBlock { - self.sweeper_state.lock().unwrap().best_block + self.sweeper_state.lock().unwrap().persistent.best_block } - fn regenerate_spend_if_necessary( - &self, sweeper_state: &mut SweeperState, - ) -> Option { - let cur_height = sweeper_state.best_block.height; - let cur_hash = sweeper_state.best_block.block_hash; - let filter_fn = |o: &TrackedSpendableOutput| { + /// Regenerates and broadcasts the spending transaction for any outputs that are pending + pub async fn regenerate_and_broadcast_spend_if_necessary(&self) -> Result<(), ()> { + let filter_fn = |o: &TrackedSpendableOutput, cur_height: u32| { if o.status.is_confirmed() { // Don't rebroadcast confirmed txs. return false; @@ -487,42 +481,96 @@ where true }; - let respend_descriptors: Vec<&SpendableOutputDescriptor> = - sweeper_state.outputs.iter().filter(|o| filter_fn(*o)).map(|o| &o.descriptor).collect(); + // See if there is anything to sweep before requesting a change address. + { + let mut sweeper_state = self.sweeper_state.lock().unwrap(); + + // Prevent concurrent sweeping. + if sweeper_state.sweep_pending { + return Ok(()); + } - if respend_descriptors.is_empty() { - // Nothing to do. - return None; + let cur_height = sweeper_state.persistent.best_block.height; + let has_respends = + sweeper_state.persistent.outputs.iter().any(|o| filter_fn(o, cur_height)); + if !has_respends { + return Ok(()); + } + + // There is something to sweep. Block concurrent sweeps. + sweeper_state.sweep_pending = true; } - let spending_tx = match self.spend_outputs(&*sweeper_state, respend_descriptors) { - Ok(spending_tx) => { - log_debug!( - self.logger, - "Generating and broadcasting sweeping transaction {}", - spending_tx.compute_txid() - ); - spending_tx - }, - Err(e) => { - log_error!(self.logger, "Error spending outputs: {:?}", e); - return None; - }, - }; + // Request a new change address outside of the mutex to avoid the mutex crossing await. + let change_destination_script_result = + self.change_destination_source.get_change_destination_script().await; - // As we didn't modify the state so far, the same filter_fn yields the same elements as - // above. - let respend_outputs = sweeper_state.outputs.iter_mut().filter(|o| filter_fn(&**o)); - for output_info in respend_outputs { - if let Some(filter) = self.chain_data_source.as_ref() { - let watched_output = output_info.to_watched_output(cur_hash); - filter.register_output(watched_output); + // Sweep the outputs. + { + let mut runtime_sweeper_state = self.sweeper_state.lock().unwrap(); + + // Always allow a new sweep after this spend, also in the error case. + runtime_sweeper_state.sweep_pending = false; + + let sweeper_state = &mut runtime_sweeper_state.persistent; + + let change_destination_script = change_destination_script_result?; + + let cur_height = sweeper_state.best_block.height; + let cur_hash = sweeper_state.best_block.block_hash; + + let respend_descriptors: Vec<&SpendableOutputDescriptor> = sweeper_state + .outputs + .iter() + .filter(|o| filter_fn(*o, cur_height)) + .map(|o| &o.descriptor) + .collect(); + + if respend_descriptors.is_empty() { + // It could be that a tx confirmed and there is now nothing to sweep anymore. + return Ok(()); + } + + let spending_tx = match self.spend_outputs( + sweeper_state, + &respend_descriptors, + change_destination_script, + ) { + Ok(spending_tx) => { + log_debug!( + self.logger, + "Generating and broadcasting sweeping transaction {}", + spending_tx.compute_txid() + ); + spending_tx + }, + Err(e) => { + log_error!(self.logger, "Error spending outputs: {:?}", e); + return Ok(()); + }, + }; + + // As we didn't modify the state so far, the same filter_fn yields the same elements as + // above. + let respend_outputs = + sweeper_state.outputs.iter_mut().filter(|o| filter_fn(&**o, cur_height)); + for output_info in respend_outputs { + if let Some(filter) = self.chain_data_source.as_ref() { + let watched_output = output_info.to_watched_output(cur_hash); + filter.register_output(watched_output); + } + + output_info.status.broadcast(cur_hash, cur_height, spending_tx.clone()); } - output_info.status.broadcast(cur_hash, cur_height, spending_tx.clone()); + self.persist_state(&sweeper_state).map_err(|e| { + log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); + })?; + + self.broadcaster.broadcast_transactions(&[&spending_tx]); } - Some(spending_tx) + Ok(()) } fn prune_confirmed_outputs(&self, sweeper_state: &mut SweeperState) { @@ -567,16 +615,15 @@ where } fn spend_outputs( - &self, sweeper_state: &SweeperState, descriptors: Vec<&SpendableOutputDescriptor>, + &self, sweeper_state: &SweeperState, descriptors: &Vec<&SpendableOutputDescriptor>, + change_destination_script: ScriptBuf, ) -> Result { let tx_feerate = self.fee_estimator.get_est_sat_per_1000_weight(ConfirmationTarget::OutputSpendingFee); - let change_destination_script = - self.change_destination_source.get_change_destination_script()?; let cur_height = sweeper_state.best_block.height; let locktime = Some(LockTime::from_height(cur_height).unwrap_or(LockTime::ZERO)); self.output_spender.spend_spendable_outputs( - &descriptors, + descriptors, Vec::new(), change_destination_script, tx_feerate, @@ -601,11 +648,9 @@ where fn best_block_updated_internal( &self, sweeper_state: &mut SweeperState, header: &Header, height: u32, - ) -> Option { + ) { sweeper_state.best_block = BestBlock::new(header.block_hash(), height); self.prune_confirmed_outputs(sweeper_state); - let spending_tx_opt = self.regenerate_spend_if_necessary(sweeper_state); - spending_tx_opt } } @@ -623,31 +668,22 @@ where fn filtered_block_connected( &self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32, ) { - let mut spending_tx_opt; - { - let mut state_lock = self.sweeper_state.lock().unwrap(); - assert_eq!(state_lock.best_block.block_hash, header.prev_blockhash, - "Blocks must be connected in chain-order - the connected header must build on the last connected header"); - assert_eq!(state_lock.best_block.height, height - 1, - "Blocks must be connected in chain-order - the connected block height must be one greater than the previous height"); + let state_lock = &mut self.sweeper_state.lock().unwrap().persistent; + assert_eq!(state_lock.best_block.block_hash, header.prev_blockhash, + "Blocks must be connected in chain-order - the connected header must build on the last connected header"); + assert_eq!(state_lock.best_block.height, height - 1, + "Blocks must be connected in chain-order - the connected block height must be one greater than the previous height"); - self.transactions_confirmed_internal(&mut *state_lock, header, txdata, height); - spending_tx_opt = self.best_block_updated_internal(&mut *state_lock, header, height); + self.transactions_confirmed_internal(state_lock, header, txdata, height); + self.best_block_updated_internal(state_lock, header, height); - self.persist_state(&*state_lock).unwrap_or_else(|e| { - log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); - // Skip broadcasting if the persist failed. - spending_tx_opt = None; - }); - } - - if let Some(spending_tx) = spending_tx_opt { - self.broadcaster.broadcast_transactions(&[&spending_tx]); - } + let _ = self.persist_state(&state_lock).map_err(|e| { + log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); + }); } fn block_disconnected(&self, header: &Header, height: u32) { - let mut state_lock = self.sweeper_state.lock().unwrap(); + let state_lock = &mut self.sweeper_state.lock().unwrap().persistent; let new_height = height - 1; let block_hash = header.block_hash(); @@ -685,15 +721,15 @@ where fn transactions_confirmed( &self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32, ) { - let mut state_lock = self.sweeper_state.lock().unwrap(); - self.transactions_confirmed_internal(&mut *state_lock, header, txdata, height); - self.persist_state(&*state_lock).unwrap_or_else(|e| { + let state_lock = &mut self.sweeper_state.lock().unwrap().persistent; + self.transactions_confirmed_internal(state_lock, header, txdata, height); + self.persist_state(state_lock).unwrap_or_else(|e| { log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); }); } fn transaction_unconfirmed(&self, txid: &Txid) { - let mut state_lock = self.sweeper_state.lock().unwrap(); + let state_lock = &mut self.sweeper_state.lock().unwrap().persistent; // Get what height was unconfirmed. let unconf_height = state_lock @@ -710,31 +746,22 @@ where .filter(|o| o.status.confirmation_height() >= Some(unconf_height)) .for_each(|o| o.status.unconfirmed()); - self.persist_state(&*state_lock).unwrap_or_else(|e| { + self.persist_state(state_lock).unwrap_or_else(|e| { log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); }); } } fn best_block_updated(&self, header: &Header, height: u32) { - let mut spending_tx_opt; - { - let mut state_lock = self.sweeper_state.lock().unwrap(); - spending_tx_opt = self.best_block_updated_internal(&mut *state_lock, header, height); - self.persist_state(&*state_lock).unwrap_or_else(|e| { - log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); - // Skip broadcasting if the persist failed. - spending_tx_opt = None; - }); - } - - if let Some(spending_tx) = spending_tx_opt { - self.broadcaster.broadcast_transactions(&[&spending_tx]); - } + let state_lock = &mut self.sweeper_state.lock().unwrap().persistent; + self.best_block_updated_internal(state_lock, header, height); + let _ = self.persist_state(state_lock).map_err(|e| { + log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e); + }); } fn get_relevant_txids(&self) -> Vec<(Txid, u32, Option)> { - let state_lock = self.sweeper_state.lock().unwrap(); + let state_lock = &self.sweeper_state.lock().unwrap().persistent; state_lock .outputs .iter() @@ -755,6 +782,11 @@ where } } +struct RuntimeSweeperState { + persistent: SweeperState, + sweep_pending: bool, +} + #[derive(Debug, Clone)] struct SweeperState { outputs: Vec, @@ -817,7 +849,8 @@ where } } - let sweeper_state = Mutex::new(state); + let sweeper_state = + Mutex::new(RuntimeSweeperState { persistent: state, sweep_pending: false }); Ok(Self { sweeper_state, broadcaster, @@ -865,7 +898,8 @@ where } } - let sweeper_state = Mutex::new(state); + let sweeper_state = + Mutex::new(RuntimeSweeperState { persistent: state, sweep_pending: false }); Ok(( best_block, OutputSweeper { @@ -881,3 +915,121 @@ where )) } } + +/// A synchronous wrapper around [`OutputSweeper`] to be used in contexts where async is not available. +pub struct OutputSweeperSync +where + B::Target: BroadcasterInterface, + D::Target: ChangeDestinationSourceSync, + E::Target: FeeEstimator, + F::Target: Filter + Sync + Send, + K::Target: KVStore, + L::Target: Logger, + O::Target: OutputSpender, +{ + sweeper: Arc>, E, F, K, L, O>>, +} + +impl + OutputSweeperSync +where + B::Target: BroadcasterInterface, + D::Target: ChangeDestinationSourceSync, + E::Target: FeeEstimator, + F::Target: Filter + Sync + Send, + K::Target: KVStore, + L::Target: Logger, + O::Target: OutputSpender, +{ + /// Constructs a new [`OutputSweeperSync`] instance. + pub fn new( + best_block: BestBlock, broadcaster: B, fee_estimator: E, chain_data_source: Option, + output_spender: O, change_destination_source: D, kv_store: K, logger: L, + ) -> Self { + let change_destination_source = + Arc::new(ChangeDestinationSourceSyncWrapper::new(change_destination_source)); + + let sweeper = OutputSweeper::new( + best_block, + broadcaster, + fee_estimator, + chain_data_source, + output_spender, + change_destination_source, + kv_store, + logger, + ); + Self { sweeper: Arc::new(sweeper) } + } + + /// Regenerates and broadcasts the spending transaction for any outputs that are pending. Wraps + /// [`OutputSweeper::regenerate_and_broadcast_spend_if_necessary`]. + pub fn regenerate_and_broadcast_spend_if_necessary(&self) -> Result<(), ()> { + let mut fut = Box::pin(self.sweeper.regenerate_and_broadcast_spend_if_necessary()); + let mut waker = dummy_waker(); + let mut ctx = task::Context::from_waker(&mut waker); + match fut.as_mut().poll(&mut ctx) { + task::Poll::Ready(result) => result, + task::Poll::Pending => { + // In a sync context, we can't wait for the future to complete. + panic!("task not ready"); + }, + } + } + + /// Tells the sweeper to track the given outputs descriptors. Wraps [`OutputSweeper::track_spendable_outputs`]. + pub fn track_spendable_outputs( + &self, output_descriptors: Vec, channel_id: Option, + exclude_static_outputs: bool, delay_until_height: Option, + ) -> Result<(), ()> { + self.sweeper.track_spendable_outputs( + output_descriptors, + channel_id, + exclude_static_outputs, + delay_until_height, + ) + } + + /// Returns a list of the currently tracked spendable outputs. Wraps [`OutputSweeper::tracked_spendable_outputs`]. + pub fn tracked_spendable_outputs(&self) -> Vec { + self.sweeper.tracked_spendable_outputs() + } + + /// Returns the inner async sweeper for testing purposes. + #[cfg(any(test, feature = "_test_utils"))] + pub fn sweeper_async( + &self, + ) -> Arc>, E, F, K, L, O>> { + self.sweeper.clone() + } +} + +impl Confirm + for OutputSweeperSync +where + B::Target: BroadcasterInterface, + D::Target: ChangeDestinationSourceSync, + E::Target: FeeEstimator, + F::Target: Filter + Sync + Send, + K::Target: KVStore, + L::Target: Logger, + O::Target: OutputSpender, +{ + fn transactions_confirmed( + &self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32, + ) { + self.sweeper.transactions_confirmed(header, txdata, height) + } + + fn transaction_unconfirmed(&self, txid: &Txid) { + self.sweeper.transaction_unconfirmed(txid) + } + + fn best_block_updated(&self, header: &Header, height: u32) { + self.sweeper.best_block_updated(header, height); + } + + fn get_relevant_txids(&self) -> Vec<(Txid, u32, Option)> { + self.sweeper.get_relevant_txids() + } +}