From 477e4106ebe59a33dcd5d9584dfbbcc68b5931c2 Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Wed, 22 Mar 2023 18:07:03 +0800 Subject: [PATCH] refactor `WorkerLocal` --- compiler/rustc_data_structures/Cargo.toml | 4 +- compiler/rustc_data_structures/src/sync.rs | 75 +++++++++++++--------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/compiler/rustc_data_structures/Cargo.toml b/compiler/rustc_data_structures/Cargo.toml index 29cb2c0a33e6..f2ddc8a82fed 100644 --- a/compiler/rustc_data_structures/Cargo.toml +++ b/compiler/rustc_data_structures/Cargo.toml @@ -14,7 +14,7 @@ indexmap = { version = "1.9.1" } jobserver_crate = { version = "0.1.13", package = "jobserver" } libc = "0.2" measureme = "10.0.0" -rayon-core = { version = "0.4.0", package = "rustc-rayon-core", optional = true } +rayon-core = { version = "0.4.0", package = "rustc-rayon-core" } rayon = { version = "0.4.0", package = "rustc-rayon", optional = true } rustc_graphviz = { path = "../rustc_graphviz" } rustc-hash = "1.1.0" @@ -43,4 +43,4 @@ winapi = { version = "0.3", features = ["fileapi", "psapi", "winerror"] } memmap2 = "0.2.1" [features] -rustc_use_parallel_compiler = ["indexmap/rustc-rayon", "rayon", "rayon-core"] +rustc_use_parallel_compiler = ["indexmap/rustc-rayon", "rayon"] diff --git a/compiler/rustc_data_structures/src/sync.rs b/compiler/rustc_data_structures/src/sync.rs index 31323c21df00..91617f825592 100644 --- a/compiler/rustc_data_structures/src/sync.rs +++ b/compiler/rustc_data_structures/src/sync.rs @@ -20,6 +20,7 @@ use crate::owning_ref::{Erased, OwningRef}; use std::collections::HashMap; use std::hash::{BuildHasher, Hash}; +use std::mem::MaybeUninit; use std::ops::{Deref, DerefMut}; use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; @@ -30,6 +31,8 @@ pub use vec::AppendOnlyVec; mod vec; +static PARALLEL: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); + cfg_if! { if #[cfg(not(parallel_compiler))] { pub auto trait Send {} @@ -182,33 +185,6 @@ cfg_if! { use std::cell::Cell; - #[derive(Debug)] - pub struct WorkerLocal(OneThread); - - impl WorkerLocal { - /// Creates a new worker local where the `initial` closure computes the - /// value this worker local should take for each thread in the thread pool. - #[inline] - pub fn new T>(mut f: F) -> WorkerLocal { - WorkerLocal(OneThread::new(f(0))) - } - - /// Returns the worker-local value for each thread - #[inline] - pub fn into_inner(self) -> Vec { - vec![OneThread::into_inner(self.0)] - } - } - - impl Deref for WorkerLocal { - type Target = T; - - #[inline(always)] - fn deref(&self) -> &T { - &self.0 - } - } - pub type MTRef<'a, T> = &'a mut T; #[derive(Debug, Default)] @@ -328,8 +304,6 @@ cfg_if! { }; } - pub use rayon_core::WorkerLocal; - pub use rayon::iter::ParallelIterator; use rayon::iter::IntoParallelIterator; @@ -364,6 +338,49 @@ cfg_if! { } } +#[derive(Debug)] +pub struct WorkerLocal { + single_thread: bool, + inner: T, + mt_inner: Option>, +} + +impl WorkerLocal { + /// Creates a new worker local where the `initial` closure computes the + /// value this worker local should take for each thread in the thread pool. + #[inline] + pub fn new T>(mut f: F) -> WorkerLocal { + if !PARALLEL.load(Ordering::Relaxed) { + WorkerLocal { single_thread: true, inner: f(0), mt_inner: None } + } else { + // Safety: `inner` would never be accessed when multiple threads + WorkerLocal { + single_thread: false, + inner: unsafe { MaybeUninit::uninit().assume_init() }, + mt_inner: Some(rayon_core::WorkerLocal::new(f)), + } + } + } + + /// Returns the worker-local value for each thread + #[inline] + pub fn into_inner(self) -> Vec { + if self.single_thread { vec![self.inner] } else { self.mt_inner.unwrap().into_inner() } + } +} + +impl Deref for WorkerLocal { + type Target = T; + + #[inline(always)] + fn deref(&self) -> &T { + if self.single_thread { &self.inner } else { self.mt_inner.as_ref().unwrap().deref() } + } +} + +// Just for speed test +unsafe impl std::marker::Sync for WorkerLocal {} + pub fn assert_sync() {} pub fn assert_send() {} pub fn assert_send_val(_t: &T) {}