Skip to content

Commit 7383ab7

Browse files
committed
Auto merge of #113154 - lcnr:better-probe-check, r=compiler-errors
change snapshot tracking in fulfillment contexts use the exact snapshot number to prevent misuse even when created inside of a snapshot
2 parents e013d8f + d04775d commit 7383ab7

File tree

23 files changed

+80
-122
lines changed

23 files changed

+80
-122
lines changed

compiler/rustc_hir_analysis/src/astconv/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1984,7 +1984,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
19841984
.copied()
19851985
.filter(|&(impl_, _)| {
19861986
infcx.probe(|_| {
1987-
let ocx = ObligationCtxt::new_in_snapshot(&infcx);
1987+
let ocx = ObligationCtxt::new(&infcx);
19881988
ocx.register_obligations(obligations.clone());
19891989

19901990
let impl_substs = infcx.fresh_substs_for_item(span, impl_);

compiler/rustc_hir_analysis/src/autoderef.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ impl<'a, 'tcx> Autoderef<'a, 'tcx> {
161161
&self,
162162
ty: Ty<'tcx>,
163163
) -> Option<(Ty<'tcx>, Vec<traits::PredicateObligation<'tcx>>)> {
164-
let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new_in_snapshot(self.infcx);
164+
let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(self.infcx);
165165

166166
let cause = traits::ObligationCause::misc(self.span, self.body_id);
167167
let normalized_ty = match self

compiler/rustc_hir_analysis/src/collect.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1328,7 +1328,7 @@ fn suggest_impl_trait<'tcx>(
13281328
{
13291329
continue;
13301330
}
1331-
let ocx = ObligationCtxt::new_in_snapshot(&infcx);
1331+
let ocx = ObligationCtxt::new(&infcx);
13321332
let item_ty = ocx.normalize(
13331333
&ObligationCause::misc(span, def_id),
13341334
param_env,

compiler/rustc_hir_typeck/src/coercion.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
10331033
let Ok(ok) = coerce.coerce(source, target) else {
10341034
return false;
10351035
};
1036-
let ocx = ObligationCtxt::new_in_snapshot(self);
1036+
let ocx = ObligationCtxt::new(self);
10371037
ocx.register_obligations(ok.obligations);
10381038
ocx.select_where_possible().is_empty()
10391039
})

compiler/rustc_hir_typeck/src/expr.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2962,7 +2962,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
29622962
};
29632963

29642964
self.commit_if_ok(|_| {
2965-
let ocx = ObligationCtxt::new_in_snapshot(self);
2965+
let ocx = ObligationCtxt::new(self);
29662966
let impl_substs = self.fresh_substs_for_item(base_expr.span, impl_def_id);
29672967
let impl_trait_ref =
29682968
self.tcx.impl_trait_ref(impl_def_id).unwrap().subst(self.tcx, impl_substs);

compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
746746

747747
let expect_args = self
748748
.fudge_inference_if_ok(|| {
749-
let ocx = ObligationCtxt::new_in_snapshot(self);
749+
let ocx = ObligationCtxt::new(self);
750750

751751
// Attempt to apply a subtyping relationship between the formal
752752
// return type (likely containing type variables if the function

compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
163163
return fn_sig;
164164
}
165165
self.probe(|_| {
166-
let ocx = ObligationCtxt::new_in_snapshot(self);
166+
let ocx = ObligationCtxt::new(self);
167167
let normalized_fn_sig =
168168
ocx.normalize(&ObligationCause::dummy(), self.param_env, fn_sig);
169169
if ocx.select_all_or_error().is_empty() {

compiler/rustc_infer/src/infer/at.rs

-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ impl<'tcx> InferCtxt<'tcx> {
7979
reported_closure_mismatch: self.reported_closure_mismatch.clone(),
8080
tainted_by_errors: self.tainted_by_errors.clone(),
8181
err_count_on_creation: self.err_count_on_creation,
82-
in_snapshot: self.in_snapshot.clone(),
8382
universe: self.universe.clone(),
8483
intercrate: self.intercrate,
8584
next_trait_solver: self.next_trait_solver,

compiler/rustc_infer/src/infer/mod.rs

+13-28
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub use self::RegionVariableOrigin::*;
66
pub use self::SubregionOrigin::*;
77
pub use self::ValuePairs::*;
88
pub use combine::ObligationEmittingRelation;
9+
use rustc_data_structures::undo_log::UndoLogs;
910

1011
use self::opaque_types::OpaqueTypeStorage;
1112
pub(crate) use self::undo_log::{InferCtxtUndoLogs, Snapshot, UndoLog};
@@ -297,9 +298,6 @@ pub struct InferCtxt<'tcx> {
297298
// FIXME(matthewjasper) Merge into `tainted_by_errors`
298299
err_count_on_creation: usize,
299300

300-
/// This flag is true while there is an active snapshot.
301-
in_snapshot: Cell<bool>,
302-
303301
/// What is the innermost universe we have created? Starts out as
304302
/// `UniverseIndex::root()` but grows from there as we enter
305303
/// universal quantifiers.
@@ -643,7 +641,6 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
643641
reported_closure_mismatch: Default::default(),
644642
tainted_by_errors: Cell::new(None),
645643
err_count_on_creation: tcx.sess.err_count(),
646-
in_snapshot: Cell::new(false),
647644
universe: Cell::new(ty::UniverseIndex::ROOT),
648645
intercrate,
649646
next_trait_solver,
@@ -679,7 +676,6 @@ pub struct CombinedSnapshot<'tcx> {
679676
undo_snapshot: Snapshot<'tcx>,
680677
region_constraints_snapshot: RegionSnapshot,
681678
universe: ty::UniverseIndex,
682-
was_in_snapshot: bool,
683679
}
684680

685681
impl<'tcx> InferCtxt<'tcx> {
@@ -702,10 +698,6 @@ impl<'tcx> InferCtxt<'tcx> {
702698
}
703699
}
704700

705-
pub fn is_in_snapshot(&self) -> bool {
706-
self.in_snapshot.get()
707-
}
708-
709701
pub fn freshen<T: TypeFoldable<TyCtxt<'tcx>>>(&self, t: T) -> T {
710702
t.fold_with(&mut self.freshener())
711703
}
@@ -766,31 +758,30 @@ impl<'tcx> InferCtxt<'tcx> {
766758
}
767759
}
768760

761+
pub fn in_snapshot(&self) -> bool {
762+
UndoLogs::<UndoLog<'tcx>>::in_snapshot(&self.inner.borrow_mut().undo_log)
763+
}
764+
765+
pub fn num_open_snapshots(&self) -> usize {
766+
UndoLogs::<UndoLog<'tcx>>::num_open_snapshots(&self.inner.borrow_mut().undo_log)
767+
}
768+
769769
fn start_snapshot(&self) -> CombinedSnapshot<'tcx> {
770770
debug!("start_snapshot()");
771771

772-
let in_snapshot = self.in_snapshot.replace(true);
773-
774772
let mut inner = self.inner.borrow_mut();
775773

776774
CombinedSnapshot {
777775
undo_snapshot: inner.undo_log.start_snapshot(),
778776
region_constraints_snapshot: inner.unwrap_region_constraints().start_snapshot(),
779777
universe: self.universe(),
780-
was_in_snapshot: in_snapshot,
781778
}
782779
}
783780

784781
#[instrument(skip(self, snapshot), level = "debug")]
785782
fn rollback_to(&self, cause: &str, snapshot: CombinedSnapshot<'tcx>) {
786-
let CombinedSnapshot {
787-
undo_snapshot,
788-
region_constraints_snapshot,
789-
universe,
790-
was_in_snapshot,
791-
} = snapshot;
792-
793-
self.in_snapshot.set(was_in_snapshot);
783+
let CombinedSnapshot { undo_snapshot, region_constraints_snapshot, universe } = snapshot;
784+
794785
self.universe.set(universe);
795786

796787
let mut inner = self.inner.borrow_mut();
@@ -800,14 +791,8 @@ impl<'tcx> InferCtxt<'tcx> {
800791

801792
#[instrument(skip(self, snapshot), level = "debug")]
802793
fn commit_from(&self, snapshot: CombinedSnapshot<'tcx>) {
803-
let CombinedSnapshot {
804-
undo_snapshot,
805-
region_constraints_snapshot: _,
806-
universe: _,
807-
was_in_snapshot,
808-
} = snapshot;
809-
810-
self.in_snapshot.set(was_in_snapshot);
794+
let CombinedSnapshot { undo_snapshot, region_constraints_snapshot: _, universe: _ } =
795+
snapshot;
811796

812797
self.inner.borrow_mut().commit(undo_snapshot);
813798
}

compiler/rustc_infer/src/infer/outlives/obligations.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,7 @@ impl<'tcx> InferCtxt<'tcx> {
125125
/// right before lexical region resolution.
126126
#[instrument(level = "debug", skip(self, outlives_env))]
127127
pub fn process_registered_region_obligations(&self, outlives_env: &OutlivesEnvironment<'tcx>) {
128-
assert!(
129-
!self.in_snapshot.get(),
130-
"cannot process registered region obligations in a snapshot"
131-
);
128+
assert!(!self.in_snapshot(), "cannot process registered region obligations in a snapshot");
132129

133130
let my_region_obligations = self.take_registered_region_obligations();
134131

compiler/rustc_trait_selection/src/solve/fulfill.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,27 @@ use super::{Certainty, InferCtxtEvalExt};
2626
/// here as this will have to deal with far more root goals than `evaluate_all`.
2727
pub struct FulfillmentCtxt<'tcx> {
2828
obligations: Vec<PredicateObligation<'tcx>>,
29+
30+
/// The snapshot in which this context was created. Using the context
31+
/// outside of this snapshot leads to subtle bugs if the snapshot
32+
/// gets rolled back. Because of this we explicitly check that we only
33+
/// use the context in exactly this snapshot.
34+
usable_in_snapshot: usize,
2935
}
3036

3137
impl<'tcx> FulfillmentCtxt<'tcx> {
32-
pub fn new() -> FulfillmentCtxt<'tcx> {
33-
FulfillmentCtxt { obligations: Vec::new() }
38+
pub fn new(infcx: &InferCtxt<'tcx>) -> FulfillmentCtxt<'tcx> {
39+
FulfillmentCtxt { obligations: Vec::new(), usable_in_snapshot: infcx.num_open_snapshots() }
3440
}
3541
}
3642

3743
impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
3844
fn register_predicate_obligation(
3945
&mut self,
40-
_infcx: &InferCtxt<'tcx>,
46+
infcx: &InferCtxt<'tcx>,
4147
obligation: PredicateObligation<'tcx>,
4248
) {
49+
assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
4350
self.obligations.push(obligation);
4451
}
4552

@@ -77,6 +84,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
7784
}
7885

7986
fn select_where_possible(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
87+
assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
8088
let mut errors = Vec::new();
8189
for i in 0.. {
8290
if !infcx.tcx.recursion_limit().value_within_limit(i) {

compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs

+18-20
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@ use rustc_middle::ty::TypeVisitableExt;
1313
pub struct FulfillmentContext<'tcx> {
1414
obligations: FxIndexSet<PredicateObligation<'tcx>>,
1515

16-
usable_in_snapshot: bool,
16+
/// The snapshot in which this context was created. Using the context
17+
/// outside of this snapshot leads to subtle bugs if the snapshot
18+
/// gets rolled back. Because of this we explicitly check that we only
19+
/// use the context in exactly this snapshot.
20+
usable_in_snapshot: usize,
1721
}
1822

19-
impl FulfillmentContext<'_> {
20-
pub(super) fn new() -> Self {
21-
FulfillmentContext { obligations: FxIndexSet::default(), usable_in_snapshot: false }
22-
}
23-
24-
pub(crate) fn new_in_snapshot() -> Self {
25-
FulfillmentContext { usable_in_snapshot: true, ..Self::new() }
23+
impl<'tcx> FulfillmentContext<'tcx> {
24+
pub(super) fn new(infcx: &InferCtxt<'tcx>) -> Self {
25+
FulfillmentContext {
26+
obligations: FxIndexSet::default(),
27+
usable_in_snapshot: infcx.num_open_snapshots(),
28+
}
2629
}
2730
}
2831

@@ -32,9 +35,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
3235
infcx: &InferCtxt<'tcx>,
3336
obligation: PredicateObligation<'tcx>,
3437
) {
35-
if !self.usable_in_snapshot {
36-
assert!(!infcx.is_in_snapshot());
37-
}
38+
assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
3839
let obligation = infcx.resolve_vars_if_possible(obligation);
3940

4041
self.obligations.insert(obligation);
@@ -58,9 +59,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
5859
}
5960

6061
fn select_where_possible(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
61-
if !self.usable_in_snapshot {
62-
assert!(!infcx.is_in_snapshot());
63-
}
62+
assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
6463

6564
let mut errors = Vec::new();
6665
let mut next_round = FxIndexSet::default();
@@ -94,12 +93,11 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
9493
&orig_values,
9594
&response,
9695
) {
97-
Ok(infer_ok) => next_round.extend(
98-
infer_ok.obligations.into_iter().map(|obligation| {
99-
assert!(!infcx.is_in_snapshot());
100-
infcx.resolve_vars_if_possible(obligation)
101-
}),
102-
),
96+
Ok(infer_ok) => {
97+
next_round.extend(infer_ok.obligations.into_iter().map(
98+
|obligation| infcx.resolve_vars_if_possible(obligation),
99+
))
100+
}
103101

104102
Err(_err) => errors.push(FulfillmentError {
105103
obligation: obligation.clone(),

compiler/rustc_trait_selection/src/traits/const_evaluatable.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ fn satisfied_from_param_env<'tcx>(
176176
fn visit_const(&mut self, c: ty::Const<'tcx>) -> ControlFlow<Self::BreakTy> {
177177
debug!("is_const_evaluatable: candidate={:?}", c);
178178
if self.infcx.probe(|_| {
179-
let ocx = ObligationCtxt::new_in_snapshot(self.infcx);
179+
let ocx = ObligationCtxt::new(self.infcx);
180180
ocx.eq(&ObligationCause::dummy(), self.param_env, c.ty(), self.ct.ty()).is_ok()
181181
&& ocx.eq(&ObligationCause::dummy(), self.param_env, c, self.ct).is_ok()
182182
&& ocx.select_all_or_error().is_empty()
@@ -219,7 +219,7 @@ fn satisfied_from_param_env<'tcx>(
219219
}
220220

221221
if let Some(Ok(c)) = single_match {
222-
let ocx = ObligationCtxt::new_in_snapshot(infcx);
222+
let ocx = ObligationCtxt::new(infcx);
223223
assert!(ocx.eq(&ObligationCause::dummy(), param_env, c.ty(), ct.ty()).is_ok());
224224
assert!(ocx.eq(&ObligationCause::dummy(), param_env, c, ct).is_ok());
225225
assert!(ocx.select_all_or_error().is_empty());

compiler/rustc_trait_selection/src/traits/engine.rs

+3-25
Original file line numberDiff line numberDiff line change
@@ -28,36 +28,18 @@ use rustc_span::Span;
2828

2929
pub trait TraitEngineExt<'tcx> {
3030
fn new(infcx: &InferCtxt<'tcx>) -> Box<Self>;
31-
fn new_in_snapshot(infcx: &InferCtxt<'tcx>) -> Box<Self>;
3231
}
3332

3433
impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> {
3534
fn new(infcx: &InferCtxt<'tcx>) -> Box<Self> {
3635
match (infcx.tcx.sess.opts.unstable_opts.trait_solver, infcx.next_trait_solver()) {
3736
(TraitSolver::Classic, false) | (TraitSolver::NextCoherence, false) => {
38-
Box::new(FulfillmentContext::new())
37+
Box::new(FulfillmentContext::new(infcx))
3938
}
4039
(TraitSolver::Next | TraitSolver::NextCoherence, true) => {
41-
Box::new(NextFulfillmentCtxt::new())
40+
Box::new(NextFulfillmentCtxt::new(infcx))
4241
}
43-
(TraitSolver::Chalk, false) => Box::new(ChalkFulfillmentContext::new()),
44-
_ => bug!(
45-
"incompatible combination of -Ztrait-solver flag ({:?}) and InferCtxt::next_trait_solver ({:?})",
46-
infcx.tcx.sess.opts.unstable_opts.trait_solver,
47-
infcx.next_trait_solver()
48-
),
49-
}
50-
}
51-
52-
fn new_in_snapshot(infcx: &InferCtxt<'tcx>) -> Box<Self> {
53-
match (infcx.tcx.sess.opts.unstable_opts.trait_solver, infcx.next_trait_solver()) {
54-
(TraitSolver::Classic, false) | (TraitSolver::NextCoherence, false) => {
55-
Box::new(FulfillmentContext::new_in_snapshot())
56-
}
57-
(TraitSolver::Next | TraitSolver::NextCoherence, true) => {
58-
Box::new(NextFulfillmentCtxt::new())
59-
}
60-
(TraitSolver::Chalk, false) => Box::new(ChalkFulfillmentContext::new_in_snapshot()),
42+
(TraitSolver::Chalk, false) => Box::new(ChalkFulfillmentContext::new(infcx)),
6143
_ => bug!(
6244
"incompatible combination of -Ztrait-solver flag ({:?}) and InferCtxt::next_trait_solver ({:?})",
6345
infcx.tcx.sess.opts.unstable_opts.trait_solver,
@@ -79,10 +61,6 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
7961
Self { infcx, engine: RefCell::new(<dyn TraitEngine<'_>>::new(infcx)) }
8062
}
8163

82-
pub fn new_in_snapshot(infcx: &'a InferCtxt<'tcx>) -> Self {
83-
Self { infcx, engine: RefCell::new(<dyn TraitEngine<'_>>::new_in_snapshot(infcx)) }
84-
}
85-
8664
pub fn register_obligation(&self, obligation: PredicateObligation<'tcx>) {
8765
self.engine.borrow_mut().register_predicate_obligation(self.infcx, obligation);
8866
}

compiler/rustc_trait_selection/src/traits/error_reporting/ambiguity.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub fn recompute_applicable_impls<'tcx>(
2020
let param_env = obligation.param_env;
2121

2222
let impl_may_apply = |impl_def_id| {
23-
let ocx = ObligationCtxt::new_in_snapshot(infcx);
23+
let ocx = ObligationCtxt::new(infcx);
2424
let placeholder_obligation =
2525
infcx.instantiate_binder_with_placeholders(obligation.predicate);
2626
let obligation_trait_ref =
@@ -45,7 +45,7 @@ pub fn recompute_applicable_impls<'tcx>(
4545
};
4646

4747
let param_env_candidate_may_apply = |poly_trait_predicate: ty::PolyTraitPredicate<'tcx>| {
48-
let ocx = ObligationCtxt::new_in_snapshot(infcx);
48+
let ocx = ObligationCtxt::new(infcx);
4949
let placeholder_obligation =
5050
infcx.instantiate_binder_with_placeholders(obligation.predicate);
5151
let obligation_trait_ref =

compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
377377
param_env,
378378
ty.rebind(ty::TraitPredicate { trait_ref, constness, polarity }),
379379
);
380-
let ocx = ObligationCtxt::new_in_snapshot(self);
380+
let ocx = ObligationCtxt::new(self);
381381
ocx.register_obligation(obligation);
382382
if ocx.select_all_or_error().is_empty() {
383383
return Ok((
@@ -1599,7 +1599,7 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
15991599
}
16001600

16011601
self.probe(|_| {
1602-
let ocx = ObligationCtxt::new_in_snapshot(self);
1602+
let ocx = ObligationCtxt::new(self);
16031603

16041604
// try to find the mismatched types to report the error with.
16051605
//

0 commit comments

Comments
 (0)