Skip to content

Commit d2009f0

Browse files
authored
Rollup merge of rust-lang#80732 - spastorino:trait-inheritance-self2, r=nikomatsakis
Allow Trait inheritance with cycles on associated types take 2 This reverts the revert of rust-lang#79209 and fixes the ICEs that's occasioned by that PR exposing some problems that are addressed in rust-lang#80648 and rust-lang#79811. For easier review I'd say, check only the last commit, the first one is just a revert of the revert of rust-lang#79209 which was already approved. This also could be considered part or the actual fix of rust-lang#79560 but I guess for that to be closed and fixed completely we would need to land rust-lang#80648 and rust-lang#79811 too. r? `@nikomatsakis` cc `@Aaron1011`
2 parents 803d616 + 455a0e1 commit d2009f0

30 files changed

+550
-149
lines changed

compiler/rustc_infer/src/traits/util.rs

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use smallvec::smallvec;
22

33
use crate::traits::{Obligation, ObligationCause, PredicateObligation};
4-
use rustc_data_structures::fx::FxHashSet;
4+
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
55
use rustc_middle::ty::outlives::Component;
66
use rustc_middle::ty::{self, ToPredicate, TyCtxt, WithConstness};
7+
use rustc_span::symbol::Symbol;
78

89
pub fn anonymize_predicate<'tcx>(
910
tcx: TyCtxt<'tcx>,
@@ -287,6 +288,37 @@ pub fn transitive_bounds<'tcx>(
287288
elaborate_trait_refs(tcx, bounds).filter_to_traits()
288289
}
289290

291+
/// A specialized variant of `elaborate_trait_refs` that only elaborates trait references that may
292+
/// define the given associated type `assoc_name`. It uses the
293+
/// `super_predicates_that_define_assoc_type` query to avoid enumerating super-predicates that
294+
/// aren't related to `assoc_item`. This is used when resolving types like `Self::Item` or
295+
/// `T::Item` and helps to avoid cycle errors (see e.g. #35237).
296+
pub fn transitive_bounds_that_define_assoc_type<'tcx>(
297+
tcx: TyCtxt<'tcx>,
298+
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
299+
assoc_name: Symbol,
300+
) -> FxIndexSet<ty::PolyTraitRef<'tcx>> {
301+
let mut stack: Vec<_> = bounds.collect();
302+
let mut trait_refs = FxIndexSet::default();
303+
304+
while let Some(trait_ref) = stack.pop() {
305+
if trait_refs.insert(trait_ref) {
306+
let super_predicates =
307+
tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), Some(assoc_name)));
308+
for (super_predicate, _) in super_predicates.predicates {
309+
let bound_predicate = super_predicate.bound_atom();
310+
let subst_predicate = super_predicate
311+
.subst_supertrait(tcx, &bound_predicate.rebind(trait_ref.skip_binder()));
312+
if let Some(binder) = subst_predicate.to_opt_poly_trait_ref() {
313+
stack.push(binder.value);
314+
}
315+
}
316+
}
317+
}
318+
319+
trait_refs
320+
}
321+
290322
///////////////////////////////////////////////////////////////////////////
291323
// Other
292324
///////////////////////////////////////////////////////////////////////////

compiler/rustc_middle/src/query/mod.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,23 @@ rustc_queries! {
456456
/// full predicates are available (note that supertraits have
457457
/// additional acyclicity requirements).
458458
query super_predicates_of(key: DefId) -> ty::GenericPredicates<'tcx> {
459-
desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key) }
459+
desc { |tcx| "computing the super predicates of `{}`", tcx.def_path_str(key) }
460+
}
461+
462+
/// The `Option<Symbol>` is the name of an associated type. If it is `None`, then this query
463+
/// returns the full set of predicates. If `Some<Symbol>`, then the query returns only the
464+
/// subset of super-predicates that reference traits that define the given associated type.
465+
/// This is used to avoid cycles in resolving types like `T::Item`.
466+
query super_predicates_that_define_assoc_type(key: (DefId, Option<rustc_span::symbol::Symbol>)) -> ty::GenericPredicates<'tcx> {
467+
desc { |tcx| "computing the super traits of `{}`{}",
468+
tcx.def_path_str(key.0),
469+
if let Some(assoc_name) = key.1 { format!(" with associated type name `{}`", assoc_name) } else { "".to_string() },
470+
}
460471
}
461472

