Skip to content

Commit 8fc2eeb

Browse files
Use newtype to map from Local to GeneratorSavedLocal
1 parent 72417d8 commit 8fc2eeb

File tree

1 file changed

+64
-44
lines changed

1 file changed

+64
-44
lines changed

src/librustc_mir/transform/generator.rs

+64-44
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
7272
use rustc_target::abi::VariantIdx;
7373
use rustc_target::spec::PanicStrategy;
7474
use std::borrow::Cow;
75-
use std::iter;
75+
use std::{iter, ops};
7676

7777
pub struct StateTransform;
7878

@@ -417,11 +417,7 @@ fn replace_local<'tcx>(
417417

418418
struct LivenessInfo {
419419
/// Which locals are live across any suspension point.
420-
///
421-
/// GeneratorSavedLocal is indexed in terms of the elements in this set;
422-
/// i.e. GeneratorSavedLocal::new(1) corresponds to the second local
423-
/// included in this set.
424-
live_locals: BitSet<Local>,
420+
saved_locals: GeneratorSavedLocals,
425421

426422
/// The set of saved locals live at each suspension point.
427423
live_locals_at_suspension_points: Vec<BitSet<GeneratorSavedLocal>>,
@@ -524,49 +520,75 @@ fn locals_live_across_suspend_points(
524520
live_locals_at_suspension_points.push(live_locals);
525521
}
526522
}
523+
527524
debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point);
525+
let saved_locals = GeneratorSavedLocals(live_locals_at_any_suspension_point);
528526

529527
// Renumber our liveness_map bitsets to include only the locals we are
530528
// saving.
531529
let live_locals_at_suspension_points = live_locals_at_suspension_points
532530
.iter()
533-
.map(|live_here| renumber_bitset(&live_here, &live_locals_at_any_suspension_point))
531+
.map(|live_here| saved_locals.renumber_bitset(&live_here))
534532
.collect();
535533

536534
let storage_conflicts = compute_storage_conflicts(
537535
body_ref,
538-
&live_locals_at_any_suspension_point,
536+
&saved_locals,
539537
always_live_locals.clone(),
540538
requires_storage_results,
541539
);
542540

543541
LivenessInfo {
544-
live_locals: live_locals_at_any_suspension_point,
542+
saved_locals,
545543
live_locals_at_suspension_points,
546544
storage_conflicts,
547545
storage_liveness: storage_liveness_map,
548546
}
549547
}
550548

