Skip to content

Add JoinHandle::into_join_future(). #131389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@
#![feature(lazy_get)]
#![feature(maybe_uninit_slice)]
#![feature(maybe_uninit_write_slice)]
#![feature(noop_waker)]
#![feature(panic_can_unwind)]
#![feature(panic_internals)]
#![feature(pin_coerce_unsized_trait)]
Expand Down
197 changes: 182 additions & 15 deletions library/std/src/thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,18 @@ use core::mem::MaybeUninit;

use crate::any::Any;
use crate::cell::UnsafeCell;
use crate::future::Future;
use crate::marker::PhantomData;
use crate::mem::{self, ManuallyDrop, forget};
use crate::num::NonZero;
use crate::pin::Pin;
use crate::sync::Arc;
use crate::sync::atomic::{AtomicUsize, Ordering};
use crate::sync::{Arc, Mutex, PoisonError};
use crate::sys::sync::Parker;
use crate::sys::thread as imp;
use crate::sys_common::{AsInner, IntoInner};
use crate::time::{Duration, Instant};
use crate::{env, fmt, io, panic, panicking, str};
use crate::{env, fmt, io, panic, panicking, str, task};

#[stable(feature = "scoped_threads", since = "1.63.0")]
mod scoped;
Expand Down Expand Up @@ -490,6 +491,7 @@ impl Builder {
let my_packet: Arc<Packet<'scope, T>> = Arc::new(Packet {
scope: scope_data,
result: UnsafeCell::new(None),
waker: Mutex::new(task::Waker::noop().clone()),
_marker: PhantomData,
});
let their_packet = my_packet.clone();
Expand Down Expand Up @@ -540,15 +542,35 @@ impl Builder {
let try_result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
crate::sys::backtrace::__rust_begin_short_backtrace(f)
}));

// Store the `Result` of the thread that the `JoinHandle` can retrieve.
//
// SAFETY: `their_packet` as been built just above and moved by the
// closure (it is an Arc<...>) and `my_packet` will be stored in the
// same `JoinInner` as this closure meaning the mutation will be
// safe (not modify it and affect a value far away).
unsafe { *their_packet.result.get() = Some(try_result) };
// Here `their_packet` gets dropped, and if this is the last `Arc` for that packet that
// will call `decrement_num_running_threads` and therefore signal that this thread is
// done.

// Fetch the `Waker` from the packet; this is needed to support `.into_join_future()`.
// If unused, this just returns `Waker::noop()` which will do nothing.
let waker: task::Waker = {
let placeholder = task::Waker::noop().clone();
let mut guard = their_packet.waker.lock().unwrap_or_else(PoisonError::into_inner);
mem::replace(&mut *guard, placeholder)
};

// Here `their_packet` gets dropped, and if this is the last `Arc` for that packet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just randomly looking at this patch and found a potential issue.

Don't you have a race here?
What if the Future is polled right now?
We already took the waker which was still the noop when we cloned it and released the Mutex on the waker.
Now the Future sets the waker and then check again that it is not finished. But we as we haven't decremented the ref-count yet (we do that only in the next line) so the Future will return Pending and nobody is ever going to wake it.

Since you are using a Mutex for the synchronization of the Waker, I suggest to also use this lock to communicate to the future that we are finished and that the result can be taken out even if the ref-count is still 2.

// (which happens if the `JoinHandle` has been dropped) that will call
// `decrement_num_running_threads` and therefore signal to the scope (if there is one)
// that this thread is done.
drop(their_packet);

// Now that we have become visibly “finished” by dropping the packet
// (`JoinInner::is_finished` will return true), we can use the `Waker` to signal
// any waiting `JoinFuture`. If instead we are being waited for by
// `JoinHandle::join()`, the actual platform thread termination will be the wakeup.
waker.wake();

