diff --git a/crates/polars-stream/src/async_primitives/connector.rs b/crates/polars-stream/src/async_primitives/connector.rs index f36098ec60ec..f35d43af5c38 100644 --- a/crates/polars-stream/src/async_primitives/connector.rs +++ b/crates/polars-stream/src/async_primitives/connector.rs @@ -22,7 +22,9 @@ pub fn connector() -> (Sender, Receiver) { /* For UnsafeCell safety, a sender may only set the FULL_BIT (giving exclusive access to value to the receiver), and a receiver may only unset the FULL_BIT - (giving exclusive access back to the sender). + (giving exclusive access back to the sender). Setting/clearing the FULL_BIT + must be done with a Release ordering, and before reading/writing the value + the FULL_BIT must be checked with an Acquire ordering. The exception is when the closed bit is set, at that point the unclosed end has full exclusive access. @@ -66,14 +68,14 @@ pub enum RecvError { impl Connector { unsafe fn poll_send(&self, value: &mut Option, waker: &Waker) -> Poll> { if let Some(v) = value.take() { - let mut state = self.state.load(Ordering::Relaxed); + let mut state = self.state.load(Ordering::Acquire); if state & FULL_BIT == FULL_BIT { self.send_waker.register(waker); let (Ok(s) | Err(s)) = self.state.compare_exchange( state, state | WAITING_BIT, - Ordering::Release, Ordering::Relaxed, + Ordering::Acquire, // Receiver updated, re-acquire. ); state = s; } @@ -101,11 +103,13 @@ impl Connector { unsafe { self.value.get().write(MaybeUninit::new(value)); - let state = self.state.swap(FULL_BIT, Ordering::AcqRel); + let state = self.state.swap(FULL_BIT, Ordering::Release); if state & WAITING_BIT == WAITING_BIT { self.recv_waker.wake(); } if state & CLOSED_BIT == CLOSED_BIT { + // SAFETY: no synchronization needed, we are the only one left. + // Restore the closed bit we just overwrote. self.state.store(CLOSED_BIT, Ordering::Relaxed); return Err(SendError::Closed(self.value.get().read().assume_init())); } @@ -121,8 +125,8 @@ impl Connector { let (Ok(s) | Err(s)) = self.state.compare_exchange( state, state | WAITING_BIT, - Ordering::Release, - Ordering::Acquire, + Ordering::Relaxed, + Ordering::Acquire, // Sender updated, re-acquire. ); state = s; } @@ -138,11 +142,12 @@ impl Connector { if state & FULL_BIT == FULL_BIT { unsafe { let ret = self.value.get().read().assume_init(); - let state = self.state.swap(0, Ordering::Acquire); + let state = self.state.swap(0, Ordering::Release); if state & WAITING_BIT == WAITING_BIT { self.send_waker.wake(); } if state & CLOSED_BIT == CLOSED_BIT { + // Restore the closed bit we just overwrote. self.state.store(CLOSED_BIT, Ordering::Relaxed); } return Ok(ret); @@ -159,7 +164,7 @@ impl Connector { } unsafe fn try_send(&self, value: T) -> Result<(), SendError> { - self.try_send_impl(value, self.state.load(Ordering::Relaxed)) + self.try_send_impl(value, self.state.load(Ordering::Acquire)) } unsafe fn try_recv(&self) -> Result { @@ -176,8 +181,8 @@ impl Connector { /// # Safety /// You may not access this connector anymore as a receiver after this call. unsafe fn close_recv(&self) { - self.state.fetch_or(CLOSED_BIT, Ordering::Relaxed); - drop(self.try_recv()); + let state = self.state.fetch_or(CLOSED_BIT, Ordering::Acquire); + drop(self.try_recv_impl(state)); self.send_waker.wake(); } }