Skip to content

Commit 0e24ad5

Browse files
committed
Implement RFC 3151: Scoped threads.
1 parent a45b3ac commit 0e24ad5

File tree

2 files changed

+202
-26
lines changed

2 files changed

+202
-26
lines changed

library/std/src/thread/mod.rs

+70-26
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ use crate::time::Duration;
180180
#[macro_use]
181181
mod local;
182182

183+
#[unstable(feature = "scoped_threads", issue = "none")]
184+
mod scoped;
185+
186+
#[unstable(feature = "scoped_threads", issue = "none")]
187+
pub use scoped::{scope, Scope, ScopedJoinHandle};
188+
183189
#[stable(feature = "rust1", since = "1.0.0")]
184190
pub use self::local::{AccessError, LocalKey};
185191

@@ -446,6 +452,20 @@ impl Builder {
446452
F: FnOnce() -> T,
447453
F: Send + 'a,
448454
T: Send + 'a,
455+
{
456+
Ok(JoinHandle(unsafe { self.spawn_unchecked_(f, None) }?))
457+
}
458+
459+
unsafe fn spawn_unchecked_<'a, 'scope, F, T>(
460+
self,
461+
f: F,
462+
scope_data: Option<&'scope scoped::ScopeData>,
463+
) -> io::Result<JoinInner<'scope, T>>
464+
where
465+
F: FnOnce() -> T,
466+
F: Send + 'a,
467+
T: Send + 'a,
468+
'scope: 'a,
449469
{
450470
let Builder { name, stack_size } = self;
451471

@@ -456,7 +476,8 @@ impl Builder {
456476
}));
457477
let their_thread = my_thread.clone();
458478

459-
let my_packet: Arc<UnsafeCell<Option<Result<T>>>> = Arc::new(UnsafeCell::new(None));
479+
let my_packet: Arc<Packet<'scope, T>> =
480+
Arc::new(Packet { scope: scope_data, result: UnsafeCell::new(None) });
460481
let their_packet = my_packet.clone();
461482

462483
let output_capture = crate::io::set_output_capture(None);
@@ -480,10 +501,14 @@ impl Builder {
480501
// closure (it is an Arc<...>) and `my_packet` will be stored in the
481502
// same `JoinInner` as this closure meaning the mutation will be
482503
// safe (not modify it and affect a value far away).
483-
unsafe { *their_packet.get() = Some(try_result) };
504+
unsafe { *their_packet.result.get() = Some(try_result) };
484505
};
485506

486-
Ok(JoinHandle(JoinInner {
507+
if let Some(scope_data) = scope_data {
508+
scope_data.increment_n_running_threads();
509+
}
510+
511+
Ok(JoinInner {
487512
// SAFETY:
488513
//
489514
// `imp::Thread::new` takes a closure with a `'static` lifetime, since it's passed
@@ -506,8 +531,8 @@ impl Builder {
506531
)?
507532
},
508533
thread: my_thread,
509-
packet: Packet(my_packet),
510-
}))
534+
packet: my_packet,
535+
})
511536
}
512537
}
513538

