1
1
use rustc_index:: IndexVec ;
2
2
use rustc_middle:: mir:: * ;
3
3
use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
4
+ use rustc_target:: abi:: Size ;
4
5
use std:: iter;
5
6
6
7
use super :: simplify:: simplify_cfg;
@@ -67,13 +68,13 @@ trait SimplifyMatch<'tcx> {
67
68
_ => unreachable ! ( ) ,
68
69
} ;
69
70
70
- if !self . can_simplify ( tcx, targets, param_env, bbs) {
71
+ let discr_ty = discr. ty ( local_decls, tcx) ;
72
+ if !self . can_simplify ( tcx, targets, param_env, bbs, discr_ty) {
71
73
return false ;
72
74
}
73
75
74
76
// Take ownership of items now that we know we can optimize.
75
77
let discr = discr. clone ( ) ;
76
- let discr_ty = discr. ty ( local_decls, tcx) ;
77
78
78
79
// Introduce a temporary for the discriminant value.
79
80
let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
@@ -104,6 +105,7 @@ trait SimplifyMatch<'tcx> {
104
105
targets : & SwitchTargets ,
105
106
param_env : ParamEnv < ' tcx > ,
106
107
bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
108
+ discr_ty : Ty < ' tcx > ,
107
109
) -> bool ;
108
110
109
111
fn new_stmts (
@@ -157,6 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
157
159
targets : & SwitchTargets ,
158
160
param_env : ParamEnv < ' tcx > ,
159
161
bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
162
+ _discr_ty : Ty < ' tcx > ,
160
163
) -> bool {
161
164
if targets. iter ( ) . len ( ) != 1 {
162
165
return false ;
@@ -268,7 +271,7 @@ struct SimplifyToExp {
268
271
enum CompareType < ' tcx , ' a > {
269
272
Same ( & ' a StatementKind < ' tcx > ) ,
270
273
Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
271
- Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
274
+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > , bool ) ,
272
275
}
273
276
274
277
enum TransfromType {
@@ -282,7 +285,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
282
285
match compare_type {
283
286
CompareType :: Same ( _) => TransfromType :: Same ,
284
287
CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
285
- CompareType :: Discr ( _, _) => TransfromType :: Discr ,
288
+ CompareType :: Discr ( _, _, _ ) => TransfromType :: Discr ,
286
289
}
287
290
}
288
291
}
@@ -333,6 +336,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
333
336
targets : & SwitchTargets ,
334
337
param_env : ParamEnv < ' tcx > ,
335
338
bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
339
+ discr_ty : Ty < ' tcx > ,
336
340
) -> bool {
337
341
if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
338
342
return false ;
@@ -355,13 +359,19 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
355
359
return false ;
356
360
}
357
361
362
+ let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
358
363
let first_stmts = & bbs[ first_target] . statements ;
359
364
let ( second_val, second_target) = target_iter. next ( ) . unwrap ( ) ;
360
365
let second_stmts = & bbs[ second_target] . statements ;
361
366
if first_stmts. len ( ) != second_stmts. len ( ) {
362
367
return false ;
363
368
}
364
369
370
+ fn int_equal ( l : ScalarInt , r : impl Into < u128 > , size : Size ) -> bool {
371
+ l. try_to_int ( l. size ( ) ) . unwrap ( )
372
+ == ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . try_to_int ( size) . unwrap ( )
373
+ }
374
+
365
375
let mut compare_types = Vec :: new ( ) ;
366
376
for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
367
377
let compare_type = match ( & f. kind , & s. kind ) {
@@ -382,12 +392,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
382
392
) {
383
393
( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
384
394
( Some ( f) , Some ( s) )
385
- if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
386
- && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
395
+ if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
396
+ && int_equal ( f, first_val, discr_size)
397
+ && int_equal ( s, second_val, discr_size) )
398
+ || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
399
+ && Some ( s)
400
+ == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
387
401
{
388
- CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
402
+ CompareType :: Discr (
403
+ lhs_f,
404
+ f_c. const_ . ty ( ) ,
405
+ f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
406
+ )
407
+ }
408
+ _ => {
409
+ return false ;
389
410
}
390
- _ => return false ,
391
411
}
392
412
}
393
413
@@ -413,15 +433,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
413
433
&& s_c. const_ . ty ( ) == f_ty
414
434
&& s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
415
435
(
416
- CompareType :: Discr ( lhs_f, f_ty) ,
436
+ CompareType :: Discr ( lhs_f, f_ty, is_signed ) ,
417
437
StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
418
438
) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
419
439
let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
420
440
return false ;
421
441
} ;
422
- if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
423
- return false ;
442
+ if is_signed
443
+ && s_c. const_ . ty ( ) . is_signed ( )
444
+ && int_equal ( f, other_val, discr_size)
445
+ {
446
+ continue ;
447
+ }
448
+ if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
449
+ continue ;
424
450
}
451
+ return false ;
425
452
}
426
453
_ => return false ,
427
454
}
0 commit comments