Skip to content

Commit 984afb2

Browse files
committed
Dest prop: Support removing writes when this unblocks optimizations
1 parent 01ef4b2 commit 984afb2

7 files changed

+188
-114
lines changed

compiler/rustc_mir_transform/src/dest_prop.rs

+92-49
Original file line numberDiff line numberDiff line change
@@ -208,22 +208,23 @@ impl<'tcx> MirPass<'tcx> for DestinationPropagation {
208208
// This is the set of merges we will apply this round. It is a subset of the candidates.
209209
let mut merges = FxHashMap::default();
210210

211-
for (src, candidates) in candidates.c.iter() {
212-
if merged_locals.contains(*src) {
211+
for (src, candidates) in candidates.c.drain() {
212+
if merged_locals.contains(src) {
213213
continue;
214214
}
215215
let Some(dest) =
216-
candidates.iter().find(|dest| !merged_locals.contains(**dest)) else {
216+
candidates.into_iter().find(|(dest, _)| !merged_locals.contains(*dest)) else {
217217
continue;
218218
};
219219
if !tcx.consider_optimizing(|| {
220220
format!("{} round {}", tcx.def_path_str(def_id), round_count)
221221
}) {
222222
break;
223223
}
224-
merges.insert(*src, *dest);
225-
merged_locals.insert(*src);
226-
merged_locals.insert(*dest);
224+
merged_locals.insert(src);
225+
merged_locals.insert(dest.0);
226+
merges.insert(src, dest.clone());
227+
merges.insert(dest.0, dest);
227228
}
228229
trace!(merging = ?merges);
229230

@@ -245,7 +246,7 @@ impl<'tcx> MirPass<'tcx> for DestinationPropagation {
245246
/// frequently. Everything with a `&'alloc` lifetime points into here.
246247
#[derive(Default)]
247248
struct Allocations {
248-
candidates: FxHashMap<Local, Vec<Local>>,
249+
candidates: FxHashMap<Local, Vec<(Local, Vec<Location>)>>,
249250
candidates_reverse: FxHashMap<Local, Vec<Local>>,
250251
write_info: WriteInfo,
251252
// PERF: Do this for `MaybeLiveLocals` allocations too.
@@ -267,7 +268,11 @@ struct Candidates<'alloc> {
267268
///
268269
/// We will still report that we would like to merge `_1` and `_2` in an attempt to allow us to
269270
/// remove that assignment.
270-
c: &'alloc mut FxHashMap<Local, Vec<Local>>,
271+
///
272+
/// Each candidate pair is associated with a `Vec<Location>`. If the candidate pair is accepted,
273+
/// all writes to either local at these locations must be removed. The writes will always be
274+
/// removable.
275+
c: &'alloc mut FxHashMap<Local, Vec<(Local, Vec<Location>)>>,
271276
/// A reverse index of the `c` set; if the `c` set contains `a => Place { local: b, proj }`,
272277
/// then this contains `b => a`.
273278
// PERF: Possibly these should be `SmallVec`s?
@@ -282,7 +287,7 @@ struct Candidates<'alloc> {
282287
fn apply_merges<'tcx>(
283288
body: &mut Body<'tcx>,
284289
tcx: TyCtxt<'tcx>,
285-
merges: &FxHashMap<Local, Local>,
290+
merges: &FxHashMap<Local, (Local, Vec<Location>)>,
286291
merged_locals: &BitSet<Local>,
287292
) {
288293
let mut merger = Merger { tcx, merges, merged_locals };
@@ -291,18 +296,27 @@ fn apply_merges<'tcx>(
291296

292297
struct Merger<'a, 'tcx> {
293298
tcx: TyCtxt<'tcx>,
294-
merges: &'a FxHashMap<Local, Local>,
299+
merges: &'a FxHashMap<Local, (Local, Vec<Location>)>,
295300
merged_locals: &'a BitSet<Local>,
296301
}
297302

303+
impl<'a, 'tcx> Merger<'a, 'tcx> {
304+
fn should_remove_write_at(&self, local: Local, location: Location) -> bool {
305+
let Some((_, to_remove)) = self.merges.get(&local) else {
306+
return false;
307+
};
308+
to_remove.contains(&location)
309+
}
310+
}
311+
298312
impl<'a, 'tcx> MutVisitor<'tcx> for Merger<'a, 'tcx> {
299313
fn tcx(&self) -> TyCtxt<'tcx> {
300314
self.tcx
301315
}
302316

303317
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
304318
if let Some(dest) = self.merges.get(local) {
305-
*local = *dest;
319+
*local = dest.0;
306320
}
307321
}
308322

@@ -332,10 +346,27 @@ impl<'a, 'tcx> MutVisitor<'tcx> for Merger<'a, 'tcx> {
332346
_ => {}
333347
}
334348
}
349+
StatementKind::Deinit(place) => {
350+
if self.should_remove_write_at(place.local, location) {
351+
statement.make_nop();
352+
}
353+
}
335354

336355
_ => {}
337356
}
338357
}
358+
359+
fn visit_operand(&mut self, op: &mut Operand<'tcx>, location: Location) {
360+
self.super_operand(op, location);
361+
match op {
362+
Operand::Move(place) => {
363+
if self.should_remove_write_at(place.local, location) {
364+
*op = Operand::Copy(*place);
365+
}
366+
}
367+
_ => (),
368+
}
369+
}
339370
}
340371

