Skip to content

Commit 5a8bcf9

Browse files
Relate type variables during predicate registration
This avoids select_obligations_where_possible having an effect, which future proofs the algorithm.
1 parent c4532c0 commit 5a8bcf9

File tree

9 files changed

+182
-72
lines changed

9 files changed

+182
-72
lines changed

compiler/rustc_infer/src/traits/engine.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::infer::InferCtxt;
22
use crate::traits::Obligation;
3+
use rustc_data_structures::fx::FxHashMap;
34
use rustc_hir as hir;
45
use rustc_hir::def_id::DefId;
56
use rustc_middle::ty::{self, ToPredicate, Ty, WithConstness};
@@ -73,6 +74,8 @@ pub trait TraitEngine<'tcx>: 'tcx {
7374
}
7475

7576
fn pending_obligations(&self) -> Vec<PredicateObligation<'tcx>>;
77+
78+
fn relationships(&mut self) -> &mut FxHashMap<ty::TyVid, ty::FoundRelationships>;
7679
}
7780

7881
pub trait TraitEngineExt<'tcx> {

compiler/rustc_middle/src/ty/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -2084,3 +2084,16 @@ impl<'tcx> fmt::Debug for SymbolName<'tcx> {
20842084
fmt::Display::fmt(&self.name, fmt)
20852085
}
20862086
}
2087+
2088+
#[derive(Debug, Default, Copy, Clone)]
2089+
pub struct FoundRelationships {
2090+
/// This is true if we identified that this Ty (`?T`) is found in a `?T: Foo`
2091+
/// obligation, where:
2092+
///
2093+
/// * `Foo` is not `Sized`
2094+
/// * `(): Foo` may be satisfied
2095+
pub self_in_trait: bool,
2096+
/// This is true if we identified that this Ty (`?T`) is found in a `<_ as
2097+
/// _>::AssocType = ?T`
2098+
pub output: bool,
2099+
}

compiler/rustc_trait_selection/src/traits/chalk_fulfill.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,21 @@ use crate::traits::{
77
ChalkEnvironmentAndGoal, FulfillmentError, FulfillmentErrorCode, ObligationCause,
88
PredicateObligation, SelectionError, TraitEngine,
99
};
10-
use rustc_data_structures::fx::FxIndexSet;
10+
use rustc_data_structures::fx::{FxHashMap, FxIndexSet};
1111
use rustc_middle::ty::{self, Ty};
1212

1313
pub struct FulfillmentContext<'tcx> {
1414
obligations: FxIndexSet<PredicateObligation<'tcx>>,
15+
16+
relationships: FxHashMap<ty::TyVid, ty::FoundRelationships>,
1517
}
1618

1719
impl FulfillmentContext<'tcx> {
1820
crate fn new() -> Self {
19-
FulfillmentContext { obligations: FxIndexSet::default() }
21+
FulfillmentContext {
22+
obligations: FxIndexSet::default(),
23+
relationships: FxHashMap::default(),
24+
}
2025
}
2126
}
2227

@@ -39,6 +44,8 @@ impl TraitEngine<'tcx> for FulfillmentContext<'tcx> {
3944
assert!(!infcx.is_in_snapshot());
4045
let obligation = infcx.resolve_vars_if_possible(obligation);
4146

47+
super::relationships::update(self, infcx, &obligation);
48+
4249
self.obligations.insert(obligation);
4350
}
4451

@@ -149,4 +156,8 @@ impl TraitEngine<'tcx> for FulfillmentContext<'tcx> {
149156
fn pending_obligations(&self) -> Vec<PredicateObligation<'tcx>> {
150157
self.obligations.iter().cloned().collect()
151158
}
159+
160+
fn relationships(&mut self) -> &mut FxHashMap<ty::TyVid, ty::FoundRelationships> {
161+
&mut self.relationships
162+
}
152163
}

compiler/rustc_trait_selection/src/traits/fulfill.rs