// Here, the lifetime `'scope` can end. `main` keeps running for a bit
// after that before returning itself.
};
Expand Down Expand Up @@ -1192,8 +1214,6 @@ impl ThreadId {
}
}
} else {
use crate::sync::{Mutex, PoisonError};

static COUNTER: Mutex<u64> = Mutex::new(0);

let mut counter = COUNTER.lock().unwrap_or_else(PoisonError::into_inner);
Expand Down Expand Up @@ -1635,16 +1655,30 @@ impl fmt::Debug for Thread {
#[stable(feature = "rust1", since = "1.0.0")]
pub type Result<T> = crate::result::Result<T, Box<dyn Any + Send + 'static>>;

// This packet is used to communicate the return value between the spawned
// thread and the rest of the program. It is shared through an `Arc` and
// there's no need for a mutex here because synchronization happens with `join()`
// (the caller will never read this packet until the thread has exited).
//
// An Arc to the packet is stored into a `JoinInner` which in turns is placed
// in `JoinHandle`.
/// This packet is used to communicate the return value between the spawned
/// thread and the rest of the program. It is shared through an [`Arc`].
///
/// An Arc to the packet is stored into a [`JoinInner`] which in turn is placed
/// in [`JoinHandle`] or [`ScopedJoinHandle`].
struct Packet<'scope, T> {
/// Communication with the enclosing thread scope if there is one.
scope: Option<Arc<scoped::ScopeData>>,

/// Holds the return value.
///
/// Synchronization happens via reference counting: as long as the `Arc<Packet>`
/// has two or more references, this field is never read, and will only be written
/// once as the thread terminates. After that happens, either the packet is dropped,
/// or [`JoinInner::join()`] will `take()` the result value from here.
result: UnsafeCell<Option<Result<T>>>,

/// If a [`JoinFuture`] for this thread exists and has been polled,
/// this is the waker from that poll. If it does not exist or has not
/// been polled yet, this is [`task::Waker::noop()`].
// FIXME: This should be an `AtomicWaker` instead of a `Mutex`,
// to be cheaper and impossible to deadlock.
waker: Mutex<task::Waker>,

_marker: PhantomData<Option<&'scope scoped::ScopeData>>,
}

Expand Down Expand Up @@ -1698,6 +1732,10 @@ impl<'scope, T> JoinInner<'scope, T> {
self.native.join();
Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap()
}

fn is_finished(&self) -> bool {
Arc::strong_count(&self.packet) == 1
}
}

/// An owned permission to join on a thread (block on its termination).
Expand Down Expand Up @@ -1844,6 +1882,50 @@ impl<T> JoinHandle<T> {
self.0.join()
}

/// Returns a [`Future`] that resolves when the thread has finished.
///
/// Its [output](Future::Output) value is identical to that of [`JoinHandle::join()`];
/// this is the approximate `async` equivalent of that blocking function.
///
/// # Details
///
/// * If the returned future is dropped (cancelled), the thread will become *detached*;
/// there will be no way to observe or wait for the thread’s termination.
/// This is identical to the behavior of `JoinHandle` itself.
///
/// * Unlike [`JoinHandle::join()`], the thread may still exist when the future resolves.
/// In particular, it may still be executing destructors for thread-local values.
///
/// # Example
///
// FIXME: ideally we would actually run this example, with the help of a trivial async executor
/// ```no_run
/// #![feature(thread_join_future)]
/// use std::thread;
///
/// async fn do_some_heavy_tasks_in_parallel() -> thread::Result<()> {
/// let future_1 = thread::spawn(|| {
/// // ... do something ...
/// }).into_join_future();
/// let future_2 = thread::spawn(|| {
/// // ... do something else ...
/// }).into_join_future();
///
/// // Both threads have been started; now await the completion of both.
/// future_1.await?;
/// future_2.await?;
/// Ok(())
/// }
/// ```
#[unstable(feature = "thread_join_future", issue = "none")]
pub fn into_join_future(self) -> JoinFuture<'static, T> {
// The method is not named `into_future()` to avoid overlapping with the stable
// `IntoFuture::into_future()`. We're not implementing `IntoFuture` in order to
// keep this unstable and preserve the *option* of compatibly making this obey structured
// concurrency via an async-Drop that waits for the thread to end.
JoinFuture::new(self.0)
}

