Skip to content

Commit b528120

Browse files
Collect relevant item bounds from trait clauses for nested rigid projections, GATs
1 parent 2836482 commit b528120

File tree

4 files changed

+288
-10
lines changed

4 files changed

+288
-10
lines changed

compiler/rustc_hir_analysis/src/collect/item_bounds.rs

+218-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use rustc_data_structures::fx::FxIndexSet;
1+
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
22
use rustc_hir as hir;
33
use rustc_infer::traits::util;
4+
use rustc_middle::ty::fold::shift_vars;
45
use rustc_middle::ty::{
5-
self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
6+
self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
67
};
78
use rustc_middle::{bug, span_bug};
89
use rustc_span::def_id::{DefId, LocalDefId};
@@ -41,14 +42,110 @@ fn associated_type_bounds<'tcx>(
4142
let trait_def_id = tcx.local_parent(assoc_item_def_id);
4243
let trait_predicates = tcx.trait_explicit_predicates_and_bounds(trait_def_id);
4344

44-
let bounds_from_parent = trait_predicates.predicates.iter().copied().filter(|(pred, _)| {
45-
match pred.kind().skip_binder() {
46-
ty::ClauseKind::Trait(tr) => tr.self_ty() == item_ty,
47-
ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty() == item_ty,
48-
ty::ClauseKind::TypeOutlives(outlives) => outlives.0 == item_ty,
49-
_ => false,
50-
}
51-
});
45+
let item_trait_ref = ty::TraitRef::identity(tcx, tcx.parent(assoc_item_def_id.to_def_id()));
46+
let bounds_from_parent =
47+
trait_predicates.predicates.iter().copied().filter_map(|(pred, span)| {
48+
let mut clause_ty = match pred.kind().skip_binder() {
49+
ty::ClauseKind::Trait(tr) => tr.self_ty(),
50+
ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty(),
51+
ty::ClauseKind::TypeOutlives(outlives) => outlives.0,
52+
_ => return None,
53+
};
54+
55+
// The code below is quite involved, so let me explain.
56+
//
57+
// We loop here, because we also want to collect vars for nested associated items as
58+
// well. For example, given a clause like `Self::A::B`, we want to add that to the
59+
// item bounds for `A`, so that we may use that bound in the case that `Self::A::B` is
60+
// rigid.
61+
//
62+
// Secondly, regarding bound vars, when we see a where clause that mentions a GAT
63+
// like `for<'a, ...> Self::Assoc<'a, ...>: Bound<'b, ...>`, we want to turn that into
64+
// an item bound on the GAT, where all of the GAT args are substituted with the GAT's
65+
// param regions, and then keep all of the other late-bound vars in the bound around.
66+
// We need to "compress" the binder so that it doesn't mention any of those vars that
67+
// were mapped to params.
68+
let gat_vars = loop {
69+
if let ty::Alias(ty::Projection, alias_ty) = *clause_ty.kind() {
70+
if alias_ty.trait_ref(tcx) == item_trait_ref
71+
&& alias_ty.def_id == assoc_item_def_id.to_def_id()
72+
{
73+
break &alias_ty.args[item_trait_ref.args.len()..];
74+
} else {
75+
// Only collect *self* type bounds if the filter is for self.
76+
match filter {
77+
PredicateFilter::SelfOnly | PredicateFilter::SelfThatDefines(_) => {
78+
return None;
79+
}
80+
PredicateFilter::All | PredicateFilter::SelfAndAssociatedTypeBounds => {
81+
}
82+
}
83+
84+
clause_ty = alias_ty.self_ty();
85+
continue;
86+
}
87+
}
88+
89+
return None;
90+
};
91+
// Special-case: No GAT vars, no mapping needed.
92+
if gat_vars.is_empty() {
93+
return Some((pred, span));
94+
}
95+
96+
// First, check that all of the GAT args are substituted with a unique late-bound arg.
97+
// If we find a duplicate, then it can't be mapped to the definition's params.
98+
let mut mapping = FxIndexMap::default();
99+
let generics = tcx.generics_of(assoc_item_def_id);
100+
for (param, var) in std::iter::zip(&generics.own_params, gat_vars) {
101+
let existing = match var.unpack() {
102+
ty::GenericArgKind::Lifetime(re) => {
103+
if let ty::RegionKind::ReBound(ty::INNERMOST, bv) = re.kind() {
104+
mapping.insert(bv.var, tcx.mk_param_from_def(param))
105+
} else {
106+
return None;
107+
}
108+
}
109+
ty::GenericArgKind::Type(ty) => {
110+
if let ty::Bound(ty::INNERMOST, bv) = *ty.kind() {
111+
mapping.insert(bv.var, tcx.mk_param_from_def(param))
112+
} else {
113+
return None;
114+
}
115+
}
116+
ty::GenericArgKind::Const(ct) => {
117+
if let ty::ConstKind::Bound(ty::INNERMOST, bv) = ct.kind() {
118+
mapping.insert(bv, tcx.mk_param_from_def(param))
119+
} else {
120+
return None;
121+
}
122+
}
123+
};
124+
125+
if existing.is_some() {
126+
return None;
127+
}
128+
}
129+
130+
// Finally, map all of the args in the GAT to the params we expect, and compress
131+
// the remaining late-bound vars so that they count up from var 0.
132+
let mut folder = MapAndCompressBoundVars {
133+
tcx,
134+
binder: ty::INNERMOST,
135+
still_bound_vars: vec![],
136+
mapping,
137+
};
138+
let pred = pred.kind().skip_binder().fold_with(&mut folder);
139+
140+
Some((
141+
ty::Binder::bind_with_vars(
142+
pred,
143+
tcx.mk_bound_variable_kinds(&folder.still_bound_vars),
144+
)
145+
.upcast(tcx),
146+
span,
147+
))
148+
});
52149

