From 833afb02035fde158d91881771f3536fcdae6bfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ber=C3=A1nek?= Date: Fri, 10 Dec 2021 16:28:13 +0100 Subject: [PATCH 1/2] Create a visitor that maps upvar fields to corresponding MIR locals --- compiler/rustc_mir_transform/src/generator.rs | 72 +++++++++++++++++-- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index 433a1c6ad67cc..cc13a036747bb 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -794,13 +794,11 @@ fn compute_layout<'tcx>( // (RETURNED, POISONED) of the function. const RESERVED_VARIANTS: usize = 3; let body_span = body.source_scopes[OUTERMOST_SOURCE_SCOPE].span; - let mut variant_source_info: IndexVec = [ + let mut variant_source_info: IndexVec = std::array::IntoIter::new([ SourceInfo::outermost(body_span.shrink_to_lo()), SourceInfo::outermost(body_span.shrink_to_hi()), SourceInfo::outermost(body_span.shrink_to_hi()), - ] - .iter() - .copied() + ]) .collect(); // Build the generator variant field list. @@ -1258,7 +1256,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { ty::Generator(_, substs, movability) => { let substs = substs.as_generator(); ( - substs.upvar_tys().collect(), + substs.upvar_tys().collect::>(), substs.witness(), substs.discr_ty(tcx), movability == hir::Movability::Movable, @@ -1291,8 +1289,21 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // When first entering the generator, move the resume argument into its new local. let source_info = SourceInfo::outermost(body.span); - let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements; - stmts.insert( + + let mut upvar_collector = ExtractGeneratorUpvarLocals::default(); + for (block, data) in body.basic_blocks().iter_enumerated() { + upvar_collector.visit_basic_block_data(block, data); + } + let upvar_locals = upvar_collector.finish(); + tracing::info!("Upvar locals: {:?}", upvar_locals); + tracing::info!("Upvar count: {:?}", upvars.len()); + if upvar_locals.len() != upvars.len() { + eprintln!("{:#?}", body); + assert_eq!(upvar_locals.len(), upvars.len()); + } + + let first_block = &mut body.basic_blocks_mut()[BasicBlock::new(0)]; + first_block.statements.insert( 0, Statement { source_info, @@ -1375,6 +1386,53 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } } +/// Finds locals that are assigned from generator upvars. +#[derive(Default)] +struct ExtractGeneratorUpvarLocals { + upvar_locals: FxHashMap>, +} + +impl ExtractGeneratorUpvarLocals { + fn finish(self) -> FxHashMap> { + self.upvar_locals + } +} + +impl<'tcx> Visitor<'tcx> for ExtractGeneratorUpvarLocals { + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { + let mut visitor = FindGeneratorFieldAccess { field_index: None }; + visitor.visit_rvalue(rvalue, location); + + if let Some(field_index) = visitor.field_index { + self.upvar_locals.entry(field_index).or_insert_with(|| vec![]).push(place.local); + } + } +} + +struct FindGeneratorFieldAccess { + field_index: Option, +} + +impl<'tcx> Visitor<'tcx> for FindGeneratorFieldAccess { + fn visit_projection( + &mut self, + place_ref: PlaceRef<'tcx>, + _context: PlaceContext, + _location: Location, + ) { + tracing::info!("visit_projection, place_ref={:#?}", place_ref); + + if place_ref.local.as_usize() == 1 { + if !place_ref.projection.is_empty() { + if let Some(ProjectionElem::Field(field, _)) = place_ref.projection.get(0) { + assert!(self.field_index.is_none()); + self.field_index = Some(*field); + } + } + } + } +} + /// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields /// in the generator state machine but whose storage is not marked as conflicting /// From 0af63de5fbefa5bf15a6f5a1fc38f29f2b1144c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Ber=C3=A1nek?= Date: Tue, 25 Jan 2022 16:41:06 +0100 Subject: [PATCH 2/2] WIP: replace upvar locals with generator state field accesses --- compiler/rustc_mir_transform/src/generator.rs | 64 +++++++++++++++---- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index cc13a036747bb..dc8e500ce3034 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -1296,10 +1296,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } let upvar_locals = upvar_collector.finish(); tracing::info!("Upvar locals: {:?}", upvar_locals); - tracing::info!("Upvar count: {:?}", upvars.len()); - if upvar_locals.len() != upvars.len() { - eprintln!("{:#?}", body); - assert_eq!(upvar_locals.len(), upvars.len()); + tracing::info!("Expected upvar count: {:?}", upvars.len()); + + let mut replacer = ReplaceLocalWithGeneratorFieldAccess { tcx, upvar_locals }; + for (block, data) in body.basic_blocks_mut().iter_enumerated_mut() { + replacer.visit_basic_block_data(block, data); } let first_block = &mut body.basic_blocks_mut()[BasicBlock::new(0)]; @@ -1388,23 +1389,26 @@ impl<'tcx> MirPass<'tcx> for StateTransform { /// Finds locals that are assigned from generator upvars. #[derive(Default)] -struct ExtractGeneratorUpvarLocals { - upvar_locals: FxHashMap>, +struct ExtractGeneratorUpvarLocals<'tcx> { + upvar_locals: FxHashMap>, } -impl ExtractGeneratorUpvarLocals { - fn finish(self) -> FxHashMap> { +impl<'tcx> ExtractGeneratorUpvarLocals<'tcx> { + fn finish(self) -> FxHashMap> { self.upvar_locals } } -impl<'tcx> Visitor<'tcx> for ExtractGeneratorUpvarLocals { +impl<'tcx> Visitor<'tcx> for ExtractGeneratorUpvarLocals<'tcx> { fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { let mut visitor = FindGeneratorFieldAccess { field_index: None }; visitor.visit_rvalue(rvalue, location); - if let Some(field_index) = visitor.field_index { - self.upvar_locals.entry(field_index).or_insert_with(|| vec![]).push(place.local); + if let Some(_) = visitor.field_index { + if !place.projection.is_empty() { + panic!("Non-empty projectsion: {place:#?}"); + } + self.upvar_locals.insert(place.local, rvalue.clone()); } } } @@ -1420,7 +1424,7 @@ impl<'tcx> Visitor<'tcx> for FindGeneratorFieldAccess { _context: PlaceContext, _location: Location, ) { - tracing::info!("visit_projection, place_ref={:#?}", place_ref); + tracing::info!("visit_projection, place_ref={place_ref:#?}"); if place_ref.local.as_usize() == 1 { if !place_ref.projection.is_empty() { @@ -1433,6 +1437,42 @@ impl<'tcx> Visitor<'tcx> for FindGeneratorFieldAccess { } } +struct ReplaceLocalWithGeneratorFieldAccess<'tcx> { + tcx: TyCtxt<'tcx>, + upvar_locals: FxHashMap>, +} + +impl<'tcx> MutVisitor<'tcx> for ReplaceLocalWithGeneratorFieldAccess<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { + for (statement_index, statement) in data.statements.iter_mut().enumerate() { + if let StatementKind::Assign(box (place, _rvalue)) = &statement.kind { + // Upvar was stored into a local => turn into nop + if self.upvar_locals.contains_key(&place.local) { + *statement = Statement { source_info: statement.source_info, kind: StatementKind::Nop }; + } + } + self.visit_statement(statement, Location { block, statement_index }); + } + } + + fn visit_place(&mut self, source: &mut Place<'tcx>, _context: PlaceContext, _location: Location) { + if let Some(rvalue) = self.upvar_locals.get(&source.local) { + match rvalue { + Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { + tracing::info!("Replacing {source:#?} with {place:#?}"); + *source = *place; + } + _ => {} + } + } + } +} + + /// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields /// in the generator state machine but whose storage is not marked as conflicting ///