diff --git a/compiler/stable_mir/src/mir.rs b/compiler/stable_mir/src/mir.rs index 82555461d644e..413b5152bb3b3 100644 --- a/compiler/stable_mir/src/mir.rs +++ b/compiler/stable_mir/src/mir.rs @@ -5,4 +5,4 @@ pub mod pretty; pub mod visit; pub use body::*; -pub use visit::MirVisitor; +pub use visit::{MirVisitor, MutMirVisitor}; diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs index f8b46f50a5298..b1bf7bce828c3 100644 --- a/compiler/stable_mir/src/mir/body.rs +++ b/compiler/stable_mir/src/mir/body.rs @@ -77,6 +77,22 @@ impl Body { &self.locals[self.arg_count + 1..] } + /// Returns a mutable reference to the local that holds this function's return value. + pub(crate) fn ret_local_mut(&mut self) -> &mut LocalDecl { + &mut self.locals[RETURN_LOCAL] + } + + /// Returns a mutable slice of locals corresponding to this function's arguments. + pub(crate) fn arg_locals_mut(&mut self) -> &mut [LocalDecl] { + &mut self.locals[1..][..self.arg_count] + } + + /// Returns a mutable slice of inner locals for this function. + /// Inner locals are those that are neither the return local nor the argument locals. + pub(crate) fn inner_locals_mut(&mut self) -> &mut [LocalDecl] { + &mut self.locals[self.arg_count + 1..] + } + /// Convenience function to get all the locals in this function. /// /// Locals are typically accessed via the more specific methods `ret_local`, diff --git a/compiler/stable_mir/src/mir/visit.rs b/compiler/stable_mir/src/mir/visit.rs index d985a98fcbf01..09447e70dfb18 100644 --- a/compiler/stable_mir/src/mir/visit.rs +++ b/compiler/stable_mir/src/mir/visit.rs @@ -39,406 +39,486 @@ use crate::mir::*; use crate::ty::{GenericArgs, MirConst, Region, Ty, TyConst}; use crate::{Error, Opaque, Span}; -pub trait MirVisitor { - fn visit_body(&mut self, body: &Body) { - self.super_body(body) - } - - fn visit_basic_block(&mut self, bb: &BasicBlock) { - self.super_basic_block(bb) - } - - fn visit_ret_decl(&mut self, local: Local, decl: &LocalDecl) { - self.super_ret_decl(local, decl) - } - - fn visit_arg_decl(&mut self, local: Local, decl: &LocalDecl) { - self.super_arg_decl(local, decl) - } - - fn visit_local_decl(&mut self, local: Local, decl: &LocalDecl) { - self.super_local_decl(local, decl) - } - - fn visit_statement(&mut self, stmt: &Statement, location: Location) { - self.super_statement(stmt, location) - } - - fn visit_terminator(&mut self, term: &Terminator, location: Location) { - self.super_terminator(term, location) - } - - fn visit_span(&mut self, span: &Span) { - self.super_span(span) - } +macro_rules! make_mir_visitor { + ($visitor_trait_name:ident, $($mutability:ident)?) => { + pub trait $visitor_trait_name { + fn visit_body(&mut self, body: &$($mutability)? Body) { + self.super_body(body) + } - fn visit_place(&mut self, place: &Place, ptx: PlaceContext, location: Location) { - self.super_place(place, ptx, location) - } + fn visit_basic_block(&mut self, bb: &$($mutability)? BasicBlock) { + self.super_basic_block(bb) + } - fn visit_projection_elem( - &mut self, - place_ref: PlaceRef<'_>, - elem: &ProjectionElem, - ptx: PlaceContext, - location: Location, - ) { - let _ = place_ref; - self.super_projection_elem(elem, ptx, location); - } + fn visit_ret_decl(&mut self, local: Local, decl: &$($mutability)? LocalDecl) { + self.super_ret_decl(local, decl) + } - fn visit_local(&mut self, local: &Local, ptx: PlaceContext, location: Location) { - let _ = (local, ptx, location); - } + fn visit_arg_decl(&mut self, local: Local, decl: &$($mutability)? LocalDecl) { + self.super_arg_decl(local, decl) + } - fn visit_rvalue(&mut self, rvalue: &Rvalue, location: Location) { - self.super_rvalue(rvalue, location) - } + fn visit_local_decl(&mut self, local: Local, decl: &$($mutability)? LocalDecl) { + self.super_local_decl(local, decl) + } - fn visit_operand(&mut self, operand: &Operand, location: Location) { - self.super_operand(operand, location) - } + fn visit_statement(&mut self, stmt: &$($mutability)? Statement, location: Location) { + self.super_statement(stmt, location) + } - fn visit_user_type_projection(&mut self, projection: &UserTypeProjection) { - self.super_user_type_projection(projection) - } + fn visit_terminator(&mut self, term: &$($mutability)? Terminator, location: Location) { + self.super_terminator(term, location) + } - fn visit_ty(&mut self, ty: &Ty, location: Location) { - let _ = location; - self.super_ty(ty) - } + fn visit_span(&mut self, span: &$($mutability)? Span) { + self.super_span(span) + } - fn visit_const_operand(&mut self, constant: &ConstOperand, location: Location) { - self.super_const_operand(constant, location) - } + fn visit_place(&mut self, place: &$($mutability)? Place, ptx: PlaceContext, location: Location) { + self.super_place(place, ptx, location) + } - fn visit_mir_const(&mut self, constant: &MirConst, location: Location) { - self.super_mir_const(constant, location) - } + visit_place_fns!($($mutability)?); - fn visit_ty_const(&mut self, constant: &TyConst, location: Location) { - let _ = location; - self.super_ty_const(constant) - } + fn visit_local(&mut self, local: &$($mutability)? Local, ptx: PlaceContext, location: Location) { + let _ = (local, ptx, location); + } - fn visit_region(&mut self, region: &Region, location: Location) { - let _ = location; - self.super_region(region) - } + fn visit_rvalue(&mut self, rvalue: &$($mutability)? Rvalue, location: Location) { + self.super_rvalue(rvalue, location) + } - fn visit_args(&mut self, args: &GenericArgs, location: Location) { - let _ = location; - self.super_args(args) - } + fn visit_operand(&mut self, operand: &$($mutability)? Operand, location: Location) { + self.super_operand(operand, location) + } - fn visit_assert_msg(&mut self, msg: &AssertMessage, location: Location) { - self.super_assert_msg(msg, location) - } + fn visit_user_type_projection(&mut self, projection: &$($mutability)? UserTypeProjection) { + self.super_user_type_projection(projection) + } - fn visit_var_debug_info(&mut self, var_debug_info: &VarDebugInfo) { - self.super_var_debug_info(var_debug_info); - } + fn visit_ty(&mut self, ty: &$($mutability)? Ty, location: Location) { + let _ = location; + self.super_ty(ty) + } - fn super_body(&mut self, body: &Body) { - let Body { blocks, locals: _, arg_count, var_debug_info, spread_arg: _, span } = body; + fn visit_const_operand(&mut self, constant: &$($mutability)? ConstOperand, location: Location) { + self.super_const_operand(constant, location) + } - for bb in blocks { - self.visit_basic_block(bb); - } + fn visit_mir_const(&mut self, constant: &$($mutability)? MirConst, location: Location) { + self.super_mir_const(constant, location) + } - self.visit_ret_decl(RETURN_LOCAL, body.ret_local()); + fn visit_ty_const(&mut self, constant: &$($mutability)? TyConst, location: Location) { + let _ = location; + self.super_ty_const(constant) + } - for (idx, arg) in body.arg_locals().iter().enumerate() { - self.visit_arg_decl(idx + 1, arg) - } + fn visit_region(&mut self, region: &$($mutability)? Region, location: Location) { + let _ = location; + self.super_region(region) + } - let local_start = arg_count + 1; - for (idx, arg) in body.inner_locals().iter().enumerate() { - self.visit_local_decl(idx + local_start, arg) - } + fn visit_args(&mut self, args: &$($mutability)? GenericArgs, location: Location) { + let _ = location; + self.super_args(args) + } - for info in var_debug_info.iter() { - self.visit_var_debug_info(info); - } + fn visit_assert_msg(&mut self, msg: &$($mutability)? AssertMessage, location: Location) { + self.super_assert_msg(msg, location) + } - self.visit_span(span) - } + fn visit_var_debug_info(&mut self, var_debug_info: &$($mutability)? VarDebugInfo) { + self.super_var_debug_info(var_debug_info); + } - fn super_basic_block(&mut self, bb: &BasicBlock) { - let BasicBlock { statements, terminator } = bb; - for stmt in statements { - self.visit_statement(stmt, Location(stmt.span)); - } - self.visit_terminator(terminator, Location(terminator.span)); - } + fn super_body(&mut self, body: &$($mutability)? Body) { + super_body!(self, body, $($mutability)?); + } - fn super_local_decl(&mut self, local: Local, decl: &LocalDecl) { - let _ = local; - let LocalDecl { ty, span, .. } = decl; - self.visit_ty(ty, Location(*span)); - } + fn super_basic_block(&mut self, bb: &$($mutability)? BasicBlock) { + let BasicBlock { statements, terminator } = bb; + for stmt in statements { + self.visit_statement(stmt, Location(stmt.span)); + } + self.visit_terminator(terminator, Location(terminator.span)); + } - fn super_ret_decl(&mut self, local: Local, decl: &LocalDecl) { - self.super_local_decl(local, decl) - } + fn super_local_decl(&mut self, local: Local, decl: &$($mutability)? LocalDecl) { + let _ = local; + let LocalDecl { ty, span, .. } = decl; + self.visit_ty(ty, Location(*span)); + } - fn super_arg_decl(&mut self, local: Local, decl: &LocalDecl) { - self.super_local_decl(local, decl) - } + fn super_ret_decl(&mut self, local: Local, decl: &$($mutability)? LocalDecl) { + self.super_local_decl(local, decl) + } - fn super_statement(&mut self, stmt: &Statement, location: Location) { - let Statement { kind, span } = stmt; - self.visit_span(span); - match kind { - StatementKind::Assign(place, rvalue) => { - self.visit_place(place, PlaceContext::MUTATING, location); - self.visit_rvalue(rvalue, location); - } - StatementKind::FakeRead(_, place) | StatementKind::PlaceMention(place) => { - self.visit_place(place, PlaceContext::NON_MUTATING, location); - } - StatementKind::SetDiscriminant { place, .. } - | StatementKind::Deinit(place) - | StatementKind::Retag(_, place) => { - self.visit_place(place, PlaceContext::MUTATING, location); - } - StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { - self.visit_local(local, PlaceContext::NON_USE, location); - } - StatementKind::AscribeUserType { place, projections, variance: _ } => { - self.visit_place(place, PlaceContext::NON_USE, location); - self.visit_user_type_projection(projections); - } - StatementKind::Coverage(coverage) => visit_opaque(coverage), - StatementKind::Intrinsic(intrisic) => match intrisic { - NonDivergingIntrinsic::Assume(operand) => { - self.visit_operand(operand, location); - } - NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { - src, - dst, - count, - }) => { - self.visit_operand(src, location); - self.visit_operand(dst, location); - self.visit_operand(count, location); - } - }, - StatementKind::ConstEvalCounter | StatementKind::Nop => {} - } - } + fn super_arg_decl(&mut self, local: Local, decl: &$($mutability)? LocalDecl) { + self.super_local_decl(local, decl) + } - fn super_terminator(&mut self, term: &Terminator, location: Location) { - let Terminator { kind, span } = term; - self.visit_span(span); - match kind { - TerminatorKind::Goto { .. } - | TerminatorKind::Resume - | TerminatorKind::Abort - | TerminatorKind::Unreachable => {} - TerminatorKind::Assert { cond, expected: _, msg, target: _, unwind: _ } => { - self.visit_operand(cond, location); - self.visit_assert_msg(msg, location); - } - TerminatorKind::Drop { place, target: _, unwind: _ } => { - self.visit_place(place, PlaceContext::MUTATING, location); - } - TerminatorKind::Call { func, args, destination, target: _, unwind: _ } => { - self.visit_operand(func, location); - for arg in args { - self.visit_operand(arg, location); + fn super_statement(&mut self, stmt: &$($mutability)? Statement, location: Location) { + let Statement { kind, span } = stmt; + self.visit_span(span); + match kind { + StatementKind::Assign(place, rvalue) => { + self.visit_place(place, PlaceContext::MUTATING, location); + self.visit_rvalue(rvalue, location); + } + StatementKind::FakeRead(_, place) | StatementKind::PlaceMention(place) => { + self.visit_place(place, PlaceContext::NON_MUTATING, location); + } + StatementKind::SetDiscriminant { place, .. } + | StatementKind::Deinit(place) + | StatementKind::Retag(_, place) => { + self.visit_place(place, PlaceContext::MUTATING, location); + } + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + self.visit_local(local, PlaceContext::NON_USE, location); + } + StatementKind::AscribeUserType { place, projections, variance: _ } => { + self.visit_place(place, PlaceContext::NON_USE, location); + self.visit_user_type_projection(projections); + } + StatementKind::Coverage(coverage) => visit_opaque(coverage), + StatementKind::Intrinsic(intrisic) => match intrisic { + NonDivergingIntrinsic::Assume(operand) => { + self.visit_operand(operand, location); + } + NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { + src, + dst, + count, + }) => { + self.visit_operand(src, location); + self.visit_operand(dst, location); + self.visit_operand(count, location); + } + }, + StatementKind::ConstEvalCounter | StatementKind::Nop => {} } - self.visit_place(destination, PlaceContext::MUTATING, location); } - TerminatorKind::InlineAsm { operands, .. } => { - for op in operands { - let InlineAsmOperand { in_value, out_place, raw_rpr: _ } = op; - if let Some(input) = in_value { - self.visit_operand(input, location); + + fn super_terminator(&mut self, term: &$($mutability)? Terminator, location: Location) { + let Terminator { kind, span } = term; + self.visit_span(span); + match kind { + TerminatorKind::Goto { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Unreachable => {} + TerminatorKind::Assert { cond, expected: _, msg, target: _, unwind: _ } => { + self.visit_operand(cond, location); + self.visit_assert_msg(msg, location); + } + TerminatorKind::Drop { place, target: _, unwind: _ } => { + self.visit_place(place, PlaceContext::MUTATING, location); + } + TerminatorKind::Call { func, args, destination, target: _, unwind: _ } => { + self.visit_operand(func, location); + for arg in args { + self.visit_operand(arg, location); + } + self.visit_place(destination, PlaceContext::MUTATING, location); + } + TerminatorKind::InlineAsm { operands, .. } => { + for op in operands { + let InlineAsmOperand { in_value, out_place, raw_rpr: _ } = op; + if let Some(input) = in_value { + self.visit_operand(input, location); + } + if let Some(output) = out_place { + self.visit_place(output, PlaceContext::MUTATING, location); + } + } } - if let Some(output) = out_place { - self.visit_place(output, PlaceContext::MUTATING, location); + TerminatorKind::Return => { + let $($mutability)? local = RETURN_LOCAL; + self.visit_local(&$($mutability)? local, PlaceContext::NON_MUTATING, location); + } + TerminatorKind::SwitchInt { discr, targets: _ } => { + self.visit_operand(discr, location); } } } - TerminatorKind::Return => { - let local = RETURN_LOCAL; - self.visit_local(&local, PlaceContext::NON_MUTATING, location); - } - TerminatorKind::SwitchInt { discr, targets: _ } => { - self.visit_operand(discr, location); - } - } - } - fn super_span(&mut self, span: &Span) { - let _ = span; - } - - fn super_place(&mut self, place: &Place, ptx: PlaceContext, location: Location) { - let _ = location; - let _ = ptx; - self.visit_local(&place.local, ptx, location); + fn super_span(&mut self, span: &$($mutability)? Span) { + let _ = span; + } - for (idx, elem) in place.projection.iter().enumerate() { - let place_ref = PlaceRef { local: place.local, projection: &place.projection[..idx] }; - self.visit_projection_elem(place_ref, elem, ptx, location); - } - } + fn super_rvalue(&mut self, rvalue: &$($mutability)? Rvalue, location: Location) { + match rvalue { + Rvalue::AddressOf(mutability, place) => { + let pcx = PlaceContext { is_mut: *mutability == RawPtrKind::Mut }; + self.visit_place(place, pcx, location); + } + Rvalue::Aggregate(_, operands) => { + for op in operands { + self.visit_operand(op, location); + } + } + Rvalue::BinaryOp(_, lhs, rhs) | Rvalue::CheckedBinaryOp(_, lhs, rhs) => { + self.visit_operand(lhs, location); + self.visit_operand(rhs, location); + } + Rvalue::Cast(_, op, ty) => { + self.visit_operand(op, location); + self.visit_ty(ty, location); + } + Rvalue::CopyForDeref(place) | Rvalue::Discriminant(place) | Rvalue::Len(place) => { + self.visit_place(place, PlaceContext::NON_MUTATING, location); + } + Rvalue::Ref(region, kind, place) => { + self.visit_region(region, location); + let pcx = PlaceContext { is_mut: matches!(kind, BorrowKind::Mut { .. }) }; + self.visit_place(place, pcx, location); + } + Rvalue::Repeat(op, constant) => { + self.visit_operand(op, location); + self.visit_ty_const(constant, location); + } + Rvalue::ShallowInitBox(op, ty) => { + self.visit_ty(ty, location); + self.visit_operand(op, location) + } + Rvalue::ThreadLocalRef(_) => {} + Rvalue::NullaryOp(_, ty) => { + self.visit_ty(ty, location); + } + Rvalue::UnaryOp(_, op) | Rvalue::Use(op) => { + self.visit_operand(op, location); + } + } + } - fn super_projection_elem( - &mut self, - elem: &ProjectionElem, - ptx: PlaceContext, - location: Location, - ) { - match elem { - ProjectionElem::Downcast(_idx) => {} - ProjectionElem::ConstantIndex { offset: _, min_length: _, from_end: _ } - | ProjectionElem::Deref - | ProjectionElem::Subslice { from: _, to: _, from_end: _ } => {} - ProjectionElem::Field(_idx, ty) => self.visit_ty(ty, location), - ProjectionElem::Index(local) => self.visit_local(local, ptx, location), - ProjectionElem::OpaqueCast(ty) | ProjectionElem::Subtype(ty) => { - self.visit_ty(ty, location) + fn super_operand(&mut self, operand: &$($mutability)? Operand, location: Location) { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + self.visit_place(place, PlaceContext::NON_MUTATING, location) + } + Operand::Constant(constant) => { + self.visit_const_operand(constant, location); + } + } } - } - } - fn super_rvalue(&mut self, rvalue: &Rvalue, location: Location) { - match rvalue { - Rvalue::AddressOf(mutability, place) => { - let pcx = PlaceContext { is_mut: *mutability == RawPtrKind::Mut }; - self.visit_place(place, pcx, location); + fn super_user_type_projection(&mut self, projection: &$($mutability)? UserTypeProjection) { + // This is a no-op on mir::Visitor. + let _ = projection; } - Rvalue::Aggregate(_, operands) => { - for op in operands { - self.visit_operand(op, location); - } + + fn super_ty(&mut self, ty: &$($mutability)? Ty) { + let _ = ty; } - Rvalue::BinaryOp(_, lhs, rhs) | Rvalue::CheckedBinaryOp(_, lhs, rhs) => { - self.visit_operand(lhs, location); - self.visit_operand(rhs, location); + + fn super_const_operand(&mut self, constant: &$($mutability)? ConstOperand, location: Location) { + let ConstOperand { span, user_ty: _, const_ } = constant; + self.visit_span(span); + self.visit_mir_const(const_, location); } - Rvalue::Cast(_, op, ty) => { - self.visit_operand(op, location); + + fn super_mir_const(&mut self, constant: &$($mutability)? MirConst, location: Location) { + let MirConst { kind: _, ty, id: _ } = constant; self.visit_ty(ty, location); } - Rvalue::CopyForDeref(place) | Rvalue::Discriminant(place) | Rvalue::Len(place) => { - self.visit_place(place, PlaceContext::NON_MUTATING, location); - } - Rvalue::Ref(region, kind, place) => { - self.visit_region(region, location); - let pcx = PlaceContext { is_mut: matches!(kind, BorrowKind::Mut { .. }) }; - self.visit_place(place, pcx, location); + + fn super_ty_const(&mut self, constant: &$($mutability)? TyConst) { + let _ = constant; } - Rvalue::Repeat(op, constant) => { - self.visit_operand(op, location); - self.visit_ty_const(constant, location); + + fn super_region(&mut self, region: &$($mutability)? Region) { + let _ = region; } - Rvalue::ShallowInitBox(op, ty) => { - self.visit_ty(ty, location); - self.visit_operand(op, location) + + fn super_args(&mut self, args: &$($mutability)? GenericArgs) { + let _ = args; } - Rvalue::ThreadLocalRef(_) => {} - Rvalue::NullaryOp(_, ty) => { - self.visit_ty(ty, location); + + fn super_var_debug_info(&mut self, var_debug_info: &$($mutability)? VarDebugInfo) { + let VarDebugInfo { source_info, composite, value, name: _, argument_index: _ } = + var_debug_info; + self.visit_span(&$($mutability)? source_info.span); + let location = Location(source_info.span); + if let Some(composite) = composite { + self.visit_ty(&$($mutability)? composite.ty, location); + } + match value { + VarDebugInfoContents::Place(place) => { + self.visit_place(place, PlaceContext::NON_USE, location); + } + VarDebugInfoContents::Const(constant) => { + self.visit_mir_const(&$($mutability)? constant.const_, location); + } + } } - Rvalue::UnaryOp(_, op) | Rvalue::Use(op) => { - self.visit_operand(op, location); + + fn super_assert_msg(&mut self, msg: &$($mutability)? AssertMessage, location: Location) { + match msg { + AssertMessage::BoundsCheck { len, index } => { + self.visit_operand(len, location); + self.visit_operand(index, location); + } + AssertMessage::Overflow(_, left, right) => { + self.visit_operand(left, location); + self.visit_operand(right, location); + } + AssertMessage::OverflowNeg(op) + | AssertMessage::DivisionByZero(op) + | AssertMessage::RemainderByZero(op) => { + self.visit_operand(op, location); + } + AssertMessage::ResumedAfterReturn(_) + | AssertMessage::ResumedAfterPanic(_) + | AssertMessage::NullPointerDereference => { + //nothing to visit + } + AssertMessage::MisalignedPointerDereference { required, found } => { + self.visit_operand(required, location); + self.visit_operand(found, location); + } + } } } - } + }; +} - fn super_operand(&mut self, operand: &Operand, location: Location) { - match operand { - Operand::Copy(place) | Operand::Move(place) => { - self.visit_place(place, PlaceContext::NON_MUTATING, location) - } - Operand::Constant(constant) => { - self.visit_const_operand(constant, location); - } +macro_rules! super_body { + ($self:ident, $body:ident, mut) => { + for bb in $body.blocks.iter_mut() { + $self.visit_basic_block(bb); } - } - fn super_user_type_projection(&mut self, projection: &UserTypeProjection) { - // This is a no-op on mir::Visitor. - let _ = projection; - } + $self.visit_ret_decl(RETURN_LOCAL, $body.ret_local_mut()); - fn super_ty(&mut self, ty: &Ty) { - let _ = ty; - } + for (idx, arg) in $body.arg_locals_mut().iter_mut().enumerate() { + $self.visit_arg_decl(idx + 1, arg) + } - fn super_const_operand(&mut self, constant: &ConstOperand, location: Location) { - let ConstOperand { span, user_ty: _, const_ } = constant; - self.visit_span(span); - self.visit_mir_const(const_, location); - } + let local_start = $body.arg_count + 1; + for (idx, arg) in $body.inner_locals_mut().iter_mut().enumerate() { + $self.visit_local_decl(idx + local_start, arg) + } - fn super_mir_const(&mut self, constant: &MirConst, location: Location) { - let MirConst { kind: _, ty, id: _ } = constant; - self.visit_ty(ty, location); - } + for info in $body.var_debug_info.iter_mut() { + $self.visit_var_debug_info(info); + } - fn super_ty_const(&mut self, constant: &TyConst) { - let _ = constant; - } + $self.visit_span(&mut $body.span) + }; - fn super_region(&mut self, region: &Region) { - let _ = region; - } + ($self:ident, $body:ident, ) => { + let Body { blocks, locals: _, arg_count, var_debug_info, spread_arg: _, span } = $body; - fn super_args(&mut self, args: &GenericArgs) { - let _ = args; - } + for bb in blocks { + $self.visit_basic_block(bb); + } - fn super_var_debug_info(&mut self, var_debug_info: &VarDebugInfo) { - let VarDebugInfo { source_info, composite, value, name: _, argument_index: _ } = - var_debug_info; - self.visit_span(&source_info.span); - let location = Location(source_info.span); - if let Some(composite) = composite { - self.visit_ty(&composite.ty, location); + $self.visit_ret_decl(RETURN_LOCAL, $body.ret_local()); + + for (idx, arg) in $body.arg_locals().iter().enumerate() { + $self.visit_arg_decl(idx + 1, arg) } - match value { - VarDebugInfoContents::Place(place) => { - self.visit_place(place, PlaceContext::NON_USE, location); - } - VarDebugInfoContents::Const(constant) => { - self.visit_mir_const(&constant.const_, location); - } + + let local_start = arg_count + 1; + for (idx, arg) in $body.inner_locals().iter().enumerate() { + $self.visit_local_decl(idx + local_start, arg) } - } - fn super_assert_msg(&mut self, msg: &AssertMessage, location: Location) { - match msg { - AssertMessage::BoundsCheck { len, index } => { - self.visit_operand(len, location); - self.visit_operand(index, location); - } - AssertMessage::Overflow(_, left, right) => { - self.visit_operand(left, location); - self.visit_operand(right, location); + for info in var_debug_info.iter() { + $self.visit_var_debug_info(info); + } + + $self.visit_span(span) + }; +} + +macro_rules! visit_place_fns { + (mut) => { + fn super_place(&mut self, place: &mut Place, ptx: PlaceContext, location: Location) { + self.visit_local(&mut place.local, ptx, location); + + for elem in place.projection.iter_mut() { + self.visit_projection_elem(elem, ptx, location); } - AssertMessage::OverflowNeg(op) - | AssertMessage::DivisionByZero(op) - | AssertMessage::RemainderByZero(op) => { - self.visit_operand(op, location); + } + + // We don't have to replicate the `process_projection()` like we did in + // `rustc_middle::mir::visit.rs` here because the `projection` field in `Place` + // of Stable-MIR is not an immutable borrow, unlike in `Place` of MIR. + fn visit_projection_elem( + &mut self, + elem: &mut ProjectionElem, + ptx: PlaceContext, + location: Location, + ) { + self.super_projection_elem(elem, ptx, location) + } + + fn super_projection_elem( + &mut self, + elem: &mut ProjectionElem, + ptx: PlaceContext, + location: Location, + ) { + match elem { + ProjectionElem::Deref => {} + ProjectionElem::Field(_idx, ty) => self.visit_ty(ty, location), + ProjectionElem::Index(local) => self.visit_local(local, ptx, location), + ProjectionElem::ConstantIndex { offset: _, min_length: _, from_end: _ } => {} + ProjectionElem::Subslice { from: _, to: _, from_end: _ } => {} + ProjectionElem::Downcast(_idx) => {} + ProjectionElem::OpaqueCast(ty) => self.visit_ty(ty, location), + ProjectionElem::Subtype(ty) => self.visit_ty(ty, location), } - AssertMessage::ResumedAfterReturn(_) - | AssertMessage::ResumedAfterPanic(_) - | AssertMessage::NullPointerDereference => { - //nothing to visit + } + }; + + () => { + fn super_place(&mut self, place: &Place, ptx: PlaceContext, location: Location) { + self.visit_local(&place.local, ptx, location); + + for (idx, elem) in place.projection.iter().enumerate() { + let place_ref = + PlaceRef { local: place.local, projection: &place.projection[..idx] }; + self.visit_projection_elem(place_ref, elem, ptx, location); } - AssertMessage::MisalignedPointerDereference { required, found } => { - self.visit_operand(required, location); - self.visit_operand(found, location); + } + + fn visit_projection_elem<'a>( + &mut self, + place_ref: PlaceRef<'a>, + elem: &ProjectionElem, + ptx: PlaceContext, + location: Location, + ) { + let _ = place_ref; + self.super_projection_elem(elem, ptx, location); + } + + fn super_projection_elem( + &mut self, + elem: &ProjectionElem, + ptx: PlaceContext, + location: Location, + ) { + match elem { + ProjectionElem::Deref => {} + ProjectionElem::Field(_idx, ty) => self.visit_ty(ty, location), + ProjectionElem::Index(local) => self.visit_local(local, ptx, location), + ProjectionElem::ConstantIndex { offset: _, min_length: _, from_end: _ } => {} + ProjectionElem::Subslice { from: _, to: _, from_end: _ } => {} + ProjectionElem::Downcast(_idx) => {} + ProjectionElem::OpaqueCast(ty) => self.visit_ty(ty, location), + ProjectionElem::Subtype(ty) => self.visit_ty(ty, location), } } - } + }; } +make_mir_visitor!(MirVisitor,); +make_mir_visitor!(MutMirVisitor, mut); + /// This function is a no-op that gets used to ensure this visitor is kept up-to-date. /// /// The idea is that whenever we replace an Opaque type by a real type, the compiler will fail diff --git a/tests/ui-fulldeps/stable-mir/smir_visitor.rs b/tests/ui-fulldeps/stable-mir/smir_visitor.rs index cffb41742b4e4..0a579a07cef16 100644 --- a/tests/ui-fulldeps/stable-mir/smir_visitor.rs +++ b/tests/ui-fulldeps/stable-mir/smir_visitor.rs @@ -18,6 +18,7 @@ extern crate stable_mir; use rustc_smir::rustc_internal; use stable_mir::mir::MirVisitor; +use stable_mir::mir::MutMirVisitor; use stable_mir::*; use std::collections::HashSet; use std::io::Write; @@ -99,6 +100,83 @@ impl<'a> mir::MirVisitor for TestVisitor<'a> { } } +fn test_mut_visitor() -> ControlFlow<()> { + let main_fn = stable_mir::entry_fn(); + let mut main_body = main_fn.unwrap().expect_body(); + let locals = main_body.locals().to_vec(); + let mut main_visitor = TestMutVisitor::collect(locals); + main_visitor.visit_body(&mut main_body); + assert!(main_visitor.ret_val.is_some()); + assert!(main_visitor.args.is_empty()); + assert!(main_visitor.tys.contains(&main_visitor.ret_val.unwrap().ty)); + assert!(!main_visitor.calls.is_empty()); + + let exit_fn = main_visitor.calls.last().unwrap(); + assert!(exit_fn.mangled_name().contains("exit_fn"), "Unexpected last function: {exit_fn:?}"); + + let mut exit_body = exit_fn.body().unwrap(); + let locals = exit_body.locals().to_vec(); + let mut exit_visitor = TestMutVisitor::collect(locals); + exit_visitor.visit_body(&mut exit_body); + assert!(exit_visitor.ret_val.is_some()); + assert_eq!(exit_visitor.args.len(), 1); + assert!(exit_visitor.tys.contains(&exit_visitor.ret_val.unwrap().ty)); + assert!(exit_visitor.tys.contains(&exit_visitor.args[0].ty)); + ControlFlow::Continue(()) +} + +struct TestMutVisitor { + locals: Vec, + pub tys: HashSet, + pub ret_val: Option, + pub args: Vec, + pub calls: Vec, +} + +impl TestMutVisitor { + fn collect(locals: Vec) -> TestMutVisitor { + let visitor = TestMutVisitor { + locals: locals, + tys: Default::default(), + ret_val: None, + args: vec![], + calls: vec![], + }; + visitor + } +} + +impl mir::MutMirVisitor for TestMutVisitor { + fn visit_ty(&mut self, ty: &mut ty::Ty, _location: mir::visit::Location) { + self.tys.insert(*ty); + self.super_ty(ty) + } + + fn visit_ret_decl(&mut self, local: mir::Local, decl: &mut mir::LocalDecl) { + assert!(local == mir::RETURN_LOCAL); + assert!(self.ret_val.is_none()); + self.ret_val = Some(decl.clone()); + self.super_ret_decl(local, decl); + } + + fn visit_arg_decl(&mut self, local: mir::Local, decl: &mut mir::LocalDecl) { + self.args.push(decl.clone()); + assert_eq!(local, self.args.len()); + self.super_arg_decl(local, decl); + } + + fn visit_terminator(&mut self, term: &mut mir::Terminator, location: mir::visit::Location) { + if let mir::TerminatorKind::Call { func, .. } = &mut term.kind { + let ty::TyKind::RigidTy(ty) = func.ty(&self.locals).unwrap().kind() else { + unreachable!() + }; + let ty::RigidTy::FnDef(def, args) = ty else { unreachable!() }; + self.calls.push(mir::mono::Instance::resolve(def, &args).unwrap()); + } + self.super_terminator(term, location); + } +} + /// This test will generate and analyze a dummy crate using the stable mir. /// For that, it will first write the dummy crate into a file. /// Then it will create a `StableMir` using custom arguments and then @@ -113,7 +191,8 @@ fn main() { CRATE_NAME.to_string(), path.to_string(), ]; - run!(args, test_visitor).unwrap(); + run!(args.clone(), test_visitor).unwrap(); + run!(args, test_mut_visitor).unwrap(); } fn generate_input(path: &str) -> std::io::Result<()> {