Skip to content

Commit e4fd2bd

Browse files
committed
Refactor cost computation as a visitor.
1 parent bb99e6f commit e4fd2bd

File tree

1 file changed

+122
-92
lines changed

1 file changed

+122
-92
lines changed

compiler/rustc_mir_transform/src/inline.rs

+122-92
Original file line numberDiff line numberDiff line change
@@ -409,118 +409,56 @@ impl<'tcx> Inliner<'tcx> {
409409
debug!(" final inline threshold = {}", threshold);
410410

411411
// FIXME: Give a bonus to functions with only a single caller
412-
let mut first_block = true;
413-
let mut cost = 0;
412+
let diverges = matches!(
413+
callee_body.basic_blocks()[START_BLOCK].terminator().kind,
414+
TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
415+
);
416+
if diverges && !matches!(callee_attrs.inline, InlineAttr::Always) {
417+
return Err("callee diverges unconditionally");
418+
}
419+
420+
let mut checker = CostChecker {
421+
tcx: self.tcx,
422+
param_env: self.param_env,
423+
instance: callsite.callee,
424+
callee_body,
425+
cost: 0,
426+
};
414427

415-
// Traverse the MIR manually so we can account for the effects of
416-
// inlining on the CFG.
428+
// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
417429
let mut work_list = vec![START_BLOCK];
418430
let mut visited = BitSet::new_empty(callee_body.basic_blocks().len());
419431
while let Some(bb) = work_list.pop() {
420432
if !visited.insert(bb.index()) {
421433
continue;
422434
}
435+
423436
let blk = &callee_body.basic_blocks()[bb];
437+
checker.visit_basic_block_data(bb, blk);
424438

425-
for stmt in &blk.statements {
426-
// Don't count StorageLive/StorageDead in the inlining cost.
427-
match stmt.kind {
428-
StatementKind::StorageLive(_)
429-
| StatementKind::StorageDead(_)
430-
| StatementKind::Deinit(_)
431-
| StatementKind::Nop => {}
432-
_ => cost += INSTR_COST,
433-
}
434-
}
435439
let term = blk.terminator();
436-
let mut is_drop = false;
437-
match term.kind {
438-
TerminatorKind::Drop { ref place, target, unwind }
439-
| TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => {
440-
is_drop = true;
441-
work_list.push(target);
442-
// If the place doesn't actually need dropping, treat it like
443-
// a regular goto.
444-
let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
445-
if ty.needs_drop(tcx, self.param_env) {
446-
cost += CALL_PENALTY;
447-
if let Some(unwind) = unwind {
448-
cost += LANDINGPAD_PENALTY;
449-
work_list.push(unwind);
450-
}
451-
} else {
452-
cost += INSTR_COST;
453-
}
454-
}
455-
456-
TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
457-
if first_block =>
458-
{
459-
// If the function always diverges, don't inline
460-
// unless the cost is zero
461-
threshold = 0;
462-
}
463-
464-
TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
465-
if let ty::FnDef(def_id, _) =
466-
*callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind()
467-
{
468-
// Don't give intrinsics the extra penalty for calls
469-
if tcx.is_intrinsic(def_id) {
470-
cost += INSTR_COST;
471-
} else {
472-
cost += CALL_PENALTY;
473-
}
474-
} else {
475-
cost += CALL_PENALTY;
476-
}
477-
if cleanup.is_some() {
478-
cost += LANDINGPAD_PENALTY;
479-
}
480-
}
481-
TerminatorKind::Assert { cleanup, .. } => {
482-
cost += CALL_PENALTY;
483-
484-
if cleanup.is_some() {
485-
cost += LANDINGPAD_PENALTY;
486-
}
487-
}
488-
TerminatorKind::Resume => cost += RESUME_PENALTY,
489-
TerminatorKind::InlineAsm { cleanup, .. } => {
490-
cost += INSTR_COST;
440+
if let TerminatorKind::Drop { ref place, target, unwind }
441+
| TerminatorKind::DropAndReplace { ref place, target, unwind, .. } = term.kind
442+
{
443+
work_list.push(target);
491444

492-
if cleanup.is_some() {
493-
cost += LANDINGPAD_PENALTY;
445+
// If the place doesn't actually need dropping, treat it like a regular goto.
446+
let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
447+
if ty.needs_drop(tcx, self.param_env) && let Some(unwind) = unwind {
448+
work_list.push(unwind);
494449
}
495-
}
496-
_ => cost += INSTR_COST,
497-
}
498-
499-
if !is_drop {
500-
for succ in term.successors() {
501-
work_list.push(succ);
502-
}
450+
} else {
451+
work_list.extend(term.successors())
503452
}
504-
505-
first_block = false;
506453
}
507454

508455
// Count up the cost of local variables and temps, if we know the size
509456
// use that, otherwise we use a moderately-large dummy cost.
510-
511-
let ptr_size = tcx.data_layout.pointer_size.bytes();
512-
513457
for v in callee_body.vars_and_temps_iter() {
514-
let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty);
515-
// Cost of the var is the size in machine-words, if we know
516-
// it.
517-
if let Some(size) = type_size_of(tcx, self.param_env, ty) {
518-
cost += ((size + ptr_size - 1) / ptr_size) as usize;
519-
} else {
520-
cost += UNKNOWN_SIZE_COST;
521-
}
458+
checker.visit_local_decl(v, &callee_body.local_decls[v]);
522459
}
523460

461+
let cost = checker.cost;
524462
if let InlineAttr::Always = callee_attrs.inline {
525463
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
526464
Ok(())
@@ -790,6 +728,98 @@ fn type_size_of<'tcx>(
790728
tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes())
791729
}
792730

