1
1
use rustc_index:: IndexVec ;
2
2
use rustc_middle:: mir:: * ;
3
- use rustc_middle:: ty:: { ParamEnv , Ty , TyCtxt } ;
3
+ use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
4
4
use std:: iter;
5
5
6
6
use super :: simplify:: simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
38
38
should_cleanup = true ;
39
39
continue ;
40
40
}
41
+ if SimplifyToExp :: default ( ) . simplify ( tcx, & mut body. local_decls , bbs, bb_idx, param_env)
42
+ {
43
+ should_cleanup = true ;
44
+ continue ;
45
+ }
41
46
}
42
47
43
48
if should_cleanup {
@@ -48,7 +53,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
48
53
49
54
trait SimplifyMatch < ' tcx > {
50
55
fn simplify (
51
- & self ,
56
+ & mut self ,
52
57
tcx : TyCtxt < ' tcx > ,
53
58
local_decls : & mut IndexVec < Local , LocalDecl < ' tcx > > ,
54
59
bbs : & mut IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
@@ -72,7 +77,7 @@ trait SimplifyMatch<'tcx> {
72
77
let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
73
78
let discr_local = local_decls. push ( LocalDecl :: new ( discr_ty, source_info. span ) ) ;
74
79
75
- // We already checked that first and second are different blocks,
80
+ // We already checked that targets are different blocks,
76
81
// and bb_idx has a different terminator from both of them.
77
82
let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local. clone ( ) , discr_ty) ;
78
83
let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
@@ -91,7 +96,7 @@ trait SimplifyMatch<'tcx> {
91
96
}
92
97
93
98
fn can_simplify (
94
- & self ,
99
+ & mut self ,
95
100
tcx : TyCtxt < ' tcx > ,
96
101
targets : & SwitchTargets ,
97
102
param_env : ParamEnv < ' tcx > ,
@@ -144,7 +149,7 @@ struct SimplifyToIf;
144
149
/// ```
145
150
impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToIf {
146
151
fn can_simplify (
147
- & self ,
152
+ & mut self ,
148
153
tcx : TyCtxt < ' tcx > ,
149
154
targets : & SwitchTargets ,
150
155
param_env : ParamEnv < ' tcx > ,
@@ -250,3 +255,207 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250
255
new_stmts. collect ( )
251
256
}
252
257
}
258
+
259
+ #[ derive( Default ) ]
260
+ struct SimplifyToExp {
261
+ transfrom_types : Vec < TransfromType > ,
262
+ }
263
+
264
+ #[ derive( Clone , Copy ) ]
265
+ enum CompareType < ' tcx , ' a > {
266
+ Same ( & ' a StatementKind < ' tcx > ) ,
267
+ Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
268
+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
269
+ }
270
+
271
+ enum TransfromType {
272
+ Same ,
273
+ Eq ,
274
+ Discr ,
275
+ }
276
+
277
+ impl From < CompareType < ' _ , ' _ > > for TransfromType {
278
+ fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
279
+ match compare_type {
280
+ CompareType :: Same ( _) => TransfromType :: Same ,
281
+ CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
282
+ CompareType :: Discr ( _, _) => TransfromType :: Discr ,
283
+ }
284
+ }
285
+ }
286
+
287
+ /// If we find that the value of match is the same as the assignment,
288
+ /// merge a target block statements into the source block,
289
+ /// using cast to transform different integer types.
290
+ ///
291
+ /// For example:
292
+ ///
293
+ /// ```ignore (MIR)
294
+ /// bb0: {
295
+ /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
296
+ /// }
297
+ ///
298
+ /// bb1: {
299
+ /// unreachable;
300
+ /// }
301
+ ///
302
+ /// bb2: {
303
+ /// _0 = const 1_i16;
304
+ /// goto -> bb5;
305
+ /// }
306
+ ///
307
+ /// bb3: {
308
+ /// _0 = const 2_i16;
309
+ /// goto -> bb5;
310
+ /// }
311
+ ///
312
+ /// bb4: {
313
+ /// _0 = const 3_i16;
314
+ /// goto -> bb5;
315
+ /// }
316
+ /// ```
317
+ ///
318
+ /// into:
319
+ ///
320
+ /// ```ignore (MIR)
321
+ /// bb0: {
322
+ /// _0 = _3 as i16 (IntToInt);
323
+ /// goto -> bb5;
324
+ /// }
325
+ /// ```
326
+ impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToExp {
327
+ fn can_simplify (
328
+ & mut self ,
329
+ tcx : TyCtxt < ' tcx > ,
330
+ targets : & SwitchTargets ,
331
+ param_env : ParamEnv < ' tcx > ,
332
+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
333
+ ) -> bool {
334
+ if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
335
+ return false ;
336
+ }
337
+ // We require that the possible target blocks all be distinct.
338
+ if !targets. is_distinct ( ) {
339
+ return false ;
340
+ }
341
+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
342
+ return false ;
343
+ }
344
+ let mut iter = targets. iter ( ) ;
345
+ let ( first_val, first_target) = iter. next ( ) . unwrap ( ) ;
346
+ let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
347
+ // Check that destinations are identical, and if not, then don't optimize this block
348
+ if !targets
349
+ . iter ( )
350
+ . all ( |( _, other_target) | first_terminator_kind == & bbs[ other_target] . terminator ( ) . kind )
351
+ {
352
+ return false ;
353
+ }
354
+
355
+ let first_stmts = & bbs[ first_target] . statements ;
356
+ let ( second_val, second_target) = iter. next ( ) . unwrap ( ) ;
357
+ let second_stmts = & bbs[ second_target] . statements ;
358
+
359
+ let mut compare_types = Vec :: new ( ) ;
360
+ for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
361
+ let compare_type = match ( & f. kind , & s. kind ) {
362
+ // If two statements are exactly the same, we can optimize.
363
+ ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
364
+
365
+ // If two statements are assignments with the match values to the same place, we can optimize.
366
+ (
367
+ StatementKind :: Assign ( box ( lhs_f, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
368
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
369
+ ) if lhs_f == lhs_s
370
+ && f_c. const_ . ty ( ) == s_c. const_ . ty ( )
371
+ && f_c. const_ . ty ( ) . is_integral ( ) =>
372
+ {
373
+ match (
374
+ f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
375
+ s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
376
+ ) {
377
+ ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
378
+ ( Some ( f) , Some ( s) )
379
+ if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
380
+ && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
381
+ {
382
+ CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
383
+ }
384
+ _ => return false ,
385
+ }
386
+ }
387
+
388
+ // Otherwise we cannot optimize. Try another block.
389
+ _ => return false ,
390
+ } ;
391
+ compare_types. push ( compare_type) ;
392
+ }
393
+
394
+ for ( other_val, other_target) in iter {
395
+ let other_stmts = & bbs[ other_target] . statements ;
396
+ if compare_types. len ( ) != other_stmts. len ( ) {
397
+ return false ;
398
+ }
399
+ for ( f, s) in iter:: zip ( & compare_types, other_stmts) {
400
+ match ( * f, & s. kind ) {
401
+ ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
402
+ (
403
+ CompareType :: Eq ( lhs_f, f_ty, val) ,
404
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
405
+ ) if lhs_f == lhs_s
406
+ && s_c. const_ . ty ( ) == f_ty
407
+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
408
+ (
409
+ CompareType :: Discr ( lhs_f, f_ty) ,
410
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
411
+ ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
412
+ let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
413
+ return false ;
414
+ } ;
415
+ if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
416
+ return false ;
417
+ }
418
+ }
419
+ _ => return false ,
420
+ }
421
+ }
422
+ }
423
+ self . transfrom_types = compare_types. into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
424
+ true
425
+ }
426
+
427
+ fn new_stmts (
428
+ & self ,
429
+ _tcx : TyCtxt < ' tcx > ,
430
+ targets : & SwitchTargets ,
431
+ _param_env : ParamEnv < ' tcx > ,
432
+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
433
+ discr_local : Local ,
434
+ discr_ty : Ty < ' tcx > ,
435
+ ) -> Vec < Statement < ' tcx > > {
436
+ let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
437
+ let first = & bbs[ first] ;
438
+
439
+ let new_stmts =
440
+ iter:: zip ( & self . transfrom_types , & first. statements ) . map ( |( t, s) | match ( t, & s. kind ) {
441
+ ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => ( * s) . clone ( ) ,
442
+ (
443
+ TransfromType :: Discr ,
444
+ StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
445
+ ) => {
446
+ let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
447
+ let r_val = if f_c. const_ . ty ( ) == discr_ty {
448
+ Rvalue :: Use ( operand)
449
+ } else {
450
+ Rvalue :: Cast ( CastKind :: IntToInt , operand, f_c. const_ . ty ( ) )
451
+ } ;
452
+ Statement {
453
+ source_info : s. source_info ,
454
+ kind : StatementKind :: Assign ( Box :: new ( ( * lhs, r_val) ) ) ,
455
+ }
456
+ }
457
+ _ => unreachable ! ( ) ,
458
+ } ) ;
459
+ new_stmts. collect ( )
460
+ }
461
+ }
0 commit comments