551-
/// Renumbers the items present in `stored_locals` and applies the renumbering
552-
/// to 'input`.
549+
/// The set of `Local`s that must be saved across yield points.
553550
///
554-
/// For example, if `stored_locals = [1, 3, 5]`, this would be renumbered to
555-
/// `[0, 1, 2]`. Thus, if `input = [3, 5]` we would return `[1, 2]`.
556-
fn renumber_bitset(
557-
input: &BitSet<Local>,
558-
stored_locals: &BitSet<Local>,
559-
) -> BitSet<GeneratorSavedLocal> {
560-
assert!(stored_locals.superset(&input), "{:?} not a superset of {:?}", stored_locals, input);
561-
let mut out = BitSet::new_empty(stored_locals.count());
562-
for (idx, local) in stored_locals.iter().enumerate() {
563-
let saved_local = GeneratorSavedLocal::from(idx);
564-
if input.contains(local) {
565-
out.insert(saved_local);
551+
/// `GeneratorSavedLocal` is indexed in terms of the elements in this set;
552+
/// i.e. `GeneratorSavedLocal::new(1)` corresponds to the second local
553+
/// included in this set.
554+
struct GeneratorSavedLocals(BitSet<Local>);
555+
556+
impl GeneratorSavedLocals {
557+
/// Returns an iterator over each `GeneratorSavedLocal` along with the `Local` it corresponds
558+
/// to.
559+
fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)> {
560+
self.iter().enumerate().map(|(i, l)| (GeneratorSavedLocal::from(i), l))
561+
}
562+
563+
/// Transforms a `BitSet<Local>` that contains only locals saved across yield points to the
564+
/// equivalent `BitSet<GeneratorSavedLocal>`.
565+
fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal> {
566+
assert!(self.superset(&input), "{:?} not a superset of {:?}", self.0, input);
567+
let mut out = BitSet::new_empty(self.count());
568+
for (saved_local, local) in self.iter_enumerated() {
569+
if input.contains(local) {
570+
out.insert(saved_local);
571+
}
572+
}
573+
out
574+
}
575+
576+
fn get(&self, local: Local) -> Option<GeneratorSavedLocal> {
577+
if !self.contains(local) {
578+
return None;
566579
}
580+
581+
let idx = self.iter().take_while(|&l| l < local).count();
582+
Some(GeneratorSavedLocal::new(idx))
583+
}
584+
}
585+
586+
impl ops::Deref for GeneratorSavedLocals {
587+
type Target = BitSet<Local>;
588+
589+
fn deref(&self) -> &Self::Target {
590+
&self.0
567591
}
568-
debug!("renumber_bitset({:?}, {:?}) => {:?}", input, stored_locals, out);
569-
out
570592
}
571593

572594
/// For every saved local, looks for which locals are StorageLive at the same
@@ -575,24 +597,24 @@ fn renumber_bitset(
575597
/// computation; see `GeneratorLayout` for more.
576598
fn compute_storage_conflicts(
577599
body: &'mir Body<'tcx>,
578-
stored_locals: &BitSet<Local>,
600+
saved_locals: &GeneratorSavedLocals,
579601
always_live_locals: storage::AlwaysLiveLocals,
580602
requires_storage: dataflow::Results<'tcx, MaybeRequiresStorage<'mir, 'tcx>>,
581603
) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> {
582-
assert_eq!(body.local_decls.len(), stored_locals.domain_size());
604+
assert_eq!(body.local_decls.len(), saved_locals.domain_size());
583605

584606
debug!("compute_storage_conflicts({:?})", body.span);
585607
debug!("always_live = {:?}", always_live_locals);
586608

587609
// Locals that are always live or ones that need to be stored across
588610
// suspension points are not eligible for overlap.
589611
let mut ineligible_locals = always_live_locals.into_inner();
590-
ineligible_locals.intersect(stored_locals);
612+
ineligible_locals.intersect(saved_locals);
591613

592614
// Compute the storage conflicts for all eligible locals.
593615
let mut visitor = StorageConflictVisitor {
594616
body,
595-
stored_locals: &stored_locals,
617+
saved_locals: &saved_locals,
596618
local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()),
597619
};
598620

@@ -609,16 +631,14 @@ fn compute_storage_conflicts(
609631
// However, in practice these bitsets are not usually large. The layout code
610632
// also needs to keep track of how many conflicts each local has, so it's
611633
// simpler to keep it this way for now.
612-
let mut storage_conflicts = BitMatrix::new(stored_locals.count(), stored_locals.count());
613-
for (idx_a, local_a) in stored_locals.iter().enumerate() {
614-
let saved_local_a = GeneratorSavedLocal::new(idx_a);
634+
let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count());
635+
for (saved_local_a, local_a) in saved_locals.iter_enumerated() {
615636
if ineligible_locals.contains(local_a) {
616637
// Conflicts with everything.
617638
storage_conflicts.insert_all_into_row(saved_local_a);
618639
} else {
619640
// Keep overlap information only for stored locals.
620-
for (idx_b, local_b) in stored_locals.iter().enumerate() {
621-
let saved_local_b = GeneratorSavedLocal::new(idx_b);
641+
for (saved_local_b, local_b) in saved_locals.iter_enumerated() {
622642
if local_conflicts.contains(local_a, local_b) {
623643
storage_conflicts.insert(saved_local_a, saved_local_b);
624644
}
@@ -630,7 +650,7 @@ fn compute_storage_conflicts(
630650

631651
struct StorageConflictVisitor<'mir, 'tcx, 's> {
632652
body: &'mir Body<'tcx>,
633-
stored_locals: &'s BitSet<Local>,
653+
saved_locals: &'s GeneratorSavedLocals,
634654
// FIXME(tmandry): Consider using sparse bitsets here once we have good
635655
// benchmarks for generators.
636656
local_conflicts: BitMatrix<Local, Local>,
@@ -666,7 +686,7 @@ impl<'body, 'tcx, 's> StorageConflictVisitor<'body, 'tcx, 's> {
666686
}
667687

668688
let mut eligible_storage_live = flow_state.clone();
669-
eligible_storage_live.intersect(&self.stored_locals);
689+
eligible_storage_live.intersect(&self.saved_locals);
670690

671691
for local in eligible_storage_live.iter() {
672692
self.local_conflicts.union_row_with(&eligible_storage_live, local);
@@ -678,15 +698,15 @@ impl<'body, 'tcx, 's> StorageConflictVisitor<'body, 'tcx, 's> {
678698
}
679699
}
680700

681-
/// Validates the typeck view of the generator against the actual set of types retained between
701+
/// Validates the typeck view of the generator against the actual set of types saved between
682702
/// yield points.
683703
fn sanitize_witness<'tcx>(
684704
tcx: TyCtxt<'tcx>,
685705
body: &Body<'tcx>,
686706
did: DefId,
687707
witness: Ty<'tcx>,
688708
upvars: &Vec<Ty<'tcx>>,
689-
retained: &BitSet<Local>,
709+
saved_locals: &GeneratorSavedLocals,
690710
) {
691711
let allowed_upvars = tcx.erase_regions(upvars);
692712
let allowed = match witness.kind {
@@ -703,8 +723,8 @@ fn sanitize_witness<'tcx>(
703723
let param_env = tcx.param_env(did);
704724

705725
for (local, decl) in body.local_decls.iter_enumerated() {
706-
// Ignore locals which are internal or not retained between yields.
707-
if !retained.contains(local) || decl.internal {
726+
// Ignore locals which are internal or not saved between yields.
727+
if !saved_locals.contains(local) || decl.internal {
708728
continue;
709729
}
710730
let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
@@ -738,21 +758,21 @@ fn compute_layout<'tcx>(
738758
) {
739759
// Use a liveness analysis to compute locals which are live across a suspension point
740760
let LivenessInfo {
741-
live_locals,
761+
saved_locals,
742762
live_locals_at_suspension_points,
743763
storage_conflicts,
744764
storage_liveness,
745765
} = locals_live_across_suspend_points(tcx, body, source, always_live_locals, movable);
746766

747-
sanitize_witness(tcx, body, source.def_id(), interior, upvars, &live_locals);
767+
sanitize_witness(tcx, body, source.def_id(), interior, upvars, &saved_locals);
748768

749769
// Gather live local types and their indices.
750770
let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
751771
let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
752-
for (idx, local) in live_locals.iter().enumerate() {
772+
for (saved_local, local) in saved_locals.iter_enumerated() {
753773
locals.push(local);
754774
tys.push(body.local_decls[local].ty);
755-
debug!("generator saved local {:?} => {:?}", GeneratorSavedLocal::from(idx), local);
775+
debug!("generator saved local {:?} => {:?}", saved_local, local);
756776
}
757777

758778
// Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.

0 commit comments

Comments
 (0)