Skip to content

Commit 1a55ab3

Browse files
committed
Auto merge of #16769 - ShoyuVanilla:issue-15412, r=Veykril
fix: Function argument type inference with associated type impl trait Fixes #15412
2 parents 52d8ae7 + a8f5611 commit 1a55ab3

File tree

2 files changed

+132
-11
lines changed

2 files changed

+132
-11
lines changed

crates/hir-ty/src/lower.rs

+85-11
Original file line numberDiff line numberDiff line change
@@ -995,12 +995,12 @@ impl<'a> TyLoweringContext<'a> {
995995

996996
pub(crate) fn lower_type_bound(
997997
&'a self,
998-
bound: &'a TypeBound,
998+
bound: &'a Interned<TypeBound>,
999999
self_ty: Ty,
10001000
ignore_bindings: bool,
10011001
) -> impl Iterator<Item = QuantifiedWhereClause> + 'a {
10021002
let mut bindings = None;
1003-
let trait_ref = match bound {
1003+
let trait_ref = match bound.as_ref() {
10041004
TypeBound::Path(path, TraitBoundModifier::None) => {
10051005
bindings = self.lower_trait_ref_from_path(path, Some(self_ty));
10061006
bindings
@@ -1055,10 +1055,10 @@ impl<'a> TyLoweringContext<'a> {
10551055

10561056
fn assoc_type_bindings_from_type_bound(
10571057
&'a self,
1058-
bound: &'a TypeBound,
1058+
bound: &'a Interned<TypeBound>,
10591059
trait_ref: TraitRef,
10601060
) -> impl Iterator<Item = QuantifiedWhereClause> + 'a {
1061-
let last_segment = match bound {
1061+
let last_segment = match bound.as_ref() {
10621062
TypeBound::Path(path, TraitBoundModifier::None) | TypeBound::ForLifetime(_, path) => {
10631063
path.segments().last()
10641064
}
@@ -1121,7 +1121,63 @@ impl<'a> TyLoweringContext<'a> {
11211121
);
11221122
}
11231123
} else {
1124-
let ty = self.lower_ty(type_ref);
1124+
let ty = 'ty: {
1125+
if matches!(
1126+
self.impl_trait_mode,
1127+
ImplTraitLoweringState::Param(_)
1128+
| ImplTraitLoweringState::Variable(_)
1129+
) {
1130+
// Find the generic index for the target of our `bound`
1131+
let target_param_idx = self
1132+
.resolver
1133+
.where_predicates_in_scope()
1134+
.find_map(|p| match p {
1135+
WherePredicate::TypeBound {
1136+
target: WherePredicateTypeTarget::TypeOrConstParam(idx),
1137+
bound: b,
1138+
} if b == bound => Some(idx),
1139+
_ => None,
1140+
});
1141+
if let Some(target_param_idx) = target_param_idx {
1142+
let mut counter = 0;
1143+
for (idx, data) in self.generics().params.type_or_consts.iter()
1144+
{
1145+
// Count the number of `impl Trait` things that appear before
1146+
// the target of our `bound`.
1147+
// Our counter within `impl_trait_mode` should be that number
1148+
// to properly lower each types within `type_ref`
1149+
if data.type_param().is_some_and(|p| {
1150+
p.provenance == TypeParamProvenance::ArgumentImplTrait
1151+
}) {
1152+
counter += 1;
1153+
}
1154+
if idx == *target_param_idx {
1155+
break;
1156+
}
1157+
}
1158+
let mut ext = TyLoweringContext::new_maybe_unowned(
1159+
self.db,
1160+
self.resolver,
1161+
self.owner,
1162+
)
1163+
.with_type_param_mode(self.type_param_mode);
1164+
match &self.impl_trait_mode {
1165+
ImplTraitLoweringState::Param(_) => {
1166+
ext.impl_trait_mode =
1167+
ImplTraitLoweringState::Param(Cell::new(counter));
1168+
}
1169+
ImplTraitLoweringState::Variable(_) => {
1170+
ext.impl_trait_mode = ImplTraitLoweringState::Variable(
1171+
Cell::new(counter),
1172+
);
1173+
}
1174+
_ => unreachable!(),
1175+
}
1176+
break 'ty ext.lower_ty(type_ref);
1177+
}
1178+
}
1179+
self.lower_ty(type_ref)
1180+
};
11251181
let alias_eq =
11261182
AliasEq { alias: AliasTy::Projection(projection_ty.clone()), ty };
11271183
predicates.push(crate::wrap_empty_binders(WhereClause::AliasEq(alias_eq)));
@@ -1403,8 +1459,14 @@ pub(crate) fn generic_predicates_for_param_query(
14031459
assoc_name: Option<Name>,
14041460
) -> Arc<[Binders<QuantifiedWhereClause>]> {
14051461
let resolver = def.resolver(db.upcast());
1406-
let ctx = TyLoweringContext::new(db, &resolver, def.into())
1407-
.with_type_param_mode(ParamLoweringMode::Variable);
1462+
let ctx = if let GenericDefId::FunctionId(_) = def {
1463+
TyLoweringContext::new(db, &resolver, def.into())
1464+
.with_impl_trait_mode(ImplTraitLoweringMode::Variable)
1465+
.with_type_param_mode(ParamLoweringMode::Variable)
1466+
} else {
1467+
TyLoweringContext::new(db, &resolver, def.into())
1468+
.with_type_param_mode(ParamLoweringMode::Variable)
1469+
};
14081470
let generics = generics(db.upcast(), def);
14091471

14101472
// we have to filter out all other predicates *first*, before attempting to lower them
@@ -1490,8 +1552,14 @@ pub(crate) fn trait_environment_query(
14901552
def: GenericDefId,
14911553
) -> Arc<TraitEnvironment> {
14921554
let resolver = def.resolver(db.upcast());
1493-
let ctx = TyLoweringContext::new(db, &resolver, def.into())
1494-
.with_type_param_mode(ParamLoweringMode::Placeholder);
1555+
let ctx = if let GenericDefId::FunctionId(_) = def {
1556+
TyLoweringContext::new(db, &resolver, def.into())
1557+
.with_impl_trait_mode(ImplTraitLoweringMode::Param)
1558+
.with_type_param_mode(ParamLoweringMode::Placeholder)
1559+
} else {
1560+
TyLoweringContext::new(db, &resolver, def.into())
1561+
.with_type_param_mode(ParamLoweringMode::Placeholder)
1562+
};
14951563
let mut traits_in_scope = Vec::new();
14961564
let mut clauses = Vec::new();
14971565
for pred in resolver.where_predicates_in_scope() {
@@ -1549,8 +1617,14 @@ pub(crate) fn generic_predicates_query(
15491617
def: GenericDefId,
15501618
) -> Arc<[Binders<QuantifiedWhereClause>]> {
15511619
let resolver = def.resolver(db.upcast());
1552-
let ctx = TyLoweringContext::new(db, &resolver, def.into())
1553-
.with_type_param_mode(ParamLoweringMode::Variable);
1620+
let ctx = if let GenericDefId::FunctionId(_) = def {
1621+
TyLoweringContext::new(db, &resolver, def.into())
1622+
.with_impl_trait_mode(ImplTraitLoweringMode::Variable)
1623+
.with_type_param_mode(ParamLoweringMode::Variable)
1624+
} else {
1625+
TyLoweringContext::new(db, &resolver, def.into())
1626+
.with_type_param_mode(ParamLoweringMode::Variable)
1627+
};
15541628
let generics = generics(db.upcast(), def);
15551629

15561630
let mut predicates = resolver

crates/hir-ty/src/tests/traits.rs

+47
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,53 @@ fn test(x: impl Trait<u64>, y: &impl Trait<u64>) {
12311231
);
12321232
}
12331233

1234+
#[test]
1235+
fn argument_impl_trait_with_projection() {
1236+
check_infer(
1237+
r#"
1238+
trait X {
1239+
type Item;
1240+
}
1241+
1242+
impl<T> X for [T; 2] {
1243+
type Item = T;
1244+
}
1245+
1246+
trait Y {}
1247+
1248+
impl<T> Y for T {}
1249+
1250+
enum R<T, U> {
1251+
A(T),
1252+
B(U),
1253+
}
1254+
1255+
fn foo<T>(x: impl X<Item = R<impl Y, T>>) -> T { loop {} }
1256+
1257+
fn bar() {
1258+
let a = foo([R::A(()), R::B(7)]);
1259+
}
1260+
"#,
1261+
expect![[r#"
1262+
153..154 'x': impl X<Item = R<impl Y + ?Sized, T>> + ?Sized
1263+
190..201 '{ loop {} }': T
1264+
192..199 'loop {}': !
1265+
197..199 '{}': ()
1266+
212..253 '{ ...)]); }': ()
1267+
222..223 'a': i32
1268+
226..229 'foo': fn foo<i32>([R<(), i32>; 2]) -> i32
1269+
226..250 'foo([R...B(7)])': i32
1270+
230..249 '[R::A(...:B(7)]': [R<(), i32>; 2]
1271+
231..235 'R::A': extern "rust-call" A<(), i32>(()) -> R<(), i32>
1272+
231..239 'R::A(())': R<(), i32>
1273+
236..238 '()': ()
1274+
241..245 'R::B': extern "rust-call" B<(), i32>(i32) -> R<(), i32>
1275+
241..248 'R::B(7)': R<(), i32>
1276+
246..247 '7': i32
1277+
"#]],
1278+
);
1279+
}
1280+
12341281
#[test]
12351282
fn simple_return_pos_impl_trait() {
12361283
cov_mark::check!(lower_rpit);

0 commit comments

Comments
 (0)