@@ -1239,34 +1264,53 @@ impl fmt::Debug for Thread {
12391264
#[stable(feature = "rust1", since = "1.0.0")]
12401265
pub type Result<T> = crate::result::Result<T, Box<dyn Any + Send + 'static>>;
12411266

1242-
// This packet is used to communicate the return value between the spawned thread
1243-
// and the rest of the program. Memory is shared through the `Arc` within and there's
1244-
// no need for a mutex here because synchronization happens with `join()` (the
1245-
// caller will never read this packet until the thread has exited).
1267+
// This packet is used to communicate the return value between the spawned
1268+
// thread and the rest of the program. It is shared through an `Arc` and
1269+
// there's no need for a mutex here because synchronization happens with `join()`
1270+
// (the caller will never read this packet until the thread has exited).
12461271
//
1247-
// This packet itself is then stored into a `JoinInner` which in turns is placed
1248-
// in `JoinHandle` and `JoinGuard`. Due to the usage of `UnsafeCell` we need to
1249-
// manually worry about impls like Send and Sync. The type `T` should
1250-
// already always be Send (otherwise the thread could not have been created) and
1251-
// this type is inherently Sync because no methods take &self. Regardless,
1252-
// however, we add inheriting impls for Send/Sync to this type to ensure it's
1253-
// Send/Sync and that future modifications will still appropriately classify it.
1254-
struct Packet<T>(Arc<UnsafeCell<Option<Result<T>>>>);
1255-
1256-
unsafe impl<T: Send> Send for Packet<T> {}
1257-
unsafe impl<T: Sync> Sync for Packet<T> {}
1272+
// An Arc to the packet is stored into a `JoinInner` which in turns is placed
1273+
// in `JoinHandle`. Due to the usage of `UnsafeCell` we need to manually worry
1274+
// about impls like Send and Sync. The type `T` should already always be Send
1275+
// (otherwise the thread could not have been created) and this type is
1276+
// inherently Sync because no methods take &self. Regardless, however, we add
1277+
// inheriting impls for Send/Sync to this type to ensure it's Send/Sync and
1278+
// that future modifications will still appropriately classify it.
1279+
struct Packet<'scope, T> {
1280+
scope: Option<&'scope scoped::ScopeData>,
1281+
result: UnsafeCell<Option<Result<T>>>,
1282+
}
1283+
1284+
unsafe impl<'scope, T: Send> Send for Packet<'scope, T> {}
1285+
unsafe impl<'scope, T: Sync> Sync for Packet<'scope, T> {}
1286+
1287+
impl<'scope, T> Drop for Packet<'scope, T> {
1288+
fn drop(&mut self) {
1289+
if let Some(scope) = self.scope {
1290+
// If this packet was for a thread that ran in a scope, the thread
1291+
// panicked, and nobody consumed the panic payload, we put the
1292+
// panic payload in the scope so it can re-throw it, if it didn't
1293+
// already capture any panic yet.
1294+
if let Some(Err(e)) = self.result.get_mut().take() {
1295+
scope.panic_payload.lock().unwrap().get_or_insert(e);
1296+
}
1297+
// Book-keeping so the scope knows when it's done.
1298+
scope.decrement_n_running_threads();
1299+
}
1300+
}
1301+
}
12581302

12591303
/// Inner representation for JoinHandle
1260-
struct JoinInner<T> {
1304+
struct JoinInner<'scope, T> {
12611305
native: imp::Thread,
12621306
thread: Thread,
1263-
packet: Packet<T>,
1307+
packet: Arc<Packet<'scope, T>>,
12641308
}
12651309

