@@ -132,18 +132,29 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
132
132
133
133
let mut patch = MirPatch :: new ( body) ;
134
134
135
- // create temp to store second discriminant in, `_s` in example above
136
- let second_discriminant_temp =
137
- patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
135
+ let ( second_discriminant_temp, second_operand) = if opt_data. need_hoist_discriminant {
136
+ // create temp to store second discriminant in, `_s` in example above
137
+ let second_discriminant_temp =
138
+ patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
138
139
139
- patch. add_statement ( parent_end, StatementKind :: StorageLive ( second_discriminant_temp) ) ;
140
+ patch. add_statement (
141
+ parent_end,
142
+ StatementKind :: StorageLive ( second_discriminant_temp) ,
143
+ ) ;
140
144
141
- // create assignment of discriminant
142
- patch. add_assign (
143
- parent_end,
144
- Place :: from ( second_discriminant_temp) ,
145
- Rvalue :: Discriminant ( opt_data. child_place ) ,
146
- ) ;
145
+ // create assignment of discriminant
146
+ patch. add_assign (
147
+ parent_end,
148
+ Place :: from ( second_discriminant_temp) ,
149
+ Rvalue :: Discriminant ( opt_data. child_place ) ,
150
+ ) ;
151
+ (
152
+ Some ( second_discriminant_temp) ,
153
+ Operand :: Move ( Place :: from ( second_discriminant_temp) ) ,
154
+ )
155
+ } else {
156
+ ( None , Operand :: Copy ( opt_data. child_place ) )
157
+ } ;
147
158
148
159
// create temp to store inequality comparison between the two discriminants, `_t` in
149
160
// example above
@@ -152,11 +163,9 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
152
163
let comp_temp = patch. new_temp ( comp_res_type, opt_data. child_source . span ) ;
153
164
patch. add_statement ( parent_end, StatementKind :: StorageLive ( comp_temp) ) ;
154
165
155
- // create inequality comparison between the two discriminants
156
- let comp_rvalue = Rvalue :: BinaryOp (
157
- nequal,
158
- Box :: new ( ( parent_op. clone ( ) , Operand :: Move ( Place :: from ( second_discriminant_temp) ) ) ) ,
159
- ) ;
166
+ // create inequality comparison
167
+ let comp_rvalue =
168
+ Rvalue :: BinaryOp ( nequal, Box :: new ( ( parent_op. clone ( ) , second_operand) ) ) ;
160
169
patch. add_statement (
161
170
parent_end,
162
171
StatementKind :: Assign ( Box :: new ( ( Place :: from ( comp_temp) , comp_rvalue) ) ) ,
@@ -192,8 +201,13 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
192
201
TerminatorKind :: if_ ( Operand :: Move ( Place :: from ( comp_temp) ) , true_case, false_case) ,
193
202
) ;
194
203
195
- // generate StorageDead for the second_discriminant_temp not in use anymore
196
- patch. add_statement ( parent_end, StatementKind :: StorageDead ( second_discriminant_temp) ) ;
204
+ if let Some ( second_discriminant_temp) = second_discriminant_temp {
205
+ // generate StorageDead for the second_discriminant_temp not in use anymore
206
+ patch. add_statement (
207
+ parent_end,
208
+ StatementKind :: StorageDead ( second_discriminant_temp) ,
209
+ ) ;
210
+ }
197
211
198
212
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
199
213
// the switch
@@ -221,6 +235,7 @@ struct OptimizationData<'tcx> {
221
235
child_place : Place < ' tcx > ,
222
236
child_ty : Ty < ' tcx > ,
223
237
child_source : SourceInfo ,
238
+ need_hoist_discriminant : bool ,
224
239
}
225
240
226
241
fn evaluate_candidate < ' tcx > (
@@ -234,70 +249,128 @@ fn evaluate_candidate<'tcx>(
234
249
return None ;
235
250
} ;
236
251
let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
237
- if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
238
- // Someone could write code like this:
239
- // ```rust
240
- // let Q = val;
241
- // if discriminant(P) == otherwise {
242
- // let ptr = &mut Q as *mut _ as *mut u8;
243
- // // It may be difficult for us to effectively determine whether values are valid.
244
- // // Invalid values can come from all sorts of corners.
245
- // unsafe { *ptr = 10; }
246
- // }
247
- //
248
- // match P {
249
- // A => match Q {
250
- // A => {
251
- // // code
252
- // }
253
- // _ => {
254
- // // don't use Q
255
- // }
256
- // }
257
- // _ => {
258
- // // don't use Q
259
- // }
260
- // };
261
- // ```
262
- //
263
- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
264
- // invalid value, which is UB.
265
- // In order to fix this, **we would either need to show that the discriminant computation of
266
- // `place` is computed in all branches**.
267
- // FIXME(#95162) For the moment, we adopt a conservative approach and
268
- // consider only the `otherwise` branch has no statements and an unreachable terminator.
269
- return None ;
270
- }
271
252
let ( _, child) = targets. iter ( ) . next ( ) ?;
272
- let child_terminator = & bbs[ child] . terminator ( ) ;
273
- let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
274
- & child_terminator. kind
253
+
254
+ let Terminator {
255
+ kind : TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } ,
256
+ source_info,
257
+ } = bbs[ child] . terminator ( )
275
258
else {
276
259
return None ;
277
260
} ;
278
261
let child_ty = child_discr. ty ( body. local_decls ( ) , tcx) ;
279
262
if child_ty != parent_ty {
280
263
return None ;
281
264
}
282
- let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind ) else {
265
+
266
+ // We only handle:
267
+ // ```
268
+ // bb4: {
269
+ // _8 = discriminant((_3.1: Enum1));
270
+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
271
+ // }
272
+ // ```
273
+ // and
274
+ // ```
275
+ // bb2: {
276
+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
277
+ // }
278
+ // ```
279
+ if bbs[ child] . statements . len ( ) > 1 {
283
280
return None ;
281
+ }
282
+
283
+ // When thie BB has exactly one statement, this statement should be discriminant.
284
+ let need_hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
285
+ let child_place = if need_hoist_discriminant {
286
+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
287
+ // Someone could write code like this:
288
+ // ```rust
289
+ // let Q = val;
290
+ // if discriminant(P) == otherwise {
291
+ // let ptr = &mut Q as *mut _ as *mut u8;
292
+ // // It may be difficult for us to effectively determine whether values are valid.
293
+ // // Invalid values can come from all sorts of corners.
294
+ // unsafe { *ptr = 10; }
295
+ // }
296
+ //
297
+ // match P {
298
+ // A => match Q {
299
+ // A => {
300
+ // // code
301
+ // }
302
+ // _ => {
303
+ // // don't use Q
304
+ // }
305
+ // }
306
+ // _ => {
307
+ // // don't use Q
308
+ // }
309
+ // };
310
+ // ```
311
+ //
312
+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
313
+ // invalid value, which is UB.
314
+ // In order to fix this, **we would either need to show that the discriminant computation of
315
+ // `place` is computed in all branches**.
316
+ // FIXME(#95162) For the moment, we adopt a conservative approach and
317
+ // consider only the `otherwise` branch has no statements and an unreachable terminator.
318
+ return None ;
319
+ }
320
+ // Handle:
321
+ // ```
322
+ // bb4: {
323
+ // _8 = discriminant((_3.1: Enum1));
324
+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
325
+ // }
326
+ // ```
327
+ let [
328
+ Statement {
329
+ kind : StatementKind :: Assign ( box ( _, Rvalue :: Discriminant ( child_place) ) ) ,
330
+ ..
331
+ } ,
332
+ ] = bbs[ child] . statements . as_slice ( )
333
+ else {
334
+ return None ;
335
+ } ;
336
+ * child_place
337
+ } else {
338
+ // Handle:
339
+ // ```
340
+ // bb2: {
341
+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
342
+ // }
343
+ // ```
344
+ let Operand :: Copy ( child_place) = child_discr else {
345
+ return None ;
346
+ } ;
347
+ * child_place
284
348
} ;
285
- let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
286
- return None ;
349
+ let destination = if need_hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( )
350
+ {
351
+ child_targets. otherwise ( )
352
+ } else {
353
+ targets. otherwise ( )
287
354
} ;
288
- let destination = child_targets. otherwise ( ) ;
289
355
290
356
// Verify that the optimization is legal for each branch
291
357
for ( value, child) in targets. iter ( ) {
292
- if !verify_candidate_branch ( & bbs[ child] , value, * child_place, destination) {
358
+ if !verify_candidate_branch (
359
+ & bbs[ child] ,
360
+ value,
361
+ child_place,
362
+ destination,
363
+ need_hoist_discriminant,
364
+ ) {
293
365
return None ;
294
366
}
295
367
}
296
368
Some ( OptimizationData {
297
369
destination,
298
- child_place : * child_place ,
370
+ child_place,
299
371
child_ty,
300
- child_source : child_terminator. source_info ,
372
+ child_source : * source_info,
373
+ need_hoist_discriminant,
301
374
} )
302
375
}
303
376
@@ -306,45 +379,48 @@ fn verify_candidate_branch<'tcx>(
306
379
value : u128 ,
307
380
place : Place < ' tcx > ,
308
381
destination : BasicBlock ,
382
+ need_hoist_discriminant : bool ,
309
383
) -> bool {
310
- // In order for the optimization to be correct, the branch must...
311
- // ...have exactly one statement
312
- let [ statement] = branch. statements . as_slice ( ) else {
313
- return false ;
314
- } ;
315
- // ...assign the discriminant of `place` in that statement
316
- let StatementKind :: Assign ( boxed) = & statement. kind else { return false } ;
317
- let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed else { return false } ;
318
- if * from_place != place {
319
- return false ;
320
- }
321
- // ...make that assignment to a local
322
- if discr_place. projection . len ( ) != 0 {
323
- return false ;
324
- }
325
- // ...terminate on a `SwitchInt` that invalidates that local
326
- let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } = & branch. terminator ( ) . kind
327
- else {
384
+ // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
385
+ let TerminatorKind :: SwitchInt { discr : switch_op, targets } = & branch. terminator ( ) . kind else {
328
386
return false ;
329
387
} ;
330
- if * switch_op != Operand :: Move ( * discr_place) {
331
- return false ;
388
+ if need_hoist_discriminant {
389
+ // If we need hoist discriminant, the branch must have exactly one statement.
390
+ let [ statement] = branch. statements . as_slice ( ) else {
391
+ return false ;
392
+ } ;
393
+ // The statement must assign the discriminant of `place`.
394
+ let StatementKind :: Assign ( box ( discr_place, Rvalue :: Discriminant ( from_place) ) ) =
395
+ statement. kind
396
+ else {
397
+ return false ;
398
+ } ;
399
+ if from_place != place {
400
+ return false ;
401
+ }
402
+ // The assignment must invalidate a local that terminate on a `SwitchInt`.
403
+ if !discr_place. projection . is_empty ( ) || * switch_op != Operand :: Move ( discr_place) {
404
+ return false ;
405
+ }
406
+ } else {
407
+ // If we don't need hoist discriminant, the branch must not have any statements.
408
+ if !branch. statements . is_empty ( ) {
409
+ return false ;
410
+ }
411
+ // The place on `SwitchInt` must be the same.
412
+ if * switch_op != Operand :: Copy ( place) {
413
+ return false ;
414
+ }
332
415
}
333
- // ... fall through to `destination` if the switch misses
416
+ // It must fall through to `destination` if the switch misses.
334
417
if destination != targets. otherwise ( ) {
335
418
return false ;
336
419
}
337
- // ... have a branch for value `value`
420
+ // It must have exactly one branch for value `value` and have no more branches.
338
421
let mut iter = targets. iter ( ) ;
339
- let Some ( ( target_value, _) ) = iter. next ( ) else {
422
+ let ( Some ( ( target_value, _) ) , None ) = ( iter. next ( ) , iter . next ( ) ) else {
340
423
return false ;
341
424
} ;
342
- if target_value != value {
343
- return false ;
344
- }
345
- // ...and have no more branches
346
- if let Some ( _) = iter. next ( ) {
347
- return false ;
348
- }
349
- true
425
+ target_value == value
350
426
}
0 commit comments