From b87f210dbdd90e5f65caefac1eeb053b0f0f612e Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Thu, 25 Apr 2024 14:31:55 +0200 Subject: [PATCH] fix: reduce lock contention in distributor channels (#10026) * fix: lock contention in distributor channels Reduce lock contention in distributor channels via: - use atomic counters instead of "counter behind mutex" where appropriate - use less state - only lock when needed - move "wake" operation out of lock scopes (they are eventual operations anyways and many wake operations results in "futex wake" operations -- i.e. a syscall -- which you should avoid while holding the lock) * refactor: add more docs and tests for distributor channels --------- Co-authored-by: Andrew Lamb --- .../src/repartition/distributor_channels.rs | 358 ++++++++++++------ 1 file changed, 245 insertions(+), 113 deletions(-) diff --git a/datafusion/physical-plan/src/repartition/distributor_channels.rs b/datafusion/physical-plan/src/repartition/distributor_channels.rs index e71b88467bcc..bad923ce9e82 100644 --- a/datafusion/physical-plan/src/repartition/distributor_channels.rs +++ b/datafusion/physical-plan/src/repartition/distributor_channels.rs @@ -40,8 +40,12 @@ use std::{ collections::VecDeque, future::Future, + ops::DerefMut, pin::Pin, - sync::Arc, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, task::{Context, Poll, Waker}, }; @@ -52,20 +56,12 @@ pub fn channels( n: usize, ) -> (Vec>, Vec>) { let channels = (0..n) - .map(|id| { - Arc::new(Mutex::new(Channel { - data: VecDeque::default(), - n_senders: 1, - recv_alive: true, - recv_wakers: Vec::default(), - id, - })) - }) + .map(|id| Arc::new(Channel::new_with_one_sender(id))) .collect::>(); - let gate = Arc::new(Mutex::new(Gate { - empty_channels: n, - send_wakers: Vec::default(), - })); + let gate = Arc::new(Gate { + empty_channels: AtomicUsize::new(n), + send_wakers: Mutex::new(None), + }); let senders = channels .iter() .map(|channel| DistributionSender { @@ -143,8 +139,7 @@ impl DistributionSender { impl Clone for DistributionSender { fn clone(&self) -> Self { - let mut guard = self.channel.lock(); - guard.n_senders += 1; + self.channel.n_senders.fetch_add(1, Ordering::SeqCst); Self { channel: Arc::clone(&self.channel), @@ -155,19 +150,46 @@ impl Clone for DistributionSender { impl Drop for DistributionSender { fn drop(&mut self) { - let mut guard_channel = self.channel.lock(); - guard_channel.n_senders -= 1; + let n_senders_pre = self.channel.n_senders.fetch_sub(1, Ordering::SeqCst); + // is the the last copy of the sender side? + if n_senders_pre > 1 { + return; + } - if guard_channel.n_senders == 0 { - // Note: the recv_alive check is so that we don't double-clear the status - if guard_channel.data.is_empty() && guard_channel.recv_alive { + let receivers = { + let mut state = self.channel.state.lock(); + + // During the shutdown of a empty channel, both the sender and the receiver side will be dropped. However we + // only want to decrement the "empty channels" counter once. + // + // We are within a critical section here, so we we can safely assume that either the last sender or the + // receiver (there's only one) will be dropped first. + // + // If the last sender is dropped first, `state.data` will still exists and the sender side decrements the + // signal. The receiver side then MUST check the `n_senders` counter during the section and if it is zero, + // it inferres that it is dropped afterwards and MUST NOT decrement the counter. + // + // If the receiver end is dropped first, it will inferr -- based on `n_senders` -- that there are still + // senders and it will decrement the `empty_channels` counter. It will also set `data` to `None`. The sender + // side will then see that `data` is `None` and can therefore inferr that the receiver end was dropped, and + // hence it MUST NOT decrement the `empty_channels` counter. + if state + .data + .as_ref() + .map(|data| data.is_empty()) + .unwrap_or_default() + { // channel is gone, so we need to clear our signal - let mut guard_gate = self.gate.lock(); - guard_gate.empty_channels -= 1; + self.gate.decr_empty_channels(); } - // receiver may be waiting for data, but should return `None` now since the channel is closed - guard_channel.wake_receivers(); + // make sure that nobody can add wakers anymore + state.recv_wakers.take().expect("not closed yet") + }; + + // wake outside of lock scope + for recv in receivers { + recv.wake(); } } } @@ -188,33 +210,41 @@ impl<'a, T> Future for SendFuture<'a, T> { let this = &mut *self; assert!(this.element.is_some(), "polled ready future"); - let mut guard_channel = this.channel.lock(); - - // receiver end still alive? - if !guard_channel.recv_alive { - return Poll::Ready(Err(SendError( - this.element.take().expect("just checked"), - ))); - } + // lock scope + let to_wake = { + let mut guard_channel_state = this.channel.state.lock(); + + let Some(data) = guard_channel_state.data.as_mut() else { + // receiver end dead + return Poll::Ready(Err(SendError( + this.element.take().expect("just checked"), + ))); + }; + + // does ANY receiver need data? + // if so, allow sender to create another + if this.gate.empty_channels.load(Ordering::SeqCst) == 0 { + let mut guard = this.gate.send_wakers.lock(); + if let Some(send_wakers) = guard.deref_mut() { + send_wakers.push((cx.waker().clone(), this.channel.id)); + return Poll::Pending; + } + } - let mut guard_gate = this.gate.lock(); + let was_empty = data.is_empty(); + data.push_back(this.element.take().expect("just checked")); - // does ANY receiver need data? - // if so, allow sender to create another - if guard_gate.empty_channels == 0 { - guard_gate - .send_wakers - .push((cx.waker().clone(), guard_channel.id)); - return Poll::Pending; - } + if was_empty { + this.gate.decr_empty_channels(); + guard_channel_state.take_recv_wakers() + } else { + Vec::with_capacity(0) + } + }; - let was_empty = guard_channel.data.is_empty(); - guard_channel - .data - .push_back(this.element.take().expect("just checked")); - if was_empty { - guard_gate.empty_channels -= 1; - guard_channel.wake_receivers(); + // wake outside of lock scope + for receiver in to_wake { + receiver.wake(); } Poll::Ready(Ok(())) @@ -243,21 +273,18 @@ impl DistributionReceiver { impl Drop for DistributionReceiver { fn drop(&mut self) { - let mut guard_channel = self.channel.lock(); - let mut guard_gate = self.gate.lock(); - guard_channel.recv_alive = false; + let mut guard_channel_state = self.channel.state.lock(); + let data = guard_channel_state.data.take().expect("not dropped yet"); - // Note: n_senders check is here so we don't double-clear the signal - if guard_channel.data.is_empty() && (guard_channel.n_senders > 0) { + // See `DistributedSender::drop` for an explanation of the drop order and when the "empty channels" counter is + // decremented. + if data.is_empty() && (self.channel.n_senders.load(Ordering::SeqCst) > 0) { // channel is gone, so we need to clear our signal - guard_gate.empty_channels -= 1; + self.gate.decr_empty_channels(); } // senders may be waiting for gate to open but should error now that the channel is closed - guard_gate.wake_channel_senders(guard_channel.id); - - // clear potential remaining data from channel - guard_channel.data.clear(); + self.gate.wake_channel_senders(self.channel.id); } } @@ -275,37 +302,51 @@ impl<'a, T> Future for RecvFuture<'a, T> { let this = &mut *self; assert!(!this.rdy, "polled ready future"); - let mut guard_channel = this.channel.lock(); + let mut guard_channel_state = this.channel.state.lock(); + let channel_state = guard_channel_state.deref_mut(); + let data = channel_state.data.as_mut().expect("not dropped yet"); - match guard_channel.data.pop_front() { + match data.pop_front() { Some(element) => { // change "empty" signal for this channel? - if guard_channel.data.is_empty() && (guard_channel.n_senders > 0) { - let mut guard_gate = this.gate.lock(); - + if data.is_empty() && channel_state.recv_wakers.is_some() { // update counter - let old_counter = guard_gate.empty_channels; - guard_gate.empty_channels += 1; + let old_counter = + this.gate.empty_channels.fetch_add(1, Ordering::SeqCst); // open gate? - if old_counter == 0 { - guard_gate.wake_all_senders(); + let to_wake = if old_counter == 0 { + let mut guard = this.gate.send_wakers.lock(); + + // check after lock to see if we should still change the state + if this.gate.empty_channels.load(Ordering::SeqCst) > 0 { + guard.take().unwrap_or_default() + } else { + Vec::with_capacity(0) + } + } else { + Vec::with_capacity(0) + }; + + drop(guard_channel_state); + + // wake outside of lock scope + for (waker, _channel_id) in to_wake { + waker.wake(); } - - drop(guard_gate); - drop(guard_channel); } this.rdy = true; Poll::Ready(Some(element)) } - None if guard_channel.n_senders == 0 => { - this.rdy = true; - Poll::Ready(None) - } None => { - guard_channel.recv_wakers.push(cx.waker().clone()); - Poll::Pending + if let Some(recv_wakers) = channel_state.recv_wakers.as_mut() { + recv_wakers.push(cx.waker().clone()); + Poll::Pending + } else { + this.rdy = true; + Poll::Ready(None) + } } } } @@ -314,78 +355,122 @@ impl<'a, T> Future for RecvFuture<'a, T> { /// Links senders and receivers. #[derive(Debug)] struct Channel { - /// Buffered data. - data: VecDeque, - /// Reference counter for the sender side. - n_senders: usize, - - /// Reference "counter"/flag for the single receiver. - recv_alive: bool, - - /// Wakers for the receiver side. - /// - /// The receiver will be pending if the [buffer](Self::data) is empty and - /// there are senders left (according to the [reference counter](Self::n_senders)). - recv_wakers: Vec, + n_senders: AtomicUsize, /// Channel ID. /// /// This is used to address [send wakers](Gate::send_wakers). id: usize, + + /// Mutable state. + state: Mutex>, } impl Channel { - fn wake_receivers(&mut self) { - for waker in self.recv_wakers.drain(..) { - waker.wake(); + /// Create new channel with one sender (so we don't need to [fetch-add](AtomicUsize::fetch_add) directly afterwards). + fn new_with_one_sender(id: usize) -> Self { + Channel { + n_senders: AtomicUsize::new(1), + id, + state: Mutex::new(ChannelState { + data: Some(VecDeque::default()), + recv_wakers: Some(Vec::default()), + }), } } } +#[derive(Debug)] +struct ChannelState { + /// Buffered data. + /// + /// This is [`None`] when the receiver is gone. + data: Option>, + + /// Wakers for the receiver side. + /// + /// The receiver will be pending if the [buffer](Self::data) is empty and + /// there are senders left (otherwise this is set to [`None`]). + recv_wakers: Option>, +} + +impl ChannelState { + /// Get all [`recv_wakers`](Self::recv_wakers) and replace with identically-sized buffer. + /// + /// The wakers should be woken AFTER the lock to [this state](Self) was dropped. + /// + /// # Panics + /// Assumes that channel is NOT closed yet, i.e. that [`recv_wakers`](Self::recv_wakers) is not [`None`]. + fn take_recv_wakers(&mut self) -> Vec { + let to_wake = self.recv_wakers.as_mut().expect("not closed"); + let mut tmp = Vec::with_capacity(to_wake.capacity()); + std::mem::swap(to_wake, &mut tmp); + tmp + } +} + /// Shared channel. /// /// One or multiple senders and a single receiver will share a channel. -type SharedChannel = Arc>>; +type SharedChannel = Arc>; /// The "all channels have data" gate. #[derive(Debug)] struct Gate { /// Number of currently empty (and still open) channels. - empty_channels: usize, + empty_channels: AtomicUsize, /// Wakers for the sender side, including their channel IDs. - send_wakers: Vec<(Waker, usize)>, + /// + /// This is `None` if the there are non-empty channels. + send_wakers: Mutex>>, } impl Gate { - //// Wake all senders. + /// Wake senders for a specific channel. /// - /// This is helpful to signal that there are some channels empty now and hence the gate was opened. - fn wake_all_senders(&mut self) { - for (waker, _id) in self.send_wakers.drain(..) { + /// This is helpful to signal that the receiver side is gone and the senders shall now error. + fn wake_channel_senders(&self, id: usize) { + // lock scope + let to_wake = { + let mut guard = self.send_wakers.lock(); + + if let Some(send_wakers) = guard.deref_mut() { + // `drain_filter` is unstable, so implement our own + let (wake, keep) = + send_wakers.drain(..).partition(|(_waker, id2)| id == *id2); + + *send_wakers = keep; + + wake + } else { + Vec::with_capacity(0) + } + }; + + // wake outside of lock scope + for (waker, _id) in to_wake { waker.wake(); } } - /// Wake senders for a specific channel. - /// - /// This is helpful to signal that the receiver side is gone and the senders shall now error. - fn wake_channel_senders(&mut self, id: usize) { - // `drain_filter` is unstable, so implement our own - let (wake, keep) = self - .send_wakers - .drain(..) - .partition(|(_waker, id2)| id == *id2); - self.send_wakers = keep; - for (waker, _id) in wake { - waker.wake(); + fn decr_empty_channels(&self) { + let old_count = self.empty_channels.fetch_sub(1, Ordering::SeqCst); + + if old_count == 1 { + let mut guard = self.send_wakers.lock(); + + // double-check state during lock + if self.empty_channels.load(Ordering::SeqCst) == 0 && guard.is_none() { + *guard = Some(Vec::new()); + } } } } /// Gate shared by all senders and receivers. -type SharedGate = Arc>; +type SharedGate = Arc; #[cfg(test)] mod tests { @@ -596,6 +681,52 @@ mod tests { assert_eq!(counter.strong_count(), 0); } + /// Ensure that polling "pending" futures work even when you poll them too often (which happens under some circumstances). + #[test] + fn test_poll_empty_channel_twice() { + let (txs, mut rxs) = channels(1); + + let mut recv_fut = rxs[0].recv(); + let waker_1a = poll_pending(&mut recv_fut); + let waker_1b = poll_pending(&mut recv_fut); + + let mut recv_fut = rxs[0].recv(); + let waker_2 = poll_pending(&mut recv_fut); + + poll_ready(&mut txs[0].send("a")).unwrap(); + assert!(waker_1a.woken()); + assert!(waker_1b.woken()); + assert!(waker_2.woken()); + assert_eq!(poll_ready(&mut recv_fut), Some("a"),); + + poll_ready(&mut txs[0].send("b")).unwrap(); + let mut send_fut = txs[0].send("c"); + let waker_3 = poll_pending(&mut send_fut); + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("b"),); + assert!(waker_3.woken()); + poll_ready(&mut send_fut).unwrap(); + assert_eq!(poll_ready(&mut rxs[0].recv()), Some("c")); + + let mut recv_fut = rxs[0].recv(); + let waker_4 = poll_pending(&mut recv_fut); + + let mut recv_fut = rxs[0].recv(); + let waker_5 = poll_pending(&mut recv_fut); + + poll_ready(&mut txs[0].send("d")).unwrap(); + let mut send_fut = txs[0].send("e"); + let waker_6a = poll_pending(&mut send_fut); + let waker_6b = poll_pending(&mut send_fut); + + assert!(waker_4.woken()); + assert!(waker_5.woken()); + assert_eq!(poll_ready(&mut recv_fut), Some("d"),); + + assert!(waker_6a.woken()); + assert!(waker_6b.woken()); + poll_ready(&mut send_fut).unwrap(); + } + #[test] #[should_panic(expected = "polled ready future")] fn test_panic_poll_send_future_after_ready_ok() { @@ -655,6 +786,7 @@ mod tests { poll_pending(&mut fut); } + /// Test [`poll_pending`] (i.e. the testing utils, not the actual library code). #[test] fn test_meta_poll_pending_waker() { let (tx, mut rx) = futures::channel::oneshot::channel();