/// Checks if the associated thread has finished running its main function.
///
/// `is_finished` supports implementing a non-blocking join operation, by checking
Expand All @@ -1856,7 +1938,7 @@ impl<T> JoinHandle<T> {
/// to return quickly, without blocking for any significant amount of time.
#[stable(feature = "thread_is_running", since = "1.61.0")]
pub fn is_finished(&self) -> bool {
Arc::strong_count(&self.0.packet) == 1
self.0.is_finished()
}
}

Expand All @@ -1882,9 +1964,94 @@ impl<T> fmt::Debug for JoinHandle<T> {
fn _assert_sync_and_send() {
fn _assert_both<T: Send + Sync>() {}
_assert_both::<JoinHandle<()>>();
_assert_both::<JoinFuture<'static, ()>>();
_assert_both::<Thread>();
}

/// A [`Future`] that resolves when a thread has finished.
///
/// Its [output](Future::Output) value is identical to that of [`JoinHandle::join()`];
/// this is the `async` equivalent of that blocking function.
/// Obtain it by calling [`JoinHandle::into_join_future()`] or
/// [`ScopedJoinHandle::into_join_future()`].
///
/// # Behavior details
///
/// * If a `JoinFuture` is dropped (cancelled), and the thread does not belong to a [scope],
/// the associated thread will become *detached*;
/// there will be no way to observe or wait for the thread’s termination.
///
/// * Unlike [`JoinHandle::join()`], the thread may still exist when the future resolves.
/// In particular, it may still be executing destructors for thread-local values.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct. Since JoinInner::join is called in take_result, the thread will be properly joined, resulting in blocking during the execution of the TLS destructors. IMHO that's a very reasonable behaviour as TLS destructors shouldn't be doing anything except freeing resources anyway, but it's not what's described here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point: take_result will actually block here if the thread is marked as finished but still executing TLS destructors. This is arguably the more correct behavior, but maybe not what async users are expecting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now updated the documentation to say that it currently blocks on TLS destruction, but is not guaranteed to continue doing that (because that seems an appropriately conservative starting point).

///
#[unstable(feature = "thread_join_future", issue = "none")]
pub struct JoinFuture<'scope, T>(Option<JoinInner<'scope, T>>);

impl<'scope, T> JoinFuture<'scope, T> {
fn new(inner: JoinInner<'scope, T>) -> Self {
Self(Some(inner))
}

/// Implements the “getting a result” part of joining/polling, without blocking or changing
/// the `Waker`. Part of the implementation of `poll()`.
///
/// If this returns `Some`, then `self.0` is now `None` and the future will panic
/// if polled again.
fn take_result(&mut self) -> Option<Result<T>> {
self.0.take_if(|i| i.is_finished()).map(JoinInner::join)
}
}

#[unstable(feature = "thread_join_future", issue = "none")]
impl<T> Future for JoinFuture<'_, T> {
type Output = Result<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
if let Some(result) = self.take_result() {
return task::Poll::Ready(result);
}

// Update the `Waker` the thread should wake when it completes.
{
let Some(inner) = &mut self.0 else {
panic!("polled after complete");
};

let new_waker = cx.waker();

// Lock the mutex, and ignore the poison state because there are no meaningful ways
// the existing contents can be corrupted; they will be overwritten completely and the
// overwrite is atomic-in-the-database-sense.
let mut current_waker_guard =
inner.packet.waker.lock().unwrap_or_else(PoisonError::into_inner);

// Overwrite the waker. Note that we are executing the new waker’s clone and the old
// waker’s destructor; these could panic (which will merely poison the lock) or hang,
// which will hold the lock, but the most that can do is prevent the thread from
// exiting because it's trying to acquire `packet.waker`, which it won't do while
// holding any *other* locks (...unless the thread’s data includes a lock guard that
// the waker also wants).
if !new_waker.will_wake(&*current_waker_guard) {
*current_waker_guard = new_waker.clone();
}
}

// Check for completion again in case the thread finished while we were busy
// setting the waker, to prevent a lost wakeup in that case.
if let Some(result) = self.take_result() {
task::Poll::Ready(result)
} else {
task::Poll::Pending
}
}
}

#[unstable(feature = "thread_join_future", issue = "none")]
impl<T> fmt::Debug for JoinFuture<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JoinHandle").finish_non_exhaustive()
}
}