+13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::infer::{InferCtxt, TyOrConstInferVar};
2+
use rustc_data_structures::fx::FxHashMap;
23
use rustc_data_structures::obligation_forest::ProcessResult;
34
use rustc_data_structures::obligation_forest::{Error, ForestObligation, Outcome};
45
use rustc_data_structures::obligation_forest::{ObligationForest, ObligationProcessor};
@@ -53,6 +54,9 @@ pub struct FulfillmentContext<'tcx> {
5354
// A list of all obligations that have been registered with this
5455
// fulfillment context.
5556
predicates: ObligationForest<PendingPredicateObligation<'tcx>>,
57+
58+
relationships: FxHashMap<ty::TyVid, ty::FoundRelationships>,
59+
5660
// Should this fulfillment context register type-lives-for-region
5761
// obligations on its parent infcx? In some cases, region
5862
// obligations are either already known to hold (normalization) or
@@ -97,6 +101,7 @@ impl<'a, 'tcx> FulfillmentContext<'tcx> {
97101
pub fn new() -> FulfillmentContext<'tcx> {
98102
FulfillmentContext {
99103
predicates: ObligationForest::new(),
104+
relationships: FxHashMap::default(),
100105
register_region_obligations: true,
101106
usable_in_snapshot: false,
102107
}
@@ -105,6 +110,7 @@ impl<'a, 'tcx> FulfillmentContext<'tcx> {
105110
pub fn new_in_snapshot() -> FulfillmentContext<'tcx> {
106111
FulfillmentContext {
107112
predicates: ObligationForest::new(),
113+
relationships: FxHashMap::default(),
108114
register_region_obligations: true,
109115
usable_in_snapshot: true,
110116
}
@@ -113,6 +119,7 @@ impl<'a, 'tcx> FulfillmentContext<'tcx> {
113119
pub fn new_ignoring_regions() -> FulfillmentContext<'tcx> {
114120
FulfillmentContext {
115121
predicates: ObligationForest::new(),
122+
relationships: FxHashMap::default(),
116123
register_region_obligations: false,
117124
usable_in_snapshot: false,
118125
}
@@ -210,6 +217,8 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
210217

211218
assert!(!infcx.is_in_snapshot() || self.usable_in_snapshot);
212219

220+
super::relationships::update(self, infcx, &obligation);
221+
213222
self.predicates
214223
.register_obligation(PendingPredicateObligation { obligation, stalled_on: vec![] });
215224
}
@@ -265,6 +274,10 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentContext<'tcx> {
265274
fn pending_obligations(&self) -> Vec<PredicateObligation<'tcx>> {
266275
self.predicates.map_pending_obligations(|o| o.obligation.clone())
267276
}
277+
278+
fn relationships(&mut self) -> &mut FxHashMap<ty::TyVid, ty::FoundRelationships> {
279+
&mut self.relationships
280+
}
268281
}
269282

270283
struct FulfillProcessor<'a, 'b, 'tcx> {

compiler/rustc_trait_selection/src/traits/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ mod object_safety;
1515
mod on_unimplemented;
1616
mod project;
1717
pub mod query;
18+
pub(crate) mod relationships;
1819
mod select;
1920
mod specialize;
2021
mod structural_match;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
use crate::infer::InferCtxt;
2+
use crate::traits::query::evaluate_obligation::InferCtxtExt;
3+
use crate::traits::{ObligationCause, PredicateObligation};
4+
use rustc_infer::traits::TraitEngine;
5+
use rustc_middle::ty::{self, ToPredicate, TypeFoldable};
6+
7+
pub(crate) fn update<'tcx, T>(
8+
engine: &mut T,
9+
infcx: &InferCtxt<'_, 'tcx>,
10+
obligation: &PredicateObligation<'tcx>,
11+
) where
12+
T: TraitEngine<'tcx>,
13+
{
14+
if let ty::PredicateKind::Trait(predicate) = obligation.predicate.kind().skip_binder() {
15+
if predicate.trait_ref.def_id
16+
!= infcx.tcx.require_lang_item(rustc_hir::LangItem::Sized, None)
17+
{
18+
// fixme: copy of mk_trait_obligation_with_new_self_ty
19+
let new_self_ty = infcx.tcx.types.unit;
20+
21+
let trait_ref = ty::TraitRef {
22+
substs: infcx.tcx.mk_substs_trait(new_self_ty, &predicate.trait_ref.substs[1..]),
23+
..predicate.trait_ref
24+
};
25+
26+
// Then contstruct a new obligation with Self = () added
27+
// to the ParamEnv, and see if it holds.
28+
let o = rustc_infer::traits::Obligation::new(
29+
ObligationCause::dummy(),
30+
obligation.param_env,
31+
obligation
32+
.predicate
33+
.kind()
34+
.map_bound(|_| {
35+
ty::PredicateKind::Trait(ty::TraitPredicate {
36+
trait_ref,
37+
constness: predicate.constness,
38+
})
39+
})
40+
.to_predicate(infcx.tcx),
41+
);
42+
// Don't report overflow errors. Otherwise equivalent to may_hold.
43+
if let Ok(result) = infcx.probe(|_| infcx.evaluate_obligation(&o)) {
44+
if result.may_apply() {
45+
if let Some(ty) = infcx
46+
.shallow_resolve(predicate.self_ty())
47+
.ty_vid()
48+
.map(|t| infcx.root_var(t))
49+
{
50+
debug!("relationship: {:?}.self_in_trait = true", ty);
51+
engine.relationships().entry(ty).or_default().self_in_trait = true;
52+
} else {
53+
debug!("relationship: did not find TyVid for self ty...");
54+
}
55+
}
56+
}
57+
}
58+
}
59+
60+
if let ty::PredicateKind::Projection(predicate) = obligation.predicate.kind().skip_binder() {
61+
// If the projection predicate (Foo::Bar == X) has X as a non-TyVid,
62+
// we need to make it into one.
63+
if let Some(vid) = predicate.ty.ty_vid() {
64+
debug!("relationship: {:?}.output = true", vid);
65+
engine.relationships().entry(vid).or_default().output = true;
66+
} else {
67+
// This will have registered a projection obligation that will hit
68+
// the Some(vid) branch above. So we don't need to do anything
69+
// further here.
70+
debug!(
71+
"skipping relationship for obligation {:?} -- would need to normalize",
72+
obligation
73+
);
74+
if !predicate.projection_ty.has_escaping_bound_vars() {
75+
// FIXME: We really *should* do this even with escaping bound
76+
// vars, but there's not much we can do here. In the worst case
77+
// (if this ends up being important) we just don't register a relationship and then end up falling back to !,
78+
// which is not terrible.
79+
80+
//engine.normalize_projection_type(
81+
// infcx,
82+
// obligation.param_env,
83+
// predicate.projection_ty,
84+
// obligation.cause.clone(),
85+
//);
86+
}
87+
}
88+
}
89+
}

