Skip to content

Commit f8ebe31

Browse files
committed
Apply EarlyOtherwiseBranch to scalar value
1 parent 59a74db commit f8ebe31

5 files changed

+415
-92
lines changed

compiler/rustc_mir_transform/src/early_otherwise_branch.rs

+168-92
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,29 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
132132

133133
let mut patch = MirPatch::new(body);
134134

135-
// create temp to store second discriminant in, `_s` in example above
136-
let second_discriminant_temp =
137-
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
135+
let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
136+
// create temp to store second discriminant in, `_s` in example above
137+
let second_discriminant_temp =
138+
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
138139

139-
patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
140+
patch.add_statement(
141+
parent_end,
142+
StatementKind::StorageLive(second_discriminant_temp),
143+
);
140144

141-
// create assignment of discriminant
142-
patch.add_assign(
143-
parent_end,
144-
Place::from(second_discriminant_temp),
145-
Rvalue::Discriminant(opt_data.child_place),
146-
);
145+
// create assignment of discriminant
146+
patch.add_assign(
147+
parent_end,
148+
Place::from(second_discriminant_temp),
149+
Rvalue::Discriminant(opt_data.child_place),
150+
);
151+
(
152+
Some(second_discriminant_temp),
153+
Operand::Move(Place::from(second_discriminant_temp)),
154+
)
155+
} else {
156+
(None, Operand::Copy(opt_data.child_place))
157+
};
147158

148159
// create temp to store inequality comparison between the two discriminants, `_t` in
149160
// example above
@@ -152,11 +163,9 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
152163
let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
153164
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
154165

155-
// create inequality comparison between the two discriminants
156-
let comp_rvalue = Rvalue::BinaryOp(
157-
nequal,
158-
Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
159-
);
166+
// create inequality comparison
167+
let comp_rvalue =
168+
Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
160169
patch.add_statement(
161170
parent_end,
162171
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
@@ -192,8 +201,13 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
192201
TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
193202
);
194203

195-
// generate StorageDead for the second_discriminant_temp not in use anymore
196-
patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
204+
if let Some(second_discriminant_temp) = second_discriminant_temp {
205+
// generate StorageDead for the second_discriminant_temp not in use anymore
206+
patch.add_statement(
207+
parent_end,
208+
StatementKind::StorageDead(second_discriminant_temp),
209+
);
210+
}
197211

198212
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
199213
// the switch
@@ -221,6 +235,7 @@ struct OptimizationData<'tcx> {
221235
child_place: Place<'tcx>,
222236
child_ty: Ty<'tcx>,
223237
child_source: SourceInfo,
238+
need_hoist_discriminant: bool,
224239
}
225240