/// Returns an estimate of the default amount of parallelism a program should use.
///
/// Parallelism is a resource. A given machine provides a certain capacity for
Expand Down
28 changes: 27 additions & 1 deletion library/std/src/thread/scoped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,32 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
self.0.join()
}

/// Returns a [`Future`] that resolves when the thread has finished.
///
/// Its [output] value is identical to that of [`ScopedJoinHandle::join()`];
/// this is the `async` equivalent of that blocking function.
///
/// # Behavior details
///
/// * Unlike [`JoinHandle::join()`], the thread may still exist when the future resolves.
/// In particular, it may still be executing destructors for thread-local values.
///
/// * While this function allows waiting for a scoped thread from `async`
/// functions, the original [`scope()`] is still a blocking function which should
/// not be used in `async` functions.
///
/// [`Future`]: crate::future::Future
/// [output]: crate::future::Future::Output
/// [`JoinHandle::join()`]: super::JoinHandle::join()
#[unstable(feature = "thread_join_future", issue = "none")]
pub fn into_join_future(self) -> super::JoinFuture<'scope, T> {
// There is no `ScopedJoinFuture` because the only difference between `JoinHandle`
// and `ScopedJoinHandle` is that `JoinHandle` has no lifetime parameter, because
// it was introduced before scoped threads. `JoinFuture` is new enough that we don’t
// need to make two versions of it.
super::JoinFuture::new(self.0)
}

/// Checks if the associated thread has finished running its main function.
///
/// `is_finished` supports implementing a non-blocking join operation, by checking
Expand All @@ -325,7 +351,7 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> {
/// to return quickly, without blocking for any significant amount of time.
#[stable(feature = "scoped_threads", since = "1.63.0")]
pub fn is_finished(&self) -> bool {
Arc::strong_count(&self.0.packet) == 1
self.0.is_finished()
}
}

Expand Down
48 changes: 47 additions & 1 deletion library/std/src/thread/tests.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use super::Builder;
use crate::any::Any;
use crate::assert_matches::assert_matches;
use crate::future::Future as _;
use crate::panic::panic_any;
use crate::sync::atomic::{AtomicBool, Ordering};
use crate::sync::mpsc::{Sender, channel};
use crate::sync::{Arc, Barrier};
use crate::thread::{self, Scope, ThreadId};
use crate::time::{Duration, Instant};
use crate::{mem, result};
use crate::{mem, result, task};

// !!! These tests are dangerous. If something is buggy, they will hang, !!!
// !!! instead of exiting cleanly. This might wedge the buildbots. !!!
Expand Down Expand Up @@ -410,3 +412,47 @@ fn test_minimal_thread_stack() {
assert_eq!(before, 0);
assert_eq!(COUNT.load(Ordering::Relaxed), 1);
}

fn join_future_test(scoped: bool) {
/// Simple `Waker` implementation.
/// If `std` ever gains a `block_on()`, we can consider replacing this with that.
struct MyWaker(Sender<()>);
impl task::Wake for MyWaker {
fn wake(self: Arc<Self>) {
_ = self.0.send(());
}
}

// Communication setup.
let (thread_delay_tx, thread_delay_rx) = channel();
let (waker_tx, waker_rx) = channel();
let waker = task::Waker::from(Arc::new(MyWaker(waker_tx)));
let ctx = &mut task::Context::from_waker(&waker);

thread::scope(|s| {
// Create the thread and the future under test
let thread_body = move || {
thread_delay_rx.recv().unwrap();
"hello"
};
let mut future = crate::pin::pin!(if scoped {
s.spawn(thread_body).into_join_future()
} else {
thread::spawn(thread_body).into_join_future()
});

// Actual test
assert_matches!(future.as_mut().poll(ctx), task::Poll::Pending);
thread_delay_tx.send(()).unwrap(); // Unblock the thread
waker_rx.recv().unwrap(); // Wait for waking (as an executor would)
assert_matches!(future.as_mut().poll(ctx), task::Poll::Ready(Ok("hello")));
});
}
#[test]
fn join_future_unscoped() {
join_future_test(false)
}
#[test]
fn join_future_scoped() {
join_future_test(true)
}
Loading