53150
let all_bounds = tcx.arena.alloc_from_iter(bounds.clauses(tcx).chain(bounds_from_parent));
54151
debug!(
@@ -59,6 +156,117 @@ fn associated_type_bounds<'tcx>(
59156
all_bounds
60157
}
61158

159+
struct MapAndCompressBoundVars<'tcx> {
160+
tcx: TyCtxt<'tcx>,
161+
/// How deep are we? Makes sure we don't touch the vars of nested binders.
162+
binder: ty::DebruijnIndex,
163+
/// List of bound vars that remain unsubstituted because they were not
164+
/// mentioned in the GAT's args.
165+
still_bound_vars: Vec<ty::BoundVariableKind>,
166+
/// Subtle invariant: If the `GenericArg` is bound, then it should be
167+
/// stored with the debruijn index of `INNERMOST` so it can be shifted
168+
/// correctly during substitution.
169+
mapping: FxIndexMap<ty::BoundVar, ty::GenericArg<'tcx>>,
170+
}
171+
172+
impl<'tcx> TypeFolder<TyCtxt<'tcx>> for MapAndCompressBoundVars<'tcx> {
173+
fn cx(&self) -> TyCtxt<'tcx> {
174+
self.tcx
175+
}
176+
177+
fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
178+
where
179+
ty::Binder<'tcx, T>: TypeSuperFoldable<TyCtxt<'tcx>>,
180+
{
181+
self.binder.shift_in(1);
182+
let out = t.super_fold_with(self);
183+
self.binder.shift_out(1);
184+
out
185+
}
186+
187+
fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
188+
if !ty.has_bound_vars() {
189+
return ty;
190+
}
191+
192+
if let ty::Bound(binder, old_bound) = *ty.kind()
193+
&& self.binder == binder
194+
{
195+
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
196+
mapped.expect_ty()
197+
} else {
198+
// If we didn't find a mapped generic, then make a new one.
199+
// Allocate a new var idx, and insert a new bound ty.
200+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
201+
self.still_bound_vars.push(ty::BoundVariableKind::Ty(old_bound.kind));
202+
let mapped = Ty::new_bound(
203+
self.tcx,
204+
ty::INNERMOST,
205+
ty::BoundTy { var, kind: old_bound.kind },
206+
);
207+
self.mapping.insert(old_bound.var, mapped.into());
208+
mapped
209+
};
210+
211+
shift_vars(self.tcx, mapped, self.binder.as_u32())
212+
} else {
213+
ty.super_fold_with(self)
214+
}
215+
}
216+
217+
fn fold_region(&mut self, re: ty::Region<'tcx>) -> ty::Region<'tcx> {
218+
if let ty::ReBound(binder, old_bound) = re.kind()
219+
&& self.binder == binder
220+
{
221+
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
222+
mapped.expect_region()
223+
} else {
224+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
225+
self.still_bound_vars.push(ty::BoundVariableKind::Region(old_bound.kind));
226+
let mapped = ty::Region::new_bound(
227+
self.tcx,
228+
ty::INNERMOST,
229+
ty::BoundRegion { var, kind: old_bound.kind },
230+
);
231+
self.mapping.insert(old_bound.var, mapped.into());
232+
mapped
233+
};
234+
235+
shift_vars(self.tcx, mapped, self.binder.as_u32())
236+
} else {
237+
re
238+
}
239+
}
240+
241+
fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
242+
if !ct.has_bound_vars() {
243+
return ct;
244+
}
245+
246+
if let ty::ConstKind::Bound(binder, old_var) = ct.kind()
247+
&& self.binder == binder
248+
{
249+
let mapped = if let Some(mapped) = self.mapping.get(&old_var) {
250+
mapped.expect_const()
251+
} else {
252+
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
253+
self.still_bound_vars.push(ty::BoundVariableKind::Const);
254+
let mapped = ty::Const::new_bound(self.tcx, ty::INNERMOST, var);
255+
self.mapping.insert(old_var, mapped.into());
256+
mapped
257+
};
258+
259+
shift_vars(self.tcx, mapped, self.binder.as_u32())
260+
} else {
261+
ct.super_fold_with(self)
262+
}
263+
}
264+
265+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
266+
if !p.has_bound_vars() { p } else { p.super_fold_with(self) }
267+
}
268+
}
269+
62270
/// Opaque types don't inherit bounds from their parent: for return position
63271
/// impl trait it isn't possible to write a suitable predicate on the
64272
/// containing function and for type-alias impl trait we don't have a backwards
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//@ check-pass
2+
//@ revisions: current next
3+
//@[next] compile-flags: -Znext-solver
4+
5+
trait Trait
6+
where
7+
Self::Assoc: Clone,
8+
{
9+
type Assoc;
10+
}
11+
12+
fn foo<T: Trait>(x: &T::Assoc) -> T::Assoc {
13+
x.clone()
14+
}
15+
16+
trait Trait2
17+
where
18+
Self::Assoc: Iterator,
19+
<Self::Assoc as Iterator>::Item: Clone,
20+
{
21+
type Assoc;
22+
}
23+
24+
fn foo2<T: Trait2>(x: &<T::Assoc as Iterator>::Item) -> <T::Assoc as Iterator>::Item {
25+
x.clone()
26+
}
27+
28+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//@ check-pass
2+
3+
// Test that `for<'a> Self::Gat<'a>: Debug` is implied in the definition of `Foo`,
4+
// just as it would be if it weren't a GAT but just a regular associated type.
5+
6+
use std::fmt::Debug;
7+
8+
trait Foo
9+
where
10+
for<'a> Self::Gat<'a>: Debug,
11+
{
12+
type Gat<'a>;
13+
}
14+
15+
fn test<T: Foo>(x: T::Gat<'static>) {
16+
println!("{:?}", x);
17+
}
18+
19+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//@ check-pass
2+
//@ revisions: current next
3+
//@[next] compile-flags: -Znext-solver
4+
5+
trait Foo
6+
where
7+
Self::Iterator: Iterator,
8+
<Self::Iterator as Iterator>::Item: Bar,
9+
{
10+
type Iterator;
11+
12+
fn iter() -> Self::Iterator;
13+
}
14+
15+
trait Bar {
16+
fn bar(&self);
17+
}
18+
19+
fn x<T: Foo>() {
20+
T::iter().next().unwrap().bar();
21+
}
22+
23+
fn main() {}

0 commit comments

Comments
 (0)