462473
/// To avoid cycles within the predicates of a single item we compute
463474
/// per-type-parameter predicates for resolving `T::AssocTy`.
464-
query type_param_predicates(key: (DefId, LocalDefId)) -> ty::GenericPredicates<'tcx> {
475+
query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Symbol)) -> ty::GenericPredicates<'tcx> {
465476
desc { |tcx| "computing the bounds for type parameter `{}`", {
466477
let id = tcx.hir().local_def_id_to_hir_id(key.1);
467478
tcx.hir().ty_param_name(id)

compiler/rustc_middle/src/ty/context.rs

+37
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,43 @@ impl<'tcx> TyCtxt<'tcx> {
20592059
self.mk_fn_ptr(sig.map_bound(|sig| ty::FnSig { unsafety: hir::Unsafety::Unsafe, ..sig }))
20602060
}
20612061

2062+
/// Given the def_id of a Trait `trait_def_id` and the name of an associated item `assoc_name`
2063+
/// returns true if the `trait_def_id` defines an associated item of name `assoc_name`.
2064+
pub fn trait_may_define_assoc_type(self, trait_def_id: DefId, assoc_name: Symbol) -> bool {
2065+
self.super_traits_of(trait_def_id).any(|trait_did| {
2066+
self.associated_items(trait_did)
2067+
.find_by_name_and_kind_unhygienic(assoc_name, ty::AssocKind::Type)
2068+
.next()
2069+
.is_some()
2070+
})
2071+
}
2072+
2073+
/// Computes the def-ids of the transitive super-traits of `trait_def_id`. This (intentionally)
2074+
/// does not compute the full elaborated super-predicates but just the set of def-ids. It is used
2075+
/// to identify which traits may define a given associated type to help avoid cycle errors.
2076+
/// Returns a `DefId` iterator.
2077+
fn super_traits_of(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx {
2078+
let mut set = FxHashSet::default();
2079+
let mut stack = vec![trait_def_id];
2080+
2081+
set.insert(trait_def_id);
2082+
2083+
iter::from_fn(move || -> Option<DefId> {
2084+
let trait_did = stack.pop()?;
2085+
let generic_predicates = self.super_predicates_of(trait_did);
2086+
2087+
for (predicate, _) in generic_predicates.predicates {
2088+
if let ty::PredicateAtom::Trait(data, _) = predicate.skip_binders() {
2089+
if set.insert(data.def_id()) {
2090+
stack.push(data.def_id());
2091+
}
2092+
}
2093+
}
2094+
2095+
Some(trait_did)
2096+
})
2097+
}
2098+
20622099
/// Given a closure signature, returns an equivalent fn signature. Detuples
20632100
/// and so forth -- so e.g., if we have a sig with `Fn<(u32, i32)>` then
20642101
/// you would get a `fn(u32, i32)`.

compiler/rustc_middle/src/ty/mod.rs

+10
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,16 @@ impl<'tcx> AssociatedItems<'tcx> {
320320
.find(|item| tcx.hygienic_eq(ident, item.ident, parent_def_id))
321321
}
322322

323+
/// Returns the associated item with the given name and `AssocKind`, if one exists, ignoring
324+
/// hygiene.
325+
pub fn find_by_name_and_kind_unhygienic(
326+
&self,
327+
name: Symbol,
328+
kind: AssocKind,
329+
) -> impl '_ + Iterator<Item = &ty::AssocItem> {
330+
self.filter_by_name_unhygienic(name).filter(move |item| item.kind == kind)
331+
}
332+
323333
/// Returns the associated item with the given name in the given `Namespace`, if one exists.
324334
pub fn find_by_name_and_namespace(
325335
&self,

compiler/rustc_middle/src/ty/query/keys.rs

+22
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,28 @@ impl Key for (LocalDefId, DefId) {
149149
}
150150
}
151151

152+
impl Key for (DefId, Option<Symbol>) {
153+
type CacheSelector = DefaultCacheSelector;
154+
155+
fn query_crate(&self) -> CrateNum {
156+
self.0.krate
157+
}
158+
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
159+
tcx.def_span(self.0)
160+
}
161+
}
162+
163+
impl Key for (DefId, LocalDefId, Symbol) {
164+
type CacheSelector = DefaultCacheSelector;
165+
166+
fn query_crate(&self) -> CrateNum {
167+
self.0.krate
168+
}
169+
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
170+
self.1.default_span(tcx)
171+
}
172+
}
173+
152174
impl Key for (CrateNum, DefId) {
153175
type CacheSelector = DefaultCacheSelector;
154176

compiler/rustc_trait_selection/src/traits/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ pub use self::util::{
6565
get_vtable_index_of_object_method, impl_item_is_final, predicate_for_trait_def, upcast_choices,
6666
};
6767
pub use self::util::{
68-
supertrait_def_ids, supertraits, transitive_bounds, SupertraitDefIds, Supertraits,
68+
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
69+
SupertraitDefIds, Supertraits,
6970
};
7071

7172
pub use self::chalk_fulfill::FulfillmentContext as ChalkFulfillmentContext;

compiler/rustc_typeck/src/astconv/mod.rs

+57-11
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ pub trait AstConv<'tcx> {
4949

5050
fn default_constness_for_trait_bounds(&self) -> Constness;
5151

52-
/// Returns predicates in scope of the form `X: Foo`, where `X` is
53-
/// a type parameter `X` with the given id `def_id`. This is a
54-
/// subset of the full set of predicates.
52+
/// Returns predicates in scope of the form `X: Foo<T>`, where `X`
53+
/// is a type parameter `X` with the given id `def_id` and T
54+
/// matches `assoc_name`. This is a subset of the full set of
55+
/// predicates.
5556
///
5657
/// This is used for one specific purpose: resolving "short-hand"
5758
/// associated type references like `T::Item`. In principle, we
@@ -60,7 +61,12 @@ pub trait AstConv<'tcx> {
6061
/// but this can lead to cycle errors. The problem is that we have
6162
/// to do this resolution *in order to create the predicates in
6263
/// the first place*. Hence, we have this "special pass".
63-
fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx>;
64+
fn get_type_parameter_bounds(
65+
&self,
66+
span: Span,
67+
def_id: DefId,
68+
assoc_name: Symbol,
69+
) -> ty::GenericPredicates<'tcx>;
6470

6571
/// Returns the lifetime to use when a lifetime is omitted (and not elided).
6672
fn re_infer(&self, param: Option<&ty::GenericParamDef>, span: Span)
@@ -783,7 +789,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
783789
}
784790