226241
fn evaluate_candidate<'tcx>(
@@ -234,70 +249,128 @@ fn evaluate_candidate<'tcx>(
234249
return None;
235250
};
236251
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
237-
if !bbs[targets.otherwise()].is_empty_unreachable() {
238-
// Someone could write code like this:
239-
// ```rust
240-
// let Q = val;
241-
// if discriminant(P) == otherwise {
242-
// let ptr = &mut Q as *mut _ as *mut u8;
243-
// // It may be difficult for us to effectively determine whether values are valid.
244-
// // Invalid values can come from all sorts of corners.
245-
// unsafe { *ptr = 10; }
246-
// }
247-
//
248-
// match P {
249-
// A => match Q {
250-
// A => {
251-
// // code
252-
// }
253-
// _ => {
254-
// // don't use Q
255-
// }
256-
// }
257-
// _ => {
258-
// // don't use Q
259-
// }
260-
// };
261-
// ```
262-
//
263-
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
264-
// invalid value, which is UB.
265-
// In order to fix this, **we would either need to show that the discriminant computation of
266-
// `place` is computed in all branches**.
267-
// FIXME(#95162) For the moment, we adopt a conservative approach and
268-
// consider only the `otherwise` branch has no statements and an unreachable terminator.
269-
return None;
270-
}
271252
let (_, child) = targets.iter().next()?;
272-
let child_terminator = &bbs[child].terminator();
273-
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
274-
&child_terminator.kind
253+
254+
let Terminator {
255+
kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
256+
source_info,
257+
} = bbs[child].terminator()
275258
else {
276259
return None;
277260
};
278261
let child_ty = child_discr.ty(body.local_decls(), tcx);
279262
if child_ty != parent_ty {
280263
return None;
281264
}
282-
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind) else {
265+
266+
// We only handle:
267+
// ```
268+
// bb4: {
269+
// _8 = discriminant((_3.1: Enum1));
270+
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
271+
// }
272+
// ```
273+
// and
274+
// ```
275+
// bb2: {
276+
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
277+
// }
278+
// ```
279+
if bbs[child].statements.len() > 1 {
283280
return None;
281+
}
282+
283+
// When thie BB has exactly one statement, this statement should be discriminant.
284+
let need_hoist_discriminant = bbs[child].statements.len() == 1;
285+
let child_place = if need_hoist_discriminant {
286+
if !bbs[targets.otherwise()].is_empty_unreachable() {
287+
// Someone could write code like this:
288+
// ```rust
289+
// let Q = val;
290+
// if discriminant(P) == otherwise {
291+
// let ptr = &mut Q as *mut _ as *mut u8;
292+
// // It may be difficult for us to effectively determine whether values are valid.
293+
// // Invalid values can come from all sorts of corners.
294+
// unsafe { *ptr = 10; }
295+
// }
296+
//
297+
// match P {
298+
// A => match Q {
299+
// A => {
300+
// // code
301+
// }
302+
// _ => {
303+
// // don't use Q
304+
// }
305+
// }
306+
// _ => {
307+
// // don't use Q
308+
// }
309+
// };
310+
// ```
311+
//
312+
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
313+
// invalid value, which is UB.
314+
// In order to fix this, **we would either need to show that the discriminant computation of
315+
// `place` is computed in all branches**.
316+
// FIXME(#95162) For the moment, we adopt a conservative approach and
317+
// consider only the `otherwise` branch has no statements and an unreachable terminator.
318+
return None;
319+
}
320+
// Handle:
321+
// ```
322+
// bb4: {
323+
// _8 = discriminant((_3.1: Enum1));
324+
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
325+
// }
326+
// ```
327+
let [
328+
Statement {
329+
kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
330+
..
331+
},
332+
] = bbs[child].statements.as_slice()
333+
else {
334+
return None;
335+
};
336+
*child_place
337+
} else {
338+
// Handle:
339+
// ```
340+
// bb2: {
341+
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
342+
// }
343+
// ```
344+
let Operand::Copy(child_place) = child_discr else {
345+
return None;
346+
};
347+
*child_place
284348
};
285-
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
286-
return None;
349+
let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
350+
{
351+
child_targets.otherwise()
352+
} else {
353+
targets.otherwise()
287354
};
288-
let destination = child_targets.otherwise();
289355

290356
// Verify that the optimization is legal for each branch
291357
for (value, child) in targets.iter() {
292-
if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
358+
if !verify_candidate_branch(
359+
&bbs[child],
360+
value,
361+
child_place,
362+
destination,
363+
need_hoist_discriminant,
364+
) {
293365
return None;
294366
}
295367
}
296368
Some(OptimizationData {
297369
destination,
298-
child_place: *child_place,
370+
child_place,
299371
child_ty,
300-
child_source: child_terminator.source_info,
372+
child_source: *source_info,
373+
need_hoist_discriminant,
301374
})
302375
}
303376

@@ -306,45 +379,48 @@ fn verify_candidate_branch<'tcx>(
306379
value: u128,
307380
place: Place<'tcx>,
308381
destination: BasicBlock,
382+
need_hoist_discriminant: bool,
309383
) -> bool {
310-
// In order for the optimization to be correct, the branch must...
311-
// ...have exactly one statement
312-
let [statement] = branch.statements.as_slice() else {
313-
return false;
314-
};
315-
// ...assign the discriminant of `place` in that statement
316-
let StatementKind::Assign(boxed) = &statement.kind else { return false };
317-
let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else { return false };
318-
if *from_place != place {
319-
return false;
320-
}
321-
// ...make that assignment to a local
322-
if discr_place.projection.len() != 0 {
323-
return false;
324-
}
325-
// ...terminate on a `SwitchInt` that invalidates that local
326-
let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } = &branch.terminator().kind
327-
else {
384+
// In order for the optimization to be correct, the terminator must be a `SwitchInt`.
385+
let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
328386
return false;
329387
};
330-
if *switch_op != Operand::Move(*discr_place) {
331-
return false;
388+
if need_hoist_discriminant {
389+
// If we need hoist discriminant, the branch must have exactly one statement.
390+
let [statement] = branch.statements.as_slice() else {
391+
return false;
392+
};
393+
// The statement must assign the discriminant of `place`.
394+
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
395+
statement.kind
396+
else {
397+
return false;
398+
};
399+
if from_place != place {
400+
return false;
401+
}
402+
// The assignment must invalidate a local that terminate on a `SwitchInt`.
403+
if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
404+
return false;
405+
}
406+
} else {
407+
// If we don't need hoist discriminant, the branch must not have any statements.
408+
if !branch.statements.is_empty() {
409+
return false;
410+
}
411+
// The place on `SwitchInt` must be the same.
412+
if *switch_op != Operand::Copy(place) {
413+
return false;
414+
}
332415
}
333-
// ...fall through to `destination` if the switch misses
416+
// It must fall through to `destination` if the switch misses.
334417
if destination != targets.otherwise() {
335418
return false;
336419
}
337-
// ...have a branch for value `value`
420+
// It must have exactly one branch for value `value` and have no more branches.
338421
let mut iter = targets.iter();
339-
let Some((target_value, _)) = iter.next() else {
422+
let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
340423
return false;
341424
};
342-
if target_value != value {
343-
return false;
344-
}
345-
// ...and have no more branches
346-
if let Some(_) = iter.next() {
347-
return false;
348-
}
349-
true
425+
target_value == value
350426
}

0 commit comments

Comments
 (0)