1266-
impl<T> JoinInner<T> {
1310+
impl<'scope, T> JoinInner<'scope, T> {
12671311
fn join(mut self) -> Result<T> {
12681312
self.native.join();
1269-
Arc::get_mut(&mut self.packet.0).unwrap().get_mut().take().unwrap()
1313+
Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap()
12701314
}
12711315
}
12721316

@@ -1333,7 +1377,7 @@ impl<T> JoinInner<T> {
13331377
/// [`thread::Builder::spawn`]: Builder::spawn
13341378
/// [`thread::spawn`]: spawn
13351379
#[stable(feature = "rust1", since = "1.0.0")]
1336-
pub struct JoinHandle<T>(JoinInner<T>);
1380+
pub struct JoinHandle<T>(JoinInner<'static, T>);
13371381

13381382
#[stable(feature = "joinhandle_impl_send_sync", since = "1.29.0")]
13391383
unsafe impl<T> Send for JoinHandle<T> {}
@@ -1407,7 +1451,7 @@ impl<T> JoinHandle<T> {
14071451
/// function has returned, but before the thread itself has stopped running.
14081452
#[unstable(feature = "thread_is_running", issue = "90470")]
14091453
pub fn is_running(&self) -> bool {
1410-
Arc::strong_count(&self.0.packet.0) > 1
1454+
Arc::strong_count(&self.0.packet) > 1
14111455
}
14121456
}
14131457

library/std/src/thread/scoped.rs

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
use super::{current, park, Builder, JoinInner, Result, Thread};
2+
use crate::any::Any;
3+
use crate::fmt;
4+
use crate::io;
5+
use crate::marker::PhantomData;
6+
use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
7+
use crate::sync::atomic::{AtomicUsize, Ordering};
8+
use crate::sync::Mutex;
9+
10+
/// TODO: documentation
11+
pub struct Scope<'env> {
12+
data: ScopeData,
13+
env: PhantomData<&'env ()>,
14+
}
15+
16+
/// TODO: documentation
17+
pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>);
18+
19+
pub(super) struct ScopeData {
20+
n_running_threads: AtomicUsize,
21+
main_thread: Thread,
22+
pub(super) panic_payload: Mutex<Option<Box<dyn Any + Send>>>,
23+
}
24+
25+
impl ScopeData {
26+
pub(super) fn increment_n_running_threads(&self) {
27+
// We check for 'overflow' with usize::MAX / 2, to make sure there's no
28+
// chance it overflows to 0, which would result in unsoundness.
29+
if self.n_running_threads.fetch_add(1, Ordering::Relaxed) == usize::MAX / 2 {
30+
// This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles.
31+
self.decrement_n_running_threads();
32+
panic!("too many running threads in thread scope");
33+
}
34+
}
35+
pub(super) fn decrement_n_running_threads(&self) {
36+
if self.n_running_threads.fetch_sub(1, Ordering::Release) == 1 {
37+
self.main_thread.unpark();
38+
}
39+
}
40+
}
41+
42+
/// TODO: documentation
43+
pub fn scope<'env, F, T>(f: F) -> T
44+
where
45+
F: FnOnce(&Scope<'env>) -> T,
46+
{
47+
let mut scope = Scope {
48+
data: ScopeData {
49+
n_running_threads: AtomicUsize::new(0),
50+
main_thread: current(),
51+
panic_payload: Mutex::new(None),
52+
},
53+
env: PhantomData,
54+
};
55+
56+
// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
57+
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));
58+
59+
// Wait until all the threads are finished.
60+
while scope.data.n_running_threads.load(Ordering::Acquire) != 0 {
61+
park();
62+
}
63+
64+
// Throw any panic from `f` or from any panicked thread, or the return value of `f` otherwise.
65+
match result {
66+
Err(e) => {
67+
// `f` itself panicked.
68+
resume_unwind(e);
69+
}
70+
Ok(result) => {
71+
if let Some(panic_payload) = scope.data.panic_payload.get_mut().unwrap().take() {
72+
// A thread panicked.
73+
resume_unwind(panic_payload);
74+
} else {
75+
// Nothing panicked.
76+
result
77+
}
78+
}
79+
}
80+
}
81+
82+
impl<'env> Scope<'env> {
83+
/// TODO: documentation
84+
pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
85+
where
86+
F: FnOnce(&Scope<'env>) -> T + Send + 'env,
87+
T: Send + 'env,
88+
{
89+
Builder::new().spawn_scoped(self, f).expect("failed to spawn thread")
90+
}
91+
}
92+
93+
impl Builder {
94+
fn spawn_scoped<'scope, 'env, F, T>(
95+
self,
96+
scope: &'scope Scope<'env>,
97+
f: F,
98+
) -> io::Result<ScopedJoinHandle<'scope, T>>
99+
where
100+
F: FnOnce(&Scope<'env>) -> T + Send + 'env,
101+
T: Send + 'env,
102+
{
103+
Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(|| f(scope), Some(&scope.data)) }?))
104+
}
105+
}
106+
107+
impl<'scope, T> ScopedJoinHandle<'scope, T> {
108+
/// TODO
109+
pub fn join(self) -> Result<T> {
110+
self.0.join()
111+
}
112+
113+
/// TODO
114+
pub fn thread(&self) -> &Thread {
115+
&self.0.thread
116+
}
117+
}
118+
119+
impl<'env> fmt::Debug for Scope<'env> {
120+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121+
f.debug_struct("Scope")
122+
.field("n_running_threads", &self.data.n_running_threads.load(Ordering::Relaxed))
123+
.field("panic_payload", &self.data.panic_payload)
124+
.finish_non_exhaustive()
125+
}
126+
}
127+
128+
impl<'scope, T> fmt::Debug for ScopedJoinHandle<'scope, T> {
129+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130+
f.debug_struct("ScopedJoinHandle").finish_non_exhaustive()
131+
}
132+
}

0 commit comments

Comments
 (0)