731+
/// Verify that the callee body is compatible with the caller.
732+
///
733+
/// This visitor mostly computes the inlining cost,
734+
/// but also needs to verify that types match because of normalization failure.
735+
struct CostChecker<'b, 'tcx> {
736+
tcx: TyCtxt<'tcx>,
737+
param_env: ParamEnv<'tcx>,
738+
cost: usize,
739+
callee_body: &'b Body<'tcx>,
740+
instance: ty::Instance<'tcx>,
741+
}
742+
743+
impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
744+
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
745+
// Don't count StorageLive/StorageDead in the inlining cost.
746+
match statement.kind {
747+
StatementKind::StorageLive(_)
748+
| StatementKind::StorageDead(_)
749+
| StatementKind::Deinit(_)
750+
| StatementKind::Nop => {}
751+
_ => self.cost += INSTR_COST,
752+
}
753+
754+
self.super_statement(statement, location);
755+
}
756+
757+
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
758+
let tcx = self.tcx;
759+
match terminator.kind {
760+
TerminatorKind::Drop { ref place, unwind, .. }
761+
| TerminatorKind::DropAndReplace { ref place, unwind, .. } => {
762+
// If the place doesn't actually need dropping, treat it like a regular goto.
763+
let ty = self.instance.subst_mir(tcx, &place.ty(self.callee_body, tcx).ty);
764+
if ty.needs_drop(tcx, self.param_env) {
765+
self.cost += CALL_PENALTY;
766+
if unwind.is_some() {
767+
self.cost += LANDINGPAD_PENALTY;
768+
}
769+
} else {
770+
self.cost += INSTR_COST;
771+
}
772+
}
773+
TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
774+
let fn_ty = self.instance.subst_mir(tcx, &f.literal.ty());
775+
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
776+
// Don't give intrinsics the extra penalty for calls
777+
INSTR_COST
778+
} else {
779+
CALL_PENALTY
780+
};
781+
if cleanup.is_some() {
782+
self.cost += LANDINGPAD_PENALTY;
783+
}
784+
}
785+
TerminatorKind::Assert { cleanup, .. } => {
786+
self.cost += CALL_PENALTY;
787+
if cleanup.is_some() {
788+
self.cost += LANDINGPAD_PENALTY;
789+
}
790+
}
791+
TerminatorKind::Resume => self.cost += RESUME_PENALTY,
792+
TerminatorKind::InlineAsm { cleanup, .. } => {
793+
self.cost += INSTR_COST;
794+
if cleanup.is_some() {
795+
self.cost += LANDINGPAD_PENALTY;
796+
}
797+
}
798+
_ => self.cost += INSTR_COST,
799+
}
800+
801+
self.super_terminator(terminator, location);
802+
}
803+
804+
/// Count up the cost of local variables and temps, if we know the size
805+
/// use that, otherwise we use a moderately-large dummy cost.
806+
fn visit_local_decl(&mut self, local: Local, local_decl: &LocalDecl<'tcx>) {
807+
let tcx = self.tcx;
808+
let ptr_size = tcx.data_layout.pointer_size.bytes();
809+
810+
let ty = self.instance.subst_mir(tcx, &local_decl.ty);
811+
// Cost of the var is the size in machine-words, if we know
812+
// it.
813+
if let Some(size) = type_size_of(tcx, self.param_env, ty) {
814+
self.cost += ((size + ptr_size - 1) / ptr_size) as usize;
815+
} else {
816+
self.cost += UNKNOWN_SIZE_COST;
817+
}
818+
819+
self.super_local_decl(local, local_decl)
820+
}
821+
}
822+
793823
/**
794824
* Integrator.
795825
*

0 commit comments

Comments
 (0)