@@ -13,19 +13,31 @@ use super::simple_passes::outgoing_edges;
13
13
use super :: { apply_rewrite_rules, id} ;
14
14
use rspirv:: dr:: { Block , Function , Instruction , ModuleHeader , Operand } ;
15
15
use rspirv:: spirv:: { Op , Word } ;
16
- use rustc_data_structures:: fx:: { FxHashMap , FxHashSet } ;
16
+ use rustc_data_structures:: fx:: { FxHashMap , FxHashSet , FxIndexMap } ;
17
17
use rustc_middle:: bug;
18
18
use std:: collections:: hash_map;
19
19
20
+ // HACK(eddyb) newtype instead of type alias to avoid mistakes.
21
+ #[ derive( Copy , Clone , PartialEq , Eq , Hash ) ]
22
+ struct LabelId ( Word ) ;
23
+
20
24
pub fn mem2reg (
21
25
header : & mut ModuleHeader ,
22
26
types_global_values : & mut Vec < Instruction > ,
23
27
pointer_to_pointee : & FxHashMap < Word , Word > ,
24
28
constants : & FxHashMap < Word , u32 > ,
25
29
func : & mut Function ,
26
30
) {
27
- let reachable = compute_reachable ( & func. blocks ) ;
28
- let preds = compute_preds ( & func. blocks , & reachable) ;
31
+ // HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but
32
+ // it's made completely irrelevant by SPIR-T so only applies to legacy code.
33
+ let mut blocks: FxIndexMap < _ , _ > = func
34
+ . blocks
35
+ . iter_mut ( )
36
+ . map ( |block| ( LabelId ( block. label_id ( ) . unwrap ( ) ) , block) )
37
+ . collect ( ) ;
38
+
39
+ let reachable = compute_reachable ( & blocks) ;
40
+ let preds = compute_preds ( & blocks, & reachable) ;
29
41
let idom = compute_idom ( & preds, & reachable) ;
30
42
let dominance_frontier = compute_dominance_frontier ( & preds, & idom) ;
31
43
loop {
@@ -34,31 +46,27 @@ pub fn mem2reg(
34
46
types_global_values,
35
47
pointer_to_pointee,
36
48
constants,
37
- & mut func . blocks ,
49
+ & mut blocks,
38
50
& dominance_frontier,
39
51
) ;
40
52
if !changed {
41
53
break ;
42
54
}
43
55
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
44
- super :: dce:: dce_phi ( func ) ;
56
+ super :: dce:: dce_phi ( & mut blocks ) ;
45
57
}
46
58
}
47
59
48
- fn label_to_index ( blocks : & [ Block ] , id : Word ) -> usize {
49
- blocks
50
- . iter ( )
51
- . position ( |b| b. label_id ( ) . unwrap ( ) == id)
52
- . unwrap ( )
53
- }
54
-
55
- fn compute_reachable ( blocks : & [ Block ] ) -> Vec < bool > {
56
- fn recurse ( blocks : & [ Block ] , reachable : & mut [ bool ] , block : usize ) {
60
+ fn compute_reachable ( blocks : & FxIndexMap < LabelId , & mut Block > ) -> Vec < bool > {
61
+ fn recurse ( blocks : & FxIndexMap < LabelId , & mut Block > , reachable : & mut [ bool ] , block : usize ) {
57
62
if !reachable[ block] {
58
63
reachable[ block] = true ;
59
- for dest_id in outgoing_edges ( & blocks[ block] ) {
60
- let dest_idx = label_to_index ( blocks, dest_id) ;
61
- recurse ( blocks, reachable, dest_idx) ;
64
+ for dest_id in outgoing_edges ( blocks[ block] ) {
65
+ recurse (
66
+ blocks,
67
+ reachable,
68
+ blocks. get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ,
69
+ ) ;
62
70
}
63
71
}
64
72
}
@@ -67,17 +75,19 @@ fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
67
75
reachable
68
76
}
69
77
70
- fn compute_preds ( blocks : & [ Block ] , reachable_blocks : & [ bool ] ) -> Vec < Vec < usize > > {
78
+ fn compute_preds (
79
+ blocks : & FxIndexMap < LabelId , & mut Block > ,
80
+ reachable_blocks : & [ bool ] ,
81
+ ) -> Vec < Vec < usize > > {
71
82
let mut result = vec ! [ vec![ ] ; blocks. len( ) ] ;
72
83
// Do not count unreachable blocks as valid preds of blocks
73
84
for ( source_idx, source) in blocks
74
- . iter ( )
85
+ . values ( )
75
86
. enumerate ( )
76
87
. filter ( |& ( b, _) | reachable_blocks[ b] )
77
88
{
78
89
for dest_id in outgoing_edges ( source) {
79
- let dest_idx = label_to_index ( blocks, dest_id) ;
80
- result[ dest_idx] . push ( source_idx) ;
90
+ result[ blocks. get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ] . push ( source_idx) ;
81
91
}
82
92
}
83
93
result
@@ -161,7 +171,7 @@ fn insert_phis_all(
161
171
types_global_values : & mut Vec < Instruction > ,
162
172
pointer_to_pointee : & FxHashMap < Word , Word > ,
163
173
constants : & FxHashMap < Word , u32 > ,
164
- blocks : & mut [ Block ] ,
174
+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
165
175
dominance_frontier : & [ FxHashSet < usize > ] ,
166
176
) -> bool {
167
177
let var_maps_and_types = blocks[ 0 ]
@@ -198,7 +208,11 @@ fn insert_phis_all(
198
208
rewrite_rules : FxHashMap :: default ( ) ,
199
209
} ;
200
210
renamer. rename ( 0 , None ) ;
201
- apply_rewrite_rules ( & renamer. rewrite_rules , blocks) ;
211
+ // FIXME(eddyb) shouldn't this full rescan of the function be done once?
212
+ apply_rewrite_rules (
213
+ & renamer. rewrite_rules ,
214
+ blocks. values_mut ( ) . map ( |block| & mut * * block) ,
215
+ ) ;
202
216
remove_nops ( blocks) ;
203
217
}
204
218
remove_old_variables ( blocks, & var_maps_and_types) ;
@@ -216,7 +230,7 @@ struct VarInfo {
216
230
fn collect_access_chains (
217
231
pointer_to_pointee : & FxHashMap < Word , Word > ,
218
232
constants : & FxHashMap < Word , u32 > ,
219
- blocks : & [ Block ] ,
233
+ blocks : & FxIndexMap < LabelId , & mut Block > ,
220
234
base_var : Word ,
221
235
base_var_ty : Word ,
222
236
) -> Option < FxHashMap < Word , VarInfo > > {
@@ -246,7 +260,7 @@ fn collect_access_chains(
246
260
// Loop in case a previous block references a later AccessChain
247
261
loop {
248
262
let mut changed = false ;
249
- for inst in blocks. iter ( ) . flat_map ( |b| & b. instructions ) {
263
+ for inst in blocks. values ( ) . flat_map ( |b| & b. instructions ) {
250
264
for ( index, op) in inst. operands . iter ( ) . enumerate ( ) {
251
265
if let Operand :: IdRef ( id) = op {
252
266
if variables. contains_key ( id) {
@@ -304,10 +318,10 @@ fn collect_access_chains(
304
318
// same var map (e.g. `s.x = s.y;`).
305
319
fn split_copy_memory (
306
320
header : & mut ModuleHeader ,
307
- blocks : & mut [ Block ] ,
321
+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
308
322
var_map : & FxHashMap < Word , VarInfo > ,
309
323
) {
310
- for block in blocks {
324
+ for block in blocks. values_mut ( ) {
311
325
let mut inst_index = 0 ;
312
326
while inst_index < block. instructions . len ( ) {
313
327
let inst = & block. instructions [ inst_index] ;
@@ -362,7 +376,7 @@ fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
362
376
}
363
377
364
378
fn insert_phis (
365
- blocks : & [ Block ] ,
379
+ blocks : & FxIndexMap < LabelId , & mut Block > ,
366
380
dominance_frontier : & [ FxHashSet < usize > ] ,
367
381
var_map : & FxHashMap < Word , VarInfo > ,
368
382
) -> FxHashSet < usize > {
@@ -371,7 +385,7 @@ fn insert_phis(
371
385
let mut ever_on_work_list = FxHashSet :: default ( ) ;
372
386
let mut work_list = Vec :: new ( ) ;
373
387
let mut blocks_with_phi = FxHashSet :: default ( ) ;
374
- for ( block_idx, block) in blocks. iter ( ) . enumerate ( ) {
388
+ for ( block_idx, block) in blocks. values ( ) . enumerate ( ) {
375
389
if has_store ( block, var_map) {
376
390
ever_on_work_list. insert ( block_idx) ;
377
391
work_list. push ( block_idx) ;
@@ -416,10 +430,10 @@ fn top_stack_or_undef(
416
430
}
417
431
}
418
432
419
- struct Renamer < ' a > {
433
+ struct Renamer < ' a , ' b > {
420
434
header : & ' a mut ModuleHeader ,
421
435
types_global_values : & ' a mut Vec < Instruction > ,
422
- blocks : & ' a mut [ Block ] ,
436
+ blocks : & ' a mut FxIndexMap < LabelId , & ' b mut Block > ,
423
437
blocks_with_phi : FxHashSet < usize > ,
424
438
base_var_type : Word ,
425
439
var_map : & ' a FxHashMap < Word , VarInfo > ,
@@ -429,7 +443,7 @@ struct Renamer<'a> {
429
443
rewrite_rules : FxHashMap < Word , Word > ,
430
444
}
431
445
432
- impl Renamer < ' _ > {
446
+ impl Renamer < ' _ , ' _ > {
433
447
// Returns the phi definition.
434
448
fn insert_phi_value ( & mut self , block : usize , from_block : usize ) -> Word {
435
449
let from_block_label = self . blocks [ from_block] . label_id ( ) . unwrap ( ) ;
@@ -549,9 +563,8 @@ impl Renamer<'_> {
549
563
}
550
564
}
551
565
552
- for dest_id in outgoing_edges ( & self . blocks [ block] ) . collect :: < Vec < _ > > ( ) {
553
- // TODO: Don't do this find
554
- let dest_idx = label_to_index ( self . blocks , dest_id) ;
566
+ for dest_id in outgoing_edges ( self . blocks [ block] ) . collect :: < Vec < _ > > ( ) {
567
+ let dest_idx = self . blocks . get_index_of ( & LabelId ( dest_id) ) . unwrap ( ) ;
555
568
self . rename ( dest_idx, Some ( block) ) ;
556
569
}
557
570
@@ -561,16 +574,16 @@ impl Renamer<'_> {
561
574
}
562
575
}
563
576
564
- fn remove_nops ( blocks : & mut [ Block ] ) {
565
- for block in blocks {
577
+ fn remove_nops ( blocks : & mut FxIndexMap < LabelId , & mut Block > ) {
578
+ for block in blocks. values_mut ( ) {
566
579
block
567
580
. instructions
568
581
. retain ( |inst| inst. class . opcode != Op :: Nop ) ;
569
582
}
570
583
}
571
584
572
585
fn remove_old_variables (
573
- blocks : & mut [ Block ] ,
586
+ blocks : & mut FxIndexMap < LabelId , & mut Block > ,
574
587
var_maps_and_types : & [ ( FxHashMap < u32 , VarInfo > , u32 ) ] ,
575
588
) {
576
589
blocks[ 0 ] . instructions . retain ( |inst| {
@@ -581,7 +594,7 @@ fn remove_old_variables(
581
594
. all ( |( var_map, _) | !var_map. contains_key ( & result_id) )
582
595
}
583
596
} ) ;
584
- for block in blocks {
597
+ for block in blocks. values_mut ( ) {
585
598
block. instructions . retain ( |inst| {
586
599
!matches ! ( inst. class. opcode, Op :: AccessChain | Op :: InBoundsAccessChain )
587
600
|| inst. operands . iter ( ) . all ( |op| {
0 commit comments