compiler/rustc_typeck/src/check/fallback.rs

+17-70
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,8 @@ use rustc_data_structures::{
55
graph::{iterate::DepthFirstSearch, vec_graph::VecGraph},
66
stable_set::FxHashSet,
77
};
8-
use rustc_middle::traits;
9-
use rustc_middle::ty::{self, ToPredicate, Ty};
10-
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
11-
12-
#[derive(Default, Copy, Clone)]
13-
struct FoundRelationships {
14-
/// This is true if we identified that this Ty (`?T`) is found in a `?T: Foo`
15-
/// obligation, where:
16-
///
17-
/// * `Foo` is not `Sized`
18-
/// * `(): Foo` may be satisfied
19-
self_in_trait: bool,
20-
/// This is true if we identified that this Ty (`?T`) is found in a `<_ as
21-
/// _>::AssocType = ?T`
22-
output: bool,
23-
}
8+
use rustc_middle::ty::{self, Ty};
9+
2410
impl<'tcx> FnCtxt<'_, 'tcx> {
2511
/// Performs type inference fallback, returning true if any fallback
2612
/// occurs.
@@ -30,60 +16,12 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
3016
self.fulfillment_cx.borrow_mut().pending_obligations()
3117
);
3218

33-
let mut relationships: FxHashMap<ty::TyVid, FoundRelationships> = FxHashMap::default();
34-
for obligation in self.fulfillment_cx.borrow_mut().pending_obligations() {
35-
if let ty::PredicateKind::Trait(predicate, constness) =
36-
obligation.predicate.kind().skip_binder()
37-
{
38-
if predicate.trait_ref.def_id
39-
!= self.infcx.tcx.require_lang_item(rustc_hir::LangItem::Sized, None)
40-
{
41-
// fixme: copy of mk_trait_obligation_with_new_self_ty
42-
let new_self_ty = self.infcx.tcx.types.unit;
43-
44-
let trait_ref = ty::TraitRef {
45-
substs: self
46-
.infcx
47-
.tcx
48-
.mk_substs_trait(new_self_ty, &predicate.trait_ref.substs[1..]),
49-
..predicate.trait_ref
50-
};
51-
52-
// Then contstruct a new obligation with Self = () added
53-
// to the ParamEnv, and see if it holds.
54-
let o = rustc_infer::traits::Obligation::new(
55-
traits::ObligationCause::dummy(),
56-
obligation.param_env,
57-
obligation
58-
.predicate
59-
.kind()
60-
.map_bound(|_| {
61-
ty::PredicateKind::Trait(
62-
ty::TraitPredicate { trait_ref },
63-
constness,
64-
)
65-
})
66-
.to_predicate(self.infcx.tcx),
67-
);
68-
if self.infcx.predicate_may_hold(&o) {
69-
if let Some(ty) = self.root_vid(predicate.self_ty()) {
70-
relationships.entry(ty).or_default().self_in_trait = true;
71-
}
72-
}
73-
}
74-
}
75-
if let ty::PredicateKind::Projection(predicate) =
76-
obligation.predicate.kind().skip_binder()
77-
{
78-
if let Some(ty) = self.root_vid(predicate.ty) {
79-
relationships.entry(ty).or_default().output = true;
80-
}
81-
}
82-
}
83-
8419
// All type checking constraints were added, try to fallback unsolved variables.
8520
self.select_obligations_where_possible(false, |_| {});
8621

