Skip to content

Commit efbf694

Browse files
eddybLegNeato
authored andcommitted
WIP: mem2reg speedup
1 parent ea20ef3 commit efbf694

File tree

3 files changed

+63
-46
lines changed

3 files changed

+63
-46
lines changed

crates/rustc_codegen_spirv/src/linker/dce.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
//! *references* a rooted thing is also rooted, not the other way around - but that's the basic
88
//! concept.
99
10-
use rspirv::dr::{Function, Instruction, Module, Operand};
10+
use rspirv::dr::{Block, Function, Instruction, Module, Operand};
1111
use rspirv::spirv::{Decoration, LinkageType, Op, StorageClass, Word};
12-
use rustc_data_structures::fx::FxIndexSet;
12+
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
13+
use std::hash::Hash;
1314

1415
pub fn dce(module: &mut Module) {
1516
let mut rooted = collect_roots(module);
@@ -137,11 +138,11 @@ fn kill_unrooted(module: &mut Module, rooted: &FxIndexSet<Word>) {
137138
}
138139
}
139140

140-
pub fn dce_phi(func: &mut Function) {
141+
pub fn dce_phi(blocks: &mut FxIndexMap<impl Eq + Hash, &mut Block>) {
141142
let mut used = FxIndexSet::default();
142143
loop {
143144
let mut changed = false;
144-
for inst in func.all_inst_iter() {
145+
for inst in blocks.values().flat_map(|block| &block.instructions) {
145146
if inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()) {
146147
for op in &inst.operands {
147148
if let Some(id) = op.id_ref_any() {
@@ -154,7 +155,7 @@ pub fn dce_phi(func: &mut Function) {
154155
break;
155156
}
156157
}
157-
for block in &mut func.blocks {
158+
for block in blocks.values_mut() {
158159
block
159160
.instructions
160161
.retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()));

crates/rustc_codegen_spirv/src/linker/mem2reg.rs

+52-39
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,31 @@ use super::simple_passes::outgoing_edges;
1313
use super::{apply_rewrite_rules, id};
1414
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
1515
use rspirv::spirv::{Op, Word};
16-
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
16+
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap};
1717
use rustc_middle::bug;
1818
use std::collections::hash_map;
1919

20+
// HACK(eddyb) newtype instead of type alias to avoid mistakes.
21+
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
22+
struct LabelId(Word);
23+
2024
pub fn mem2reg(
2125
header: &mut ModuleHeader,
2226
types_global_values: &mut Vec<Instruction>,
2327
pointer_to_pointee: &FxHashMap<Word, Word>,
2428
constants: &FxHashMap<Word, u32>,
2529
func: &mut Function,
2630
) {
27-
let reachable = compute_reachable(&func.blocks);
28-
let preds = compute_preds(&func.blocks, &reachable);
31+
// HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but
32+
// it's made completely irrelevant by SPIR-T so only applies to legacy code.
33+
let mut blocks: FxIndexMap<_, _> = func
34+
.blocks
35+
.iter_mut()
36+
.map(|block| (LabelId(block.label_id().unwrap()), block))
37+
.collect();
38+
39+
let reachable = compute_reachable(&blocks);
40+
let preds = compute_preds(&blocks, &reachable);
2941
let idom = compute_idom(&preds, &reachable);
3042
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
3143
loop {
@@ -34,31 +46,27 @@ pub fn mem2reg(
3446
types_global_values,
3547
pointer_to_pointee,
3648
constants,
37-
&mut func.blocks,
49+
&mut blocks,
3850
&dominance_frontier,
3951
);
4052
if !changed {
4153
break;
4254
}
4355
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
44-
super::dce::dce_phi(func);
56+
super::dce::dce_phi(&mut blocks);
4557
}
4658
}
4759

48-
fn label_to_index(blocks: &[Block], id: Word) -> usize {
49-
blocks
50-
.iter()
51-
.position(|b| b.label_id().unwrap() == id)
52-
.unwrap()
53-
}
54-
55-
fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
56-
fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) {
60+
fn compute_reachable(blocks: &FxIndexMap<LabelId, &mut Block>) -> Vec<bool> {
61+
fn recurse(blocks: &FxIndexMap<LabelId, &mut Block>, reachable: &mut [bool], block: usize) {
5762
if !reachable[block] {
5863
reachable[block] = true;
59-
for dest_id in outgoing_edges(&blocks[block]) {
60-
let dest_idx = label_to_index(blocks, dest_id);
61-
recurse(blocks, reachable, dest_idx);
64+
for dest_id in outgoing_edges(blocks[block]) {
65+
recurse(
66+
blocks,
67+
reachable,
68+
blocks.get_index_of(&LabelId(dest_id)).unwrap(),
69+
);
6270
}
6371
}
6472
}
@@ -67,17 +75,19 @@ fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
6775
reachable
6876
}
6977

70-
fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec<Vec<usize>> {
78+
fn compute_preds(
79+
blocks: &FxIndexMap<LabelId, &mut Block>,
80+
reachable_blocks: &[bool],
81+
) -> Vec<Vec<usize>> {
7182
let mut result = vec![vec![]; blocks.len()];
7283
// Do not count unreachable blocks as valid preds of blocks
7384
for (source_idx, source) in blocks
74-
.iter()
85+
.values()
7586
.enumerate()
7687
.filter(|&(b, _)| reachable_blocks[b])
7788
{
7889
for dest_id in outgoing_edges(source) {
79-
let dest_idx = label_to_index(blocks, dest_id);
80-
result[dest_idx].push(source_idx);
90+
result[blocks.get_index_of(&LabelId(dest_id)).unwrap()].push(source_idx);
8191
}
8292
}
8393
result
@@ -161,7 +171,7 @@ fn insert_phis_all(
161171
types_global_values: &mut Vec<Instruction>,
162172
pointer_to_pointee: &FxHashMap<Word, Word>,
163173
constants: &FxHashMap<Word, u32>,
164-
blocks: &mut [Block],
174+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
165175
dominance_frontier: &[FxHashSet<usize>],
166176
) -> bool {
167177
let var_maps_and_types = blocks[0]
@@ -198,7 +208,11 @@ fn insert_phis_all(
198208
rewrite_rules: FxHashMap::default(),
199209
};
200210
renamer.rename(0, None);
201-
apply_rewrite_rules(&renamer.rewrite_rules, blocks);
211+
// FIXME(eddyb) shouldn't this full rescan of the function be done once?
212+
apply_rewrite_rules(
213+
&renamer.rewrite_rules,
214+
blocks.values_mut().map(|block| &mut **block),
215+
);
202216
remove_nops(blocks);
203217
}
204218
remove_old_variables(blocks, &var_maps_and_types);
@@ -216,7 +230,7 @@ struct VarInfo {
216230
fn collect_access_chains(
217231
pointer_to_pointee: &FxHashMap<Word, Word>,
218232
constants: &FxHashMap<Word, u32>,
219-
blocks: &[Block],
233+
blocks: &FxIndexMap<LabelId, &mut Block>,
220234
base_var: Word,
221235
base_var_ty: Word,
222236
) -> Option<FxHashMap<Word, VarInfo>> {
@@ -246,7 +260,7 @@ fn collect_access_chains(
246260
// Loop in case a previous block references a later AccessChain
247261
loop {
248262
let mut changed = false;
249-
for inst in blocks.iter().flat_map(|b| &b.instructions) {
263+
for inst in blocks.values().flat_map(|b| &b.instructions) {
250264
for (index, op) in inst.operands.iter().enumerate() {
251265
if let Operand::IdRef(id) = op {
252266
if variables.contains_key(id) {
@@ -304,10 +318,10 @@ fn collect_access_chains(
304318
// same var map (e.g. `s.x = s.y;`).
305319
fn split_copy_memory(
306320
header: &mut ModuleHeader,
307-
blocks: &mut [Block],
321+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
308322
var_map: &FxHashMap<Word, VarInfo>,
309323
) {
310-
for block in blocks {
324+
for block in blocks.values_mut() {
311325
let mut inst_index = 0;
312326
while inst_index < block.instructions.len() {
313327
let inst = &block.instructions[inst_index];
@@ -362,7 +376,7 @@ fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
362376
}
363377

364378
fn insert_phis(
365-
blocks: &[Block],
379+
blocks: &FxIndexMap<LabelId, &mut Block>,
366380
dominance_frontier: &[FxHashSet<usize>],
367381
var_map: &FxHashMap<Word, VarInfo>,
368382
) -> FxHashSet<usize> {
@@ -371,7 +385,7 @@ fn insert_phis(
371385
let mut ever_on_work_list = FxHashSet::default();
372386
let mut work_list = Vec::new();
373387
let mut blocks_with_phi = FxHashSet::default();
374-
for (block_idx, block) in blocks.iter().enumerate() {
388+
for (block_idx, block) in blocks.values().enumerate() {
375389
if has_store(block, var_map) {
376390
ever_on_work_list.insert(block_idx);
377391
work_list.push(block_idx);
@@ -416,10 +430,10 @@ fn top_stack_or_undef(
416430
}
417431
}
418432

419-
struct Renamer<'a> {
433+
struct Renamer<'a, 'b> {
420434
header: &'a mut ModuleHeader,
421435
types_global_values: &'a mut Vec<Instruction>,
422-
blocks: &'a mut [Block],
436+
blocks: &'a mut FxIndexMap<LabelId, &'b mut Block>,
423437
blocks_with_phi: FxHashSet<usize>,
424438
base_var_type: Word,
425439
var_map: &'a FxHashMap<Word, VarInfo>,
@@ -429,7 +443,7 @@ struct Renamer<'a> {
429443
rewrite_rules: FxHashMap<Word, Word>,
430444
}
431445

432-
impl Renamer<'_> {
446+
impl Renamer<'_, '_> {
433447
// Returns the phi definition.
434448
fn insert_phi_value(&mut self, block: usize, from_block: usize) -> Word {
435449
let from_block_label = self.blocks[from_block].label_id().unwrap();
@@ -549,9 +563,8 @@ impl Renamer<'_> {
549563
}
550564
}
551565

552-
for dest_id in outgoing_edges(&self.blocks[block]).collect::<Vec<_>>() {
553-
// TODO: Don't do this find
554-
let dest_idx = label_to_index(self.blocks, dest_id);
566+
for dest_id in outgoing_edges(self.blocks[block]).collect::<Vec<_>>() {
567+
let dest_idx = self.blocks.get_index_of(&LabelId(dest_id)).unwrap();
555568
self.rename(dest_idx, Some(block));
556569
}
557570

@@ -561,16 +574,16 @@ impl Renamer<'_> {
561574
}
562575
}
563576

564-
fn remove_nops(blocks: &mut [Block]) {
565-
for block in blocks {
577+
fn remove_nops(blocks: &mut FxIndexMap<LabelId, &mut Block>) {
578+
for block in blocks.values_mut() {
566579
block
567580
.instructions
568581
.retain(|inst| inst.class.opcode != Op::Nop);
569582
}
570583
}
571584

572585
fn remove_old_variables(
573-
blocks: &mut [Block],
586+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
574587
var_maps_and_types: &[(FxHashMap<u32, VarInfo>, u32)],
575588
) {
576589
blocks[0].instructions.retain(|inst| {
@@ -581,7 +594,7 @@ fn remove_old_variables(
581594
.all(|(var_map, _)| !var_map.contains_key(&result_id))
582595
}
583596
});
584-
for block in blocks {
597+
for block in blocks.values_mut() {
585598
block.instructions.retain(|inst| {
586599
!matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain)
587600
|| inst.operands.iter().all(|op| {

crates/rustc_codegen_spirv/src/linker/mod.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ fn id(header: &mut ModuleHeader) -> Word {
8585
result
8686
}
8787

88-
fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Block]) {
88+
fn apply_rewrite_rules<'a>(
89+
rewrite_rules: &FxHashMap<Word, Word>,
90+
blocks: impl IntoIterator<Item = &'a mut Block>,
91+
) {
8992
let all_ids_mut = blocks
90-
.iter_mut()
93+
.into_iter()
9194
.flat_map(|b| b.label.iter_mut().chain(b.instructions.iter_mut()))
9295
.flat_map(|inst| {
9396
inst.result_id

0 commit comments

Comments
 (0)