341372
//////////////////////////////////////////////////////////
@@ -356,30 +387,35 @@ struct FilterInformation<'a, 'body, 'alloc, 'tcx> {
356387
// through these methods, and not directly.
357388
impl<'alloc> Candidates<'alloc> {
358389
/// Just `Vec::retain`, but the condition is inverted and we add debugging output
359-
fn vec_filter_candidates(
390+
fn vec_modify_candidates(
360391
src: Local,
361-
v: &mut Vec<Local>,
362-
mut f: impl FnMut(Local) -> CandidateFilter,
392+
v: &mut Vec<(Local, Vec<Location>)>,
393+
mut f: impl FnMut(Local) -> CandidateModification,
363394
at: Location,
364395
) {
365-
v.retain(|dest| {
366-
let remove = f(*dest);
367-
if remove == CandidateFilter::Remove {
396+
v.retain_mut(|(dest, remove_writes)| match f(*dest) {
397+
CandidateModification::Remove => {
368398
trace!("eliminating {:?} => {:?} due to conflict at {:?}", src, dest, at);
399+
false
400+
}
401+
CandidateModification::RemoveWrite => {
402+
trace!("marking write for {:?} => {:?} as needing removing at {:?}", src, dest, at);
403+
remove_writes.push(at);
404+
true
369405
}
370-
remove == CandidateFilter::Keep
406+
CandidateModification::Keep => true,
371407
});
372408
}
373409

374410
/// `vec_filter_candidates` but for an `Entry`
375411
fn entry_filter_candidates(
376-
mut entry: OccupiedEntry<'_, Local, Vec<Local>>,
412+
mut entry: OccupiedEntry<'_, Local, Vec<(Local, Vec<Location>)>>,
377413
p: Local,
378-
f: impl FnMut(Local) -> CandidateFilter,
414+
f: impl FnMut(Local) -> CandidateModification,
379415
at: Location,
380416
) {
381417
let candidates = entry.get_mut();
382-
Self::vec_filter_candidates(p, candidates, f, at);
418+
Self::vec_modify_candidates(p, candidates, f, at);
383419
if candidates.len() == 0 {
384420
entry.remove();
385421
}
@@ -389,7 +425,7 @@ impl<'alloc> Candidates<'alloc> {
389425
fn filter_candidates_by(
390426
&mut self,
391427
p: Local,
392-
mut f: impl FnMut(Local) -> CandidateFilter,
428+
mut f: impl FnMut(Local) -> CandidateModification,
393429
at: Location,
394430
) {
395431
// Cover the cases where `p` appears as a `src`
@@ -403,7 +439,8 @@ impl<'alloc> Candidates<'alloc> {
403439
// We use `retain` here to remove the elements from the reverse set if we've removed the
404440
// matching candidate in the forward set.
405441
srcs.retain(|src| {
406-
if f(*src) == CandidateFilter::Keep {
442+
let modification = f(*src);
443+
if modification == CandidateModification::Keep {
407444
return true;
408445
}
409446
let Entry::Occupied(entry) = self.c.entry(*src) else {
@@ -413,18 +450,20 @@ impl<'alloc> Candidates<'alloc> {
413450
entry,
414451
*src,
415452
|dest| {
416-
if dest == p { CandidateFilter::Remove } else { CandidateFilter::Keep }
453+
if dest == p { modification } else { CandidateModification::Keep }
417454
},
418455
at,
419456
);
420-
false
457+
// Remove the src from the reverse set if we removed the candidate pair
458+
modification == CandidateModification::RemoveWrite
421459
});
422460
}
423461
}
424462

425463
#[derive(Copy, Clone, PartialEq, Eq)]
426-
enum CandidateFilter {
464+
enum CandidateModification {
427465
Keep,
466+
RemoveWrite,
428467
Remove,
429468
}
430469

@@ -483,31 +522,36 @@ impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> {
483522

484523
fn apply_conflicts(&mut self) {
485524
let writes = &self.write_info.writes;
486-
for p in writes {
525+
for &(p, is_removable) in writes {
526+
let modification = if is_removable {
527+
CandidateModification::RemoveWrite
528+
} else {
529+
CandidateModification::Remove
530+
};
487531
let other_skip = self.write_info.skip_pair.and_then(|(a, b)| {
488-
if a == *p {
532+
if a == p {
489533
Some(b)
490-
} else if b == *p {
534+
} else if b == p {
491535
Some(a)
492536
} else {
493537
None
494538
}
495539
});
496540
self.candidates.filter_candidates_by(
497-
*p,
541+
p,
498542
|q| {
499543
if Some(q) == other_skip {
500-
return CandidateFilter::Keep;
544+
return CandidateModification::Keep;
501545
}
502546
// It is possible that a local may be live for less than the
503547
// duration of a statement This happens in the case of function
504548
// calls or inline asm. Because of this, we also mark locals as
505549
// conflicting when both of them are written to in the same
506550
// statement.
507-
if self.live.contains(q) || writes.contains(&q) {
508-
CandidateFilter::Remove
551+
if self.live.contains(q) || writes.iter().any(|&(x, _)| x == q) {
552+
modification
509553
} else {
510-
CandidateFilter::Keep
554+
CandidateModification::Keep
511555
}
512556
},
513557
self.at,
@@ -519,7 +563,9 @@ impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> {
519563
/// Describes where a statement/terminator writes to
520564
#[derive(Default, Debug)]
521565
struct WriteInfo {
522-
writes: Vec<Local>,
566+
/// Which locals are written to. The `bool` is true if the write is "removable," ie if it comes
567+
/// from a `Operand::Move` or `Deinit`.
568+
writes: Vec<(Local, bool)>,
523569
/// If this pair of locals is a candidate pair, completely skip processing it during this
524570
/// statement. All other candidates are unaffected.
525571
skip_pair: Option<(Local, Local)>,
@@ -563,10 +609,11 @@ impl WriteInfo {
563609
| Rvalue::CopyForDeref(_) => (),
564610
}
565611
}
612+
StatementKind::Deinit(p) => {
613+
self.writes.push((p.local, true));
614+
}
566615
// Retags are technically also reads, but reporting them as a write suffices
567-
StatementKind::SetDiscriminant { place, .. }
568-
| StatementKind::Deinit(place)
569-
| StatementKind::Retag(_, place) => {
616+
StatementKind::SetDiscriminant { place, .. } | StatementKind::Retag(_, place) => {
570617
self.add_place(**place);
571618
}
572619
StatementKind::Intrinsic(_)
@@ -652,16 +699,12 @@ impl WriteInfo {
652699
}
653700

654701
fn add_place<'tcx>(&mut self, place: Place<'tcx>) {
655-
self.writes.push(place.local);
702+
self.writes.push((place.local, false));
656703
}
657704

658705
fn add_operand<'tcx>(&mut self, op: &Operand<'tcx>) {
659706
match op {
660-
// FIXME(JakobDegen): In a previous version, the `Move` case was incorrectly treated as
661-
// being a read only. This was unsound, however we cannot add a regression test because
662-
// it is not possible to set this off with current MIR. Once we have that ability, a
663-
// regression test should be added.
664-
Operand::Move(p) => self.add_place(*p),
707+
Operand::Move(p) => self.writes.push((p.local, true)),
665708
Operand::Copy(_) | Operand::Constant(_) => (),
666709
}
667710
}
@@ -716,7 +759,7 @@ fn places_to_candidate_pair<'tcx>(
716759
fn find_candidates<'alloc, 'tcx>(
717760
body: &Body<'tcx>,
718761
borrowed: &BitSet<Local>,
719-
candidates: &'alloc mut FxHashMap<Local, Vec<Local>>,
762+
candidates: &'alloc mut FxHashMap<Local, Vec<(Local, Vec<Location>)>>,
720763
candidates_reverse: &'alloc mut FxHashMap<Local, Vec<Local>>,
721764
) -> Candidates<'alloc> {
722765
candidates.clear();
@@ -730,16 +773,16 @@ fn find_candidates<'alloc, 'tcx>(
730773
}
731774
// Generate the reverse map
732775
for (src, cands) in candidates.iter() {
733-
for dest in cands.iter().copied() {
734-
candidates_reverse.entry(dest).or_default().push(*src);
776+
for (dest, _) in cands.iter() {
777+
candidates_reverse.entry(*dest).or_default().push(*src);
735778
}
736779
}
737780
Candidates { c: candidates, reverse: candidates_reverse }
738781
}
739782

740783
struct FindAssignments<'a, 'alloc, 'tcx> {
741784
body: &'a Body<'tcx>,
742-
candidates: &'alloc mut FxHashMap<Local, Vec<Local>>,
785+
candidates: &'alloc mut FxHashMap<Local, Vec<(Local, Vec<Location>)>>,
743786
borrowed: &'a BitSet<Local>,
744787
}
745788

@@ -766,7 +809,7 @@ impl<'tcx> Visitor<'tcx> for FindAssignments<'_, '_, 'tcx> {
766809
}
767810

768811
// We may insert duplicates here, but that's fine
769-
self.candidates.entry(src).or_default().push(dest);
812+
self.candidates.entry(src).or_default().push((dest, Vec::new()));
770813
}
771814
}
772815
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
- // MIR for `move_simple` before DestinationPropagation
2+
+ // MIR for `move_simple` after DestinationPropagation
3+
4+
fn move_simple(_1: i32) -> () {
5+
debug x => _1; // in scope 0 at $DIR/move.rs:+0:16: +0:17
6+
let mut _0: (); // return place in scope 0 at $DIR/move.rs:+0:24: +0:24
7+
let _2: (); // in scope 0 at $DIR/move.rs:+1:5: +1:19
8+
let mut _3: i32; // in scope 0 at $DIR/move.rs:+1:14: +1:15
9+
let mut _4: i32; // in scope 0 at $DIR/move.rs:+1:17: +1:18
10+
11+
bb0: {
12+
StorageLive(_2); // scope 0 at $DIR/move.rs:+1:5: +1:19
13+
- StorageLive(_3); // scope 0 at $DIR/move.rs:+1:14: +1:15
14+
- _3 = _1; // scope 0 at $DIR/move.rs:+1:14: +1:15
15+
- StorageLive(_4); // scope 0 at $DIR/move.rs:+1:17: +1:18
16+
- _4 = _1; // scope 0 at $DIR/move.rs:+1:17: +1:18
17+
- _2 = use_both(move _3, move _4) -> bb1; // scope 0 at $DIR/move.rs:+1:5: +1:19
18+
+ nop; // scope 0 at $DIR/move.rs:+1:14: +1:15
19+
+ nop; // scope 0 at $DIR/move.rs:+1:14: +1:15
20+
+ nop; // scope 0 at $DIR/move.rs:+1:17: +1:18
21+
+ nop; // scope 0 at $DIR/move.rs:+1:17: +1:18
22+
+ _2 = use_both(_1, _1) -> bb1; // scope 0 at $DIR/move.rs:+1:5: +1:19
23+
// mir::Constant
24+
// + span: $DIR/move.rs:8:5: 8:13
25+
// + literal: Const { ty: fn(i32, i32) {use_both}, val: Value(<ZST>) }
26+
}
27+
28+
bb1: {
29+
- StorageDead(_4); // scope 0 at $DIR/move.rs:+1:18: +1:19
30+
- StorageDead(_3); // scope 0 at $DIR/move.rs:+1:18: +1:19
31+
+ nop; // scope 0 at $DIR/move.rs:+1:18: +1:19
32+
+ nop; // scope 0 at $DIR/move.rs:+1:18: +1:19
33+
StorageDead(_2); // scope 0 at $DIR/move.rs:+1:19: +1:20
34+
_0 = const (); // scope 0 at $DIR/move.rs:+0:24: +2:2
35+
return; // scope 0 at $DIR/move.rs:+2:2: +2:2
36+
}
37+
}
38+

src/test/mir-opt/dest-prop/move.rs

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// unit-test: DestinationPropagation
2+
3+
#[inline(never)]
4+
fn use_both(_: i32, _: i32) {}
5+
6+
// EMIT_MIR move.move_simple.DestinationPropagation.diff
7+
fn move_simple(x: i32) {
8+
use_both(x, x);
9+
}
10+
11+
fn main() {
12+
move_simple(1);
13+
}

0 commit comments

Comments
 (0)