22+
let relationships = self.fulfillment_cx.borrow_mut().relationships().clone();
23+
24+
debug!("relationships: {:#?}", relationships);
8725
debug!(
8826
"type-inference-fallback post selection obligations: {:#?}",
8927
self.fulfillment_cx.borrow_mut().pending_obligations()
@@ -94,7 +32,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
9432
// Check if we have any unsolved varibales. If not, no need for fallback.
9533
let unsolved_variables = self.unsolved_variables();
9634
if unsolved_variables.is_empty() {
97-
return;
35+
return false;
9836
}
9937

10038
let diverging_fallback =
@@ -324,7 +262,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
324262
fn calculate_diverging_fallback(
325263
&self,
326264
unsolved_variables: &[Ty<'tcx>],
327-
relationships: &FxHashMap<ty::TyVid, FoundRelationships>,
265+
relationships: &FxHashMap<ty::TyVid, ty::FoundRelationships>,
328266
) -> FxHashMap<Ty<'tcx>, Ty<'tcx>> {
329267
debug!("calculate_diverging_fallback({:?})", unsolved_variables);
330268

@@ -413,6 +351,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
413351

414352
debug!("inherited: {:#?}", self.inh.fulfillment_cx.borrow_mut().pending_obligations());
415353
debug!("obligations: {:#?}", self.fulfillment_cx.borrow_mut().pending_obligations());
354+
debug!("relationships: {:#?}", relationships);
416355

417356
// For each diverging variable, figure out whether it can
418357
// reach a member of N. If so, it falls back to `()`. Else
@@ -426,7 +365,15 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
426365
.depth_first_search(root_vid)
427366
.any(|n| roots_reachable_from_non_diverging.visited(n));
428367

429-
let relationship = relationships.get(&root_vid).copied().unwrap_or_default();
368+
let mut relationship = ty::FoundRelationships { self_in_trait: false, output: false };
369+
370+
for (vid, rel) in relationships.iter() {
371+
//if self.infcx.shallow_resolve(*ty).ty_vid().map(|t| self.infcx.root_var(t))
372+
if self.infcx.root_var(*vid) == root_vid {
373+
relationship.self_in_trait |= rel.self_in_trait;
374+
relationship.output |= rel.output;
375+
}
376+
}
430377

431378
if relationship.self_in_trait && relationship.output {
432379
debug!("fallback to () - found trait and projection: {:?}", diverging_vid);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use std::marker::PhantomData;
2+
3+
fn main() {
4+
let error = Closure::wrap(Box::new(move || {
5+
//~^ ERROR type mismatch
6+
panic!("Can't connect to server.");
7+
}) as Box<dyn FnMut()>);
8+
}
9+
10+
struct Closure<T: ?Sized>(PhantomData<T>);
11+
12+
impl<T: ?Sized> Closure<T> {
13+
fn wrap(data: Box<T>) -> Closure<T> {
14+
todo!()
15+
}
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
error[E0271]: type mismatch resolving `<[closure@$DIR/fallback-closure-wrap.rs:4:40: 7:6] as FnOnce<()>>::Output == ()`
2+
--> $DIR/fallback-closure-wrap.rs:4:31
3+
|
4+
LL | let error = Closure::wrap(Box::new(move || {
5+
| _______________________________^
6+
LL | |
7+
LL | | panic!("Can't connect to server.");
8+
LL | | }) as Box<dyn FnMut()>);
9+
| |______^ expected `()`, found `!`
10+
|
11+
= note: expected unit type `()`
12+
found type `!`
13+
= note: required for the cast to the object type `dyn FnMut()`
14+
15+
error: aborting due to previous error
16+
17+
For more information about this error, try `rustc --explain E0271`.

0 commit comments

Comments
 (0)