Skip to content

Commit 10e71df

Browse files
committed
Also validate types before inlining.
1 parent e4fd2bd commit 10e71df

File tree

2 files changed

+114
-16
lines changed

2 files changed

+114
-16
lines changed

compiler/rustc_const_eval/src/transform/validate.rs

+14-16
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,20 @@ pub fn equal_up_to_regions<'tcx>(
8989

9090
// Normalize lifetimes away on both sides, then compare.
9191
let normalize = |ty: Ty<'tcx>| {
92-
tcx.normalize_erasing_regions(
93-
param_env,
94-
ty.fold_with(&mut BottomUpFolder {
95-
tcx,
96-
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
97-
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
98-
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
99-
// since one may have an `impl SomeTrait for fn(&32)` and
100-
// `impl SomeTrait for fn(&'static u32)` at the same time which
101-
// specify distinct values for Assoc. (See also #56105)
102-
lt_op: |_| tcx.lifetimes.re_erased,
103-
// Leave consts and types unchanged.
104-
ct_op: |ct| ct,
105-
ty_op: |ty| ty,
106-
}),
107-
)
92+
let ty = ty.fold_with(&mut BottomUpFolder {
93+
tcx,
94+
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
95+
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
96+
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
97+
// since one may have an `impl SomeTrait for fn(&32)` and
98+
// `impl SomeTrait for fn(&'static u32)` at the same time which
99+
// specify distinct values for Assoc. (See also #56105)
100+
lt_op: |_| tcx.lifetimes.re_erased,
101+
// Leave consts and types unchanged.
102+
ct_op: |ct| ct,
103+
ty_op: |ty| ty,
104+
});
105+
tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty)
108106
};
109107
tcx.infer_ctxt().enter(|infcx| infcx.can_eq(param_env, normalize(src), normalize(dest)).is_ok())
110108
}

compiler/rustc_mir_transform/src/inline.rs

+100
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rustc_middle::ty::subst::Subst;
1111
use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
1212
use rustc_session::config::OptLevel;
1313
use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
14+
use rustc_target::abi::VariantIdx;
1415
use rustc_target::spec::abi::Abi;
1516

1617
use super::simplify::{remove_dead_blocks, CfgSimplifier};
@@ -423,6 +424,7 @@ impl<'tcx> Inliner<'tcx> {
423424
instance: callsite.callee,
424425
callee_body,
425426
cost: 0,
427+
validation: Ok(()),
426428
};
427429

428430
// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
@@ -458,6 +460,9 @@ impl<'tcx> Inliner<'tcx> {
458460
checker.visit_local_decl(v, &callee_body.local_decls[v]);
459461
}
460462

463+
// Abort if type validation found anything fishy.
464+
checker.validation?;
465+
461466
let cost = checker.cost;
462467
if let InlineAttr::Always = callee_attrs.inline {
463468
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
@@ -738,6 +743,7 @@ struct CostChecker<'b, 'tcx> {
738743
cost: usize,
739744
callee_body: &'b Body<'tcx>,
740745
instance: ty::Instance<'tcx>,
746+
validation: Result<(), &'static str>,
741747
}
742748

743749
impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
@@ -818,6 +824,100 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
818824

819825
self.super_local_decl(local, local_decl)
820826
}
827+
828+
/// This method duplicates code from MIR validation in an attempt to detect type mismatches due
829+
/// to normalization failure.
830+
fn visit_projection_elem(
831+
&mut self,
832+
local: Local,
833+
proj_base: &[PlaceElem<'tcx>],
834+
elem: PlaceElem<'tcx>,
835+
context: PlaceContext,
836+
location: Location,
837+
) {
838+
if let ProjectionElem::Field(f, ty) = elem {
839+
let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
840+
let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
841+
let check_equal = |this: &mut Self, f_ty| {
842+
if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
843+
trace!(?ty, ?f_ty);
844+
this.validation = Err("failed to normalize projection type");
845+
return;
846+
}
847+
};
848+
849+
let kind = match parent_ty.ty.kind() {
850+
&ty::Opaque(def_id, substs) => {
851+
self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind()
852+
}
853+
kind => kind,
854+
};
855+
856+
match kind {
857+
ty::Tuple(fields) => {
858+
let Some(f_ty) = fields.get(f.as_usize()) else {
859+
self.validation = Err("malformed MIR");
860+
return;
861+
};
862+
check_equal(self, *f_ty);
863+
}
864+
ty::Adt(adt_def, substs) => {
865+
let var = parent_ty.variant_index.unwrap_or(VariantIdx::from_u32(0));
866+
let Some(field) = adt_def.variant(var).fields.get(f.as_usize()) else {
867+
self.validation = Err("malformed MIR");
868+
return;
869+
};
870+
check_equal(self, field.ty(self.tcx, substs));
871+
}
872+
ty::Closure(_, substs) => {
873+
let substs = substs.as_closure();
874+
let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else {
875+
self.validation = Err("malformed MIR");
876+
return;
877+
};
878+
check_equal(self, f_ty);
879+
}
880+
&ty::Generator(def_id, substs, _) => {
881+
let f_ty = if let Some(var) = parent_ty.variant_index {
882+
let gen_body = if def_id == self.callee_body.source.def_id() {
883+
self.callee_body
884+
} else {
885+
self.tcx.optimized_mir(def_id)
886+
};
887+
888+
let Some(layout) = gen_body.generator_layout() else {
889+
self.validation = Err("malformed MIR");
890+
return;
891+
};
892+
893+
let Some(&local) = layout.variant_fields[var].get(f) else {
894+
self.validation = Err("malformed MIR");
895+
return;
896+
};
897+
898+
let Some(&f_ty) = layout.field_tys.get(local) else {
899+
self.validation = Err("malformed MIR");
900+
return;
901+
};
902+
903+
f_ty
904+
} else {
905+
let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
906+
self.validation = Err("malformed MIR");
907+
return;
908+
};
909+
910+
f_ty
911+
};
912+
913+
check_equal(self, f_ty);
914+
}
915+
_ => self.validation = Err("malformed MIR"),
916+
}
917+
}
918+
919+
self.super_projection_elem(local, proj_base, elem, context, location);
920+
}
821921
}
822922

823923
/**

0 commit comments

Comments
 (0)