Skip to content

Commit 02eb2d7

Browse files
committed
Distinguish between expected and final type in CoerceMany
1 parent c229a83 commit 02eb2d7

File tree

3 files changed

+187
-88
lines changed

3 files changed

+187
-88
lines changed

crates/hir-ty/src/infer.rs

+19-4
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<Infer
6666
let mut ctx = InferenceContext::new(db, def, &body, resolver);
6767

6868
match def {
69+
DefWithBodyId::FunctionId(f) => {
70+
ctx.collect_fn(f);
71+
}
6972
DefWithBodyId::ConstId(c) => ctx.collect_const(&db.const_data(c)),
70-
DefWithBodyId::FunctionId(f) => ctx.collect_fn(f),
7173
DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_data(s)),
7274
DefWithBodyId::VariantId(v) => {
7375
ctx.return_ty = TyBuilder::builtin(match db.enum_data(v.parent).variant_body_type() {
@@ -392,9 +394,12 @@ pub(crate) struct InferenceContext<'a> {
392394
/// currently within one.
393395
///
394396
/// We might consider using a nested inference context for checking
395-
/// closures, but currently this is the only field that will change there,
396-
/// so it doesn't make sense.
397+
/// closures so we can swap all shared things out at once.
397398
return_ty: Ty,
399+
/// If `Some`, this stores coercion information for returned
400+
/// expressions. If `None`, this is in a context where return is
401+
/// inappropriate, such as a const expression.
402+
return_coercion: Option<CoerceMany>,
398403
/// The resume type and the yield type, respectively, of the generator being inferred.
399404
resume_yield_tys: Option<(Ty, Ty)>,
400405
diverges: Diverges,
@@ -462,6 +467,7 @@ impl<'a> InferenceContext<'a> {
462467
trait_env,
463468
return_ty: TyKind::Error.intern(Interner), // set in collect_* calls
464469
resume_yield_tys: None,
470+
return_coercion: None,
465471
db,
466472
owner,
467473
body,
@@ -595,10 +601,19 @@ impl<'a> InferenceContext<'a> {
595601
};
596602

597603
self.return_ty = self.normalize_associated_types_in(return_ty);
604+
self.return_coercion = Some(CoerceMany::new(self.return_ty.clone()));
598605
}
599606

600607
fn infer_body(&mut self) {
601-
self.infer_expr_coerce(self.body.body_expr, &Expectation::has_type(self.return_ty.clone()));
608+
match self.return_coercion {
609+
Some(_) => self.infer_return(self.body.body_expr),
610+
None => {
611+
_ = self.infer_expr_coerce(
612+
self.body.body_expr,
613+
&Expectation::has_type(self.return_ty.clone()),
614+
)
615+
}
616+
}
602617
}
603618

604619
fn write_expr_ty(&mut self, expr: ExprId, ty: Ty) {

crates/hir-ty/src/infer/coerce.rs

+43-15
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,44 @@ fn success(
5050
#[derive(Clone, Debug)]
5151
pub(super) struct CoerceMany {
5252
expected_ty: Ty,
53+
final_ty: Option<Ty>,
5354
}
5455

5556
impl CoerceMany {
5657
pub(super) fn new(expected: Ty) -> Self {
57-
CoerceMany { expected_ty: expected }
58+
CoerceMany { expected_ty: expected, final_ty: None }
59+
}
60+
61+
/// Returns the "expected type" with which this coercion was
62+
/// constructed. This represents the "downward propagated" type
63+
/// that was given to us at the start of typing whatever construct
64+
/// we are typing (e.g., the match expression).
65+
///
66+
/// Typically, this is used as the expected type when
67+
/// type-checking each of the alternative expressions whose types
68+
/// we are trying to merge.
69+
pub(super) fn expected_ty(&self) -> Ty {
70+
self.expected_ty.clone()
71+
}
72+
73+
/// Returns the current "merged type", representing our best-guess
74+
/// at the LUB of the expressions we've seen so far (if any). This
75+
/// isn't *final* until you call `self.complete()`, which will return
76+
/// the merged type.
77+
pub(super) fn merged_ty(&self) -> Ty {
78+
self.final_ty.clone().unwrap_or_else(|| self.expected_ty.clone())
79+
}
80+
81+
pub(super) fn complete(self, ctx: &mut InferenceContext<'_>) -> Ty {
82+
if let Some(final_ty) = self.final_ty {
83+
final_ty
84+
} else {
85+
ctx.result.standard_types.never.clone()
86+
}
87+
}
88+
89+
pub(super) fn coerce_forced_unit(&mut self, ctx: &mut InferenceContext<'_>) {
90+
self.coerce(ctx, None, &ctx.result.standard_types.unit.clone())
5891
}
5992

6093
/// Merge two types from different branches, with possible coercion.
@@ -76,51 +109,46 @@ impl CoerceMany {
76109
// Special case: two function types. Try to coerce both to
77110
// pointers to have a chance at getting a match. See
78111
// https://github.com/rust-lang/rust/blob/7b805396bf46dce972692a6846ce2ad8481c5f85/src/librustc_typeck/check/coercion.rs#L877-L916
79-
let sig = match (self.expected_ty.kind(Interner), expr_ty.kind(Interner)) {
112+
let sig = match (self.merged_ty().kind(Interner), expr_ty.kind(Interner)) {
80113
(TyKind::FnDef(..) | TyKind::Closure(..), TyKind::FnDef(..) | TyKind::Closure(..)) => {
81114
// FIXME: we're ignoring safety here. To be more correct, if we have one FnDef and one Closure,
82115
// we should be coercing the closure to a fn pointer of the safety of the FnDef
83116
cov_mark::hit!(coerce_fn_reification);
84117
let sig =
85-
self.expected_ty.callable_sig(ctx.db).expect("FnDef without callable sig");
118+
self.merged_ty().callable_sig(ctx.db).expect("FnDef without callable sig");
86119
Some(sig)
87120
}
88121
_ => None,
89122
};
90123
if let Some(sig) = sig {
91124
let target_ty = TyKind::Function(sig.to_fn_ptr()).intern(Interner);
92-
let result1 = ctx.table.coerce_inner(self.expected_ty.clone(), &target_ty);
125+
let result1 = ctx.table.coerce_inner(self.merged_ty(), &target_ty);
93126
let result2 = ctx.table.coerce_inner(expr_ty.clone(), &target_ty);
94127
if let (Ok(result1), Ok(result2)) = (result1, result2) {
95128
ctx.table.register_infer_ok(result1);
96129
ctx.table.register_infer_ok(result2);
97-
return self.expected_ty = target_ty;
130+
return self.final_ty = Some(target_ty);
98131
}
99132
}
100133

101134
// It might not seem like it, but order is important here: If the expected
102135
// type is a type variable and the new one is `!`, trying it the other
103136
// way around first would mean we make the type variable `!`, instead of
104137
// just marking it as possibly diverging.
105-
if ctx.coerce(expr, &expr_ty, &self.expected_ty).is_ok() {
106-
/* self.expected_ty is already correct */
107-
} else if ctx.coerce(expr, &self.expected_ty, &expr_ty).is_ok() {
108-
self.expected_ty = expr_ty;
138+
if let Ok(res) = ctx.coerce(expr, &expr_ty, &self.merged_ty()) {
139+
self.final_ty = Some(res);
140+
} else if let Ok(res) = ctx.coerce(expr, &self.merged_ty(), &expr_ty) {
141+
self.final_ty = Some(res);
109142
} else {
110143
if let Some(id) = expr {
111144
ctx.result.type_mismatches.insert(
112145
id.into(),
113-
TypeMismatch { expected: self.expected_ty.clone(), actual: expr_ty },
146+
TypeMismatch { expected: self.merged_ty().clone(), actual: expr_ty.clone() },
114147
);
115148
}
116149
cov_mark::hit!(coerce_merge_fail_fallback);
117-
/* self.expected_ty is already correct */
118150
}
119151
}
120-
121-
pub(super) fn complete(self) -> Ty {
122-
self.expected_ty
123-
}
124152
}
125153

126154
pub fn could_coerce(

0 commit comments

Comments
 (0)