785791
// Returns `true` if a bounds list includes `?Sized`.
786-
pub fn is_unsized(&self, ast_bounds: &[hir::GenericBound<'_>], span: Span) -> bool {
792+
pub fn is_unsized(&self, ast_bounds: &[&hir::GenericBound<'_>], span: Span) -> bool {
787793
let tcx = self.tcx();
788794

789795
// Try to find an unbound in bounds.
@@ -841,7 +847,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
841847
fn add_bounds(
842848
&self,
843849
param_ty: Ty<'tcx>,
844-
ast_bounds: &[hir::GenericBound<'_>],
850+
ast_bounds: &[&hir::GenericBound<'_>],
845851
bounds: &mut Bounds<'tcx>,
846852
) {
847853
let constness = self.default_constness_for_trait_bounds();
@@ -856,7 +862,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
856862
hir::GenericBound::Trait(_, hir::TraitBoundModifier::Maybe) => {}
857863
hir::GenericBound::LangItemTrait(lang_item, span, hir_id, args) => self
858864
.instantiate_lang_item_trait_ref(
859-
lang_item, span, hir_id, args, param_ty, bounds,
865+
*lang_item, *span, *hir_id, args, param_ty, bounds,
860866
),
861867
hir::GenericBound::Outlives(ref l) => bounds
862868
.region_bounds
@@ -887,6 +893,42 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
887893
ast_bounds: &[hir::GenericBound<'_>],
888894
sized_by_default: SizedByDefault,
889895
span: Span,
896+
) -> Bounds<'tcx> {
897+
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
898+
self.compute_bounds_inner(param_ty, &ast_bounds, sized_by_default, span)
899+
}
900+
901+
/// Convert the bounds in `ast_bounds` that refer to traits which define an associated type
902+
/// named `assoc_name` into ty::Bounds. Ignore the rest.
903+
pub fn compute_bounds_that_match_assoc_type(
904+
&self,
905+
param_ty: Ty<'tcx>,
906+
ast_bounds: &[hir::GenericBound<'_>],
907+
sized_by_default: SizedByDefault,
908+
span: Span,
909+
assoc_name: Symbol,
910+
) -> Bounds<'tcx> {
911+
let mut result = Vec::new();
912+
913+
for ast_bound in ast_bounds {
914+
if let Some(trait_ref) = ast_bound.trait_ref() {
915+
if let Some(trait_did) = trait_ref.trait_def_id() {
916+
if self.tcx().trait_may_define_assoc_type(trait_did, assoc_name) {
917+
result.push(ast_bound);
918+
}
919+
}
920+
}
921+
}
922+
923+
self.compute_bounds_inner(param_ty, &result, sized_by_default, span)
924+
}
925+
926+
fn compute_bounds_inner(
927+
&self,
928+
param_ty: Ty<'tcx>,
929+
ast_bounds: &[&hir::GenericBound<'_>],
930+
sized_by_default: SizedByDefault,
931+
span: Span,
890932
) -> Bounds<'tcx> {
891933
let mut bounds = Bounds::default();
892934

@@ -1056,7 +1098,8 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
10561098
// Calling `skip_binder` is okay, because `add_bounds` expects the `param_ty`
10571099
// parameter to have a skipped binder.
10581100
let param_ty = tcx.mk_projection(assoc_ty.def_id, candidate.skip_binder().substs);
1059-
self.add_bounds(param_ty, ast_bounds, bounds);
1101+
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
1102+
self.add_bounds(param_ty, &ast_bounds, bounds);
10601103
}
10611104
}
10621105
Ok(())
@@ -1371,21 +1414,24 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
13711414
ty_param_def_id, assoc_name, span,
13721415
);
13731416

1374-
let predicates =
1375-
&self.get_type_parameter_bounds(span, ty_param_def_id.to_def_id()).predicates;
1417+
let predicates = &self
1418+
.get_type_parameter_bounds(span, ty_param_def_id.to_def_id(), assoc_name.name)
1419+
.predicates;
13761420

13771421
debug!("find_bound_for_assoc_item: predicates={:#?}", predicates);
13781422

13791423
let param_hir_id = tcx.hir().local_def_id_to_hir_id(ty_param_def_id);
13801424
let param_name = tcx.hir().ty_param_name(param_hir_id);
13811425
self.one_bound_for_assoc_type(
13821426
|| {
1383-
traits::transitive_bounds(
1427+
traits::transitive_bounds_that_define_assoc_type(
13841428
tcx,
13851429
predicates.iter().filter_map(|(p, _)| {
13861430
p.to_opt_poly_trait_ref().map(|trait_ref| trait_ref.value)
13871431
}),
1432+
assoc_name.name,
13881433
)
1434+
.into_iter()
13891435
},
13901436
|| param_name.to_string(),
13911437
assoc_name,

compiler/rustc_typeck/src/check/fn_ctxt/mod.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use rustc_middle::ty::fold::TypeFoldable;
2020
use rustc_middle::ty::subst::GenericArgKind;
2121
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
2222
use rustc_session::Session;
23+
use rustc_span::symbol::Symbol;
2324
use rustc_span::{self, Span};
2425
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};
2526

@@ -183,7 +184,12 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
183184
}
184185
}
185186

186-
fn get_type_parameter_bounds(&self, _: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
187+
fn get_type_parameter_bounds(
188+
&self,
189+
_: Span,
190+
def_id: DefId,
191+
_: Symbol,
192+
) -> ty::GenericPredicates<'tcx> {
187193
let tcx = self.tcx;
188194
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
189195
let item_id = tcx.hir().ty_param_owner(hir_id);

0 commit comments

Comments
 (0)