diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index 433a1c6ad67cc..dc8e500ce3034 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -794,13 +794,11 @@ fn compute_layout<'tcx>( // (RETURNED, POISONED) of the function. const RESERVED_VARIANTS: usize = 3; let body_span = body.source_scopes[OUTERMOST_SOURCE_SCOPE].span; - let mut variant_source_info: IndexVec = [ + let mut variant_source_info: IndexVec = std::array::IntoIter::new([ SourceInfo::outermost(body_span.shrink_to_lo()), SourceInfo::outermost(body_span.shrink_to_hi()), SourceInfo::outermost(body_span.shrink_to_hi()), - ] - .iter() - .copied() + ]) .collect(); // Build the generator variant field list. @@ -1258,7 +1256,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { ty::Generator(_, substs, movability) => { let substs = substs.as_generator(); ( - substs.upvar_tys().collect(), + substs.upvar_tys().collect::>(), substs.witness(), substs.discr_ty(tcx), movability == hir::Movability::Movable, @@ -1291,8 +1289,22 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // When first entering the generator, move the resume argument into its new local. let source_info = SourceInfo::outermost(body.span); - let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements; - stmts.insert( + + let mut upvar_collector = ExtractGeneratorUpvarLocals::default(); + for (block, data) in body.basic_blocks().iter_enumerated() { + upvar_collector.visit_basic_block_data(block, data); + } + let upvar_locals = upvar_collector.finish(); + tracing::info!("Upvar locals: {:?}", upvar_locals); + tracing::info!("Expected upvar count: {:?}", upvars.len()); + + let mut replacer = ReplaceLocalWithGeneratorFieldAccess { tcx, upvar_locals }; + for (block, data) in body.basic_blocks_mut().iter_enumerated_mut() { + replacer.visit_basic_block_data(block, data); + } + + let first_block = &mut body.basic_blocks_mut()[BasicBlock::new(0)]; + first_block.statements.insert( 0, Statement { source_info, @@ -1375,6 +1387,92 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } } +/// Finds locals that are assigned from generator upvars. +#[derive(Default)] +struct ExtractGeneratorUpvarLocals<'tcx> { + upvar_locals: FxHashMap>, +} + +impl<'tcx> ExtractGeneratorUpvarLocals<'tcx> { + fn finish(self) -> FxHashMap> { + self.upvar_locals + } +} + +impl<'tcx> Visitor<'tcx> for ExtractGeneratorUpvarLocals<'tcx> { + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { + let mut visitor = FindGeneratorFieldAccess { field_index: None }; + visitor.visit_rvalue(rvalue, location); + + if let Some(_) = visitor.field_index { + if !place.projection.is_empty() { + panic!("Non-empty projectsion: {place:#?}"); + } + self.upvar_locals.insert(place.local, rvalue.clone()); + } + } +} + +struct FindGeneratorFieldAccess { + field_index: Option, +} + +impl<'tcx> Visitor<'tcx> for FindGeneratorFieldAccess { + fn visit_projection( + &mut self, + place_ref: PlaceRef<'tcx>, + _context: PlaceContext, + _location: Location, + ) { + tracing::info!("visit_projection, place_ref={place_ref:#?}"); + + if place_ref.local.as_usize() == 1 { + if !place_ref.projection.is_empty() { + if let Some(ProjectionElem::Field(field, _)) = place_ref.projection.get(0) { + assert!(self.field_index.is_none()); + self.field_index = Some(*field); + } + } + } + } +} + +struct ReplaceLocalWithGeneratorFieldAccess<'tcx> { + tcx: TyCtxt<'tcx>, + upvar_locals: FxHashMap>, +} + +impl<'tcx> MutVisitor<'tcx> for ReplaceLocalWithGeneratorFieldAccess<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { + for (statement_index, statement) in data.statements.iter_mut().enumerate() { + if let StatementKind::Assign(box (place, _rvalue)) = &statement.kind { + // Upvar was stored into a local => turn into nop + if self.upvar_locals.contains_key(&place.local) { + *statement = Statement { source_info: statement.source_info, kind: StatementKind::Nop }; + } + } + self.visit_statement(statement, Location { block, statement_index }); + } + } + + fn visit_place(&mut self, source: &mut Place<'tcx>, _context: PlaceContext, _location: Location) { + if let Some(rvalue) = self.upvar_locals.get(&source.local) { + match rvalue { + Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { + tracing::info!("Replacing {source:#?} with {place:#?}"); + *source = *place; + } + _ => {} + } + } + } +} + + /// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields /// in the generator state machine but whose storage is not marked as conflicting ///