Skip to content

Commit f90e617

Browse files
committed
Auto merge of #77908 - bugadani:obl-forest, r=nnethercote
Try to make ObligationForest more efficient This PR tries to decrease the number of allocations in ObligationForest, as well as moves some cold path code to an uninlined function.
2 parents cb2462c + 8c7a8a6 commit f90e617

File tree

3 files changed

+329
-356
lines changed

3 files changed

+329
-356
lines changed

compiler/rustc_data_structures/src/obligation_forest/mod.rs

+64-52
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ pub struct ObligationForest<O: ForestObligation> {
149149
/// comments in `process_obligation` for details.
150150
active_cache: FxHashMap<O::CacheKey, usize>,
151151

152-
/// A vector reused in compress(), to avoid allocating new vectors.
153-
node_rewrites: Vec<usize>,
152+
/// A vector reused in compress() and find_cycles_from_node(), to avoid allocating new vectors.
153+
reused_node_vec: Vec<usize>,
154154

155155
obligation_tree_id_generator: ObligationTreeIdGenerator,
156156

@@ -251,12 +251,22 @@ enum NodeState {
251251
Error,
252252
}
253253

254+
/// This trait allows us to have two different Outcome types:
255+
/// - the normal one that does as little as possible
256+
/// - one for tests that does some additional work and checking
257+
pub trait OutcomeTrait {
258+
type Error;
259+
type Obligation;
260+
261+
fn new() -> Self;
262+
fn mark_not_stalled(&mut self);
263+
fn is_stalled(&self) -> bool;
264+
fn record_completed(&mut self, outcome: &Self::Obligation);
265+
fn record_error(&mut self, error: Self::Error);
266+
}
267+
254268
#[derive(Debug)]
255269
pub struct Outcome<O, E> {
256-
/// Obligations that were completely evaluated, including all
257-
/// (transitive) subobligations. Only computed if requested.
258-
pub completed: Option<Vec<O>>,
259-
260270
/// Backtrace of obligations that were found to be in error.
261271
pub errors: Vec<Error<O, E>>,
262272

@@ -269,12 +279,29 @@ pub struct Outcome<O, E> {
269279
pub stalled: bool,
270280
}
271281

272-
/// Should `process_obligations` compute the `Outcome::completed` field of its
273-
/// result?
274-
#[derive(PartialEq)]
275-
pub enum DoCompleted {
276-
No,
277-
Yes,
282+
impl<O, E> OutcomeTrait for Outcome<O, E> {
283+
type Error = Error<O, E>;
284+
type Obligation = O;
285+
286+
fn new() -> Self {
287+
Self { stalled: true, errors: vec![] }
288+
}
289+
290+
fn mark_not_stalled(&mut self) {
291+
self.stalled = false;
292+
}
293+
294+
fn is_stalled(&self) -> bool {
295+
self.stalled
296+
}
297+
298+
fn record_completed(&mut self, _outcome: &Self::Obligation) {
299+
// do nothing
300+
}
301+
302+
fn record_error(&mut self, error: Self::Error) {
303+
self.errors.push(error)
304+
}
278305
}
279306

280307
#[derive(Debug, PartialEq, Eq)]
@@ -289,7 +316,7 @@ impl<O: ForestObligation> ObligationForest<O> {
289316
nodes: vec![],
290317
done_cache: Default::default(),
291318
active_cache: Default::default(),
292-
node_rewrites: vec![],
319+
reused_node_vec: vec![],
293320
obligation_tree_id_generator: (0..).map(ObligationTreeId),
294321
error_cache: Default::default(),
295322
}
@@ -363,8 +390,7 @@ impl<O: ForestObligation> ObligationForest<O> {
363390
.map(|(index, _node)| Error { error: error.clone(), backtrace: self.error_at(index) })
364391
.collect();
365392

366-
let successful_obligations = self.compress(DoCompleted::Yes);
367-
assert!(successful_obligations.unwrap().is_empty());
393+
self.compress(|_| assert!(false));
368394
errors
369395
}
370396

@@ -392,16 +418,12 @@ impl<O: ForestObligation> ObligationForest<O> {
392418
/// be called in a loop until `outcome.stalled` is false.
393419
///
394420
/// This _cannot_ be unrolled (presently, at least).
395-
pub fn process_obligations<P>(
396-
&mut self,
397-
processor: &mut P,
398-
do_completed: DoCompleted,
399-
) -> Outcome<O, P::Error>
421+
pub fn process_obligations<P, OUT>(&mut self, processor: &mut P) -> OUT
400422
where
401423
P: ObligationProcessor<Obligation = O>,
424+
OUT: OutcomeTrait<Obligation = O, Error = Error<O, P::Error>>,
402425
{
403-
let mut errors = vec![];
404-
let mut stalled = true;
426+
let mut outcome = OUT::new();
405427

406428
// Note that the loop body can append new nodes, and those new nodes
407429
// will then be processed by subsequent iterations of the loop.
@@ -429,7 +451,7 @@ impl<O: ForestObligation> ObligationForest<O> {
429451
}
430452
ProcessResult::Changed(children) => {
431453
// We are not (yet) stalled.
432-
stalled = false;
454+
outcome.mark_not_stalled();
433455
node.state.set(NodeState::Success);
434456

435457
for child in children {
@@ -442,28 +464,22 @@ impl<O: ForestObligation> ObligationForest<O> {
442464
}
443465
}
444466
ProcessResult::Error(err) => {
445-
stalled = false;
446-
errors.push(Error { error: err, backtrace: self.error_at(index) });
467+
outcome.mark_not_stalled();
468+
outcome.record_error(Error { error: err, backtrace: self.error_at(index) });
447469
}
448470
}
449471
index += 1;
450472
}
451473

452-
if stalled {
453-
// There's no need to perform marking, cycle processing and compression when nothing
454-
// changed.
455-
return Outcome {
456-
completed: if do_completed == DoCompleted::Yes { Some(vec![]) } else { None },
457-
errors,
458-
stalled,
459-
};
474+
// There's no need to perform marking, cycle processing and compression when nothing
475+
// changed.
476+
if !outcome.is_stalled() {
477+
self.mark_successes();
478+
self.process_cycles(processor);
479+
self.compress(|obl| outcome.record_completed(obl));
460480
}
461481

462-
self.mark_successes();
463-
self.process_cycles(processor);
464-
let completed = self.compress(do_completed);
465-
466-
Outcome { completed, errors, stalled }
482+
outcome
467483
}
468484

469485
/// Returns a vector of obligations for `p` and all of its
@@ -526,7 +542,6 @@ impl<O: ForestObligation> ObligationForest<O> {
526542
let node = &self.nodes[index];
527543
let state = node.state.get();
528544
if state == NodeState::Success {
529-
node.state.set(NodeState::Waiting);
530545
// This call site is cold.
531546
self.uninlined_mark_dependents_as_waiting(node);
532547
} else {
@@ -538,17 +553,18 @@ impl<O: ForestObligation> ObligationForest<O> {
538553
// This never-inlined function is for the cold call site.
539554
#[inline(never)]
540555
fn uninlined_mark_dependents_as_waiting(&self, node: &Node<O>) {
556+
// Mark node Waiting in the cold uninlined code instead of the hot inlined
557+
node.state.set(NodeState::Waiting);
541558
self.inlined_mark_dependents_as_waiting(node)
542559
}
543560

544561
/// Report cycles between all `Success` nodes, and convert all `Success`
545562
/// nodes to `Done`. This must be called after `mark_successes`.
546-
fn process_cycles<P>(&self, processor: &mut P)
563+
fn process_cycles<P>(&mut self, processor: &mut P)
547564
where
548565
P: ObligationProcessor<Obligation = O>,
549566
{
550-
let mut stack = vec![];
551-
567+
let mut stack = std::mem::take(&mut self.reused_node_vec);
552568
for (index, node) in self.nodes.iter().enumerate() {
553569
// For some benchmarks this state test is extremely hot. It's a win
554570
// to handle the no-op cases immediately to avoid the cost of the
@@ -559,6 +575,7 @@ impl<O: ForestObligation> ObligationForest<O> {
559575
}
560576

561577
debug_assert!(stack.is_empty());
578+
self.reused_node_vec = stack;
562579
}
563580

564581
fn find_cycles_from_node<P>(&self, stack: &mut Vec<usize>, processor: &mut P, index: usize)
@@ -591,13 +608,12 @@ impl<O: ForestObligation> ObligationForest<O> {
591608
/// indices and hence invalidates any outstanding indices. `process_cycles`
592609
/// must be run beforehand to remove any cycles on `Success` nodes.
593610
#[inline(never)]
594-
fn compress(&mut self, do_completed: DoCompleted) -> Option<Vec<O>> {
611+
fn compress(&mut self, mut outcome_cb: impl FnMut(&O)) {
595612
let orig_nodes_len = self.nodes.len();
596-
let mut node_rewrites: Vec<_> = std::mem::take(&mut self.node_rewrites);
613+
let mut node_rewrites: Vec<_> = std::mem::take(&mut self.reused_node_vec);
597614
debug_assert!(node_rewrites.is_empty());
598615
node_rewrites.extend(0..orig_nodes_len);
599616
let mut dead_nodes = 0;
600-
let mut removed_done_obligations: Vec<O> = vec![];
601617

602618
// Move removable nodes to the end, preserving the order of the
603619
// remaining nodes.
@@ -627,10 +643,8 @@ impl<O: ForestObligation> ObligationForest<O> {
627643
} else {
628644
self.done_cache.insert(node.obligation.as_cache_key().clone());
629645
}
630-
if do_completed == DoCompleted::Yes {
631-
// Extract the success stories.
632-
removed_done_obligations.push(node.obligation.clone());
633-
}
646+
// Extract the success stories.
647+
outcome_cb(&node.obligation);
634648
node_rewrites[index] = orig_nodes_len;
635649
dead_nodes += 1;
636650
}
@@ -654,9 +668,7 @@ impl<O: ForestObligation> ObligationForest<O> {
654668
}
655669

656670
node_rewrites.truncate(0);
657-
self.node_rewrites = node_rewrites;
658-
659-
if do_completed == DoCompleted::Yes { Some(removed_done_obligations) } else { None }
671+
self.reused_node_vec = node_rewrites;
660672
}
661673

662674
fn apply_rewrites(&mut self, node_rewrites: &[usize]) {

0 commit comments

Comments
 (0)