Skip to content

Commit ccc1bf7

Browse files
authored
Rollup merge of #73757 - oli-obk:const_prop_hardening, r=wesleywiser
Const prop: erase all block-only locals at the end of every block I messed up this erasure in #73656 (comment). I think it is too fragile to have the previous scheme. Let's benchmark the new scheme and see what happens. r? @wesleywiser cc @felix91gr
2 parents ec48989 + b9f4e0d commit ccc1bf7

File tree

5 files changed

+90
-16
lines changed

5 files changed

+90
-16
lines changed

src/librustc_mir/interpret/eval_context.rs

+7
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ pub enum LocalValue<Tag = ()> {
132132
}
133133

134134
impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
135+
/// Read the local's value or error if the local is not yet live or not live anymore.
136+
///
137+
/// Note: This may only be invoked from the `Machine::access_local` hook and not from
138+
/// anywhere else. You may be invalidating machine invariants if you do!
135139
pub fn access(&self) -> InterpResult<'tcx, Operand<Tag>> {
136140
match self.value {
137141
LocalValue::Dead => throw_ub!(DeadLocal),
@@ -144,6 +148,9 @@ impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
144148

145149
/// Overwrite the local. If the local can be overwritten in place, return a reference
146150
/// to do so; otherwise return the `MemPlace` to consult instead.
151+
///
152+
/// Note: This may only be invoked from the `Machine::access_local_mut` hook and not from
153+
/// anywhere else. You may be invalidating machine invariants if you do!
147154
pub fn access_mut(
148155
&mut self,
149156
) -> InterpResult<'tcx, Result<&mut LocalValue<Tag>, MemPlace<Tag>>> {

src/librustc_mir/interpret/machine.rs

+18-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use rustc_span::def_id::DefId;
1111

1212
use super::{
1313
AllocId, Allocation, AllocationExtra, CheckInAllocMsg, Frame, ImmTy, InterpCx, InterpResult,
14-
Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
14+
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
1515
};
1616

1717
/// Data returned by Machine::stack_pop,
@@ -192,6 +192,8 @@ pub trait Machine<'mir, 'tcx>: Sized {
192192
) -> InterpResult<'tcx>;
193193

194194
/// Called to read the specified `local` from the `frame`.
195+
/// Since reading a ZST is not actually accessing memory or locals, this is never invoked
196+
/// for ZST reads.
195197
#[inline]
196198
fn access_local(
197199
_ecx: &InterpCx<'mir, 'tcx, Self>,
@@ -201,6 +203,21 @@ pub trait Machine<'mir, 'tcx>: Sized {
201203
frame.locals[local].access()
202204
}
203205

206+
/// Called to write the specified `local` from the `frame`.
207+
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
208+
/// for ZST reads.
209+
#[inline]
210+
fn access_local_mut<'a>(
211+
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
212+
frame: usize,
213+
local: mir::Local,
214+
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
215+
where
216+
'tcx: 'mir,
217+
{
218+
ecx.stack_mut()[frame].locals[local].access_mut()
219+
}
220+
204221
/// Called before a basic block terminator is executed.
205222
/// You can use this to detect endlessly running programs.
206223
#[inline]

src/librustc_mir/interpret/operand.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
432432
})
433433
}
434434

435-
/// This is used by [priroda](https://github.com/oli-obk/priroda) to get an OpTy from a local
435+
/// Read from a local. Will not actually access the local if reading from a ZST.
436+
/// Will not access memory, instead an indirect `Operand` is returned.
437+
///
438+
/// This is public because it is used by [priroda](https://github.com/oli-obk/priroda) to get an
439+
/// OpTy from a local
436440
pub fn access_local(
437441
&self,
438442
frame: &super::Frame<'mir, 'tcx, M::PointerTag, M::FrameExtra>,

src/librustc_mir/interpret/place.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ where
741741
// but not factored as a separate function.
742742
let mplace = match dest.place {
743743
Place::Local { frame, local } => {
744-
match self.stack_mut()[frame].locals[local].access_mut()? {
744+
match M::access_local_mut(self, frame, local)? {
745745
Ok(local) => {
746746
// Local can be updated in-place.
747747
*local = LocalValue::Live(Operand::Immediate(src));
@@ -974,7 +974,7 @@ where
974974
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, Option<Size>)> {
975975
let (mplace, size) = match place.place {
976976
Place::Local { frame, local } => {
977-
match self.stack_mut()[frame].locals[local].access_mut()? {
977+
match M::access_local_mut(self, frame, local)? {
978978
Ok(&mut local_val) => {
979979
// We need to make an allocation.
980980

@@ -998,7 +998,7 @@ where
998998
}
999999
// Now we can call `access_mut` again, asserting it goes well,
10001000
// and actually overwrite things.
1001-
*self.stack_mut()[frame].locals[local].access_mut().unwrap().unwrap() =
1001+
*M::access_local_mut(self, frame, local).unwrap().unwrap() =
10021002
LocalValue::Live(Operand::Indirect(mplace));
10031003
(mplace, Some(size))
10041004
}

src/librustc_mir/transform/const_prop.rs

+57-11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use std::cell::Cell;
55

66
use rustc_ast::ast::Mutability;
7+
use rustc_data_structures::fx::FxHashSet;
78
use rustc_hir::def::DefKind;
89
use rustc_hir::HirId;
910
use rustc_index::bit_set::BitSet;
@@ -28,7 +29,7 @@ use rustc_trait_selection::traits;
2829
use crate::const_eval::error_to_const_error;
2930
use crate::interpret::{
3031
self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx, LocalState,
31-
LocalValue, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
32+
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
3233
ScalarMaybeUninit, StackPopCleanup,
3334
};
3435
use crate::transform::{MirPass, MirSource};
@@ -151,11 +152,19 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
151152
struct ConstPropMachine<'mir, 'tcx> {
152153
/// The virtual call stack.
153154
stack: Vec<Frame<'mir, 'tcx, (), ()>>,
155+
/// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end.
156+
written_only_inside_own_block_locals: FxHashSet<Local>,
157+
/// Locals that need to be cleared after every block terminates.
158+
only_propagate_inside_block_locals: BitSet<Local>,
154159
}
155160

156161
impl<'mir, 'tcx> ConstPropMachine<'mir, 'tcx> {
157-
fn new() -> Self {
158-
Self { stack: Vec::new() }
162+
fn new(only_propagate_inside_block_locals: BitSet<Local>) -> Self {
163+
Self {
164+
stack: Vec::new(),
165+
written_only_inside_own_block_locals: Default::default(),
166+
only_propagate_inside_block_locals,
167+
}
159168
}
160169
}
161170

@@ -227,6 +236,18 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx>
227236
l.access()
228237
}
229238

239+
fn access_local_mut<'a>(
240+
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
241+
frame: usize,
242+
local: Local,
243+
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
244+
{
245+
if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) {
246+
ecx.machine.written_only_inside_own_block_locals.insert(local);
247+
}
248+
ecx.machine.stack[frame].locals[local].access_mut()
249+
}
250+
230251
fn before_access_global(
231252
_memory_extra: &(),
232253
_alloc_id: AllocId,
@@ -274,8 +295,6 @@ struct ConstPropagator<'mir, 'tcx> {
274295
// Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store
275296
// the last known `SourceInfo` here and just keep revisiting it.
276297
source_info: Option<SourceInfo>,
277-
// Locals we need to forget at the end of the current block
278-
locals_of_current_block: BitSet<Local>,
279298
}
280299

281300
impl<'mir, 'tcx> LayoutOf for ConstPropagator<'mir, 'tcx> {
@@ -313,8 +332,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
313332
let param_env = tcx.param_env(def_id).with_reveal_all();
314333

315334
let span = tcx.def_span(def_id);
316-
let mut ecx = InterpCx::new(tcx, span, param_env, ConstPropMachine::new(), ());
317335
let can_const_prop = CanConstProp::check(body);
336+
let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len());
337+
for (l, mode) in can_const_prop.iter_enumerated() {
338+
if *mode == ConstPropMode::OnlyInsideOwnBlock {
339+
only_propagate_inside_block_locals.insert(l);
340+
}
341+
}
342+
let mut ecx = InterpCx::new(
343+
tcx,
344+
span,
345+
param_env,
346+
ConstPropMachine::new(only_propagate_inside_block_locals),
347+
(),
348+
);
318349

319350
let ret = ecx
320351
.layout_of(body.return_ty().subst(tcx, substs))
@@ -345,7 +376,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
345376
//FIXME(wesleywiser) we can't steal this because `Visitor::super_visit_body()` needs it
346377
local_decls: body.local_decls.clone(),
347378
source_info: None,
348-
locals_of_current_block: BitSet::new_empty(body.local_decls.len()),
349379
}
350380
}
351381

@@ -900,7 +930,6 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
900930
Will remove it from const-prop after block is finished. Local: {:?}",
901931
place.local
902932
);
903-
self.locals_of_current_block.insert(place.local);
904933
}
905934
ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
906935
trace!("can't propagate into {:?}", place);
@@ -1089,10 +1118,27 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
10891118
}
10901119
}
10911120
}
1092-
// We remove all Locals which are restricted in propagation to their containing blocks.
1093-
for local in self.locals_of_current_block.iter() {
1121+
1122+
// We remove all Locals which are restricted in propagation to their containing blocks and
1123+
// which were modified in the current block.
1124+
// Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`
1125+
let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals);
1126+
for &local in locals.iter() {
10941127
Self::remove_const(&mut self.ecx, local);
10951128
}
1096-
self.locals_of_current_block.clear();
1129+
locals.clear();
1130+
// Put it back so we reuse the heap of the storage
1131+
self.ecx.machine.written_only_inside_own_block_locals = locals;
1132+
if cfg!(debug_assertions) {
1133+
// Ensure we are correctly erasing locals with the non-debug-assert logic.
1134+
for local in self.ecx.machine.only_propagate_inside_block_locals.iter() {
1135+
assert!(
1136+
self.get_const(local.into()).is_none()
1137+
|| self
1138+
.layout_of(self.local_decls[local].ty)
1139+
.map_or(true, |layout| layout.is_zst())
1140+
)
1141+
}
1142+
}
10971143
}
10981144
}

0 commit comments

Comments
 (0)