Skip to content

Commit 4e49b2f

Browse files
committed
Implement binary operator overloading type inference
1 parent 0fb069c commit 4e49b2f

File tree

3 files changed

+120
-5
lines changed

3 files changed

+120
-5
lines changed

crates/hir_ty/src/infer.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use arena::map::ArenaMap;
2222
use hir_def::{
2323
body::Body,
2424
data::{ConstData, FunctionData, StaticData},
25-
expr::{BindingAnnotation, ExprId, PatId},
25+
expr::{ArithOp, BinaryOp, BindingAnnotation, ExprId, PatId},
2626
lang_item::LangItemTarget,
2727
path::{path, Path},
2828
resolver::{HasResolver, Resolver, TypeNs},
@@ -586,6 +586,28 @@ impl<'a> InferenceContext<'a> {
586586
self.db.trait_data(trait_).associated_type_by_name(&name![Output])
587587
}
588588

589+
fn resolve_binary_op_output(&self, bop: &BinaryOp) -> Option<TypeAliasId> {
590+
let lang_item = match bop {
591+
BinaryOp::ArithOp(aop) => match aop {
592+
ArithOp::Add => "add",
593+
ArithOp::Sub => "sub",
594+
ArithOp::Mul => "mul",
595+
ArithOp::Div => "div",
596+
ArithOp::Shl => "shl",
597+
ArithOp::Shr => "shr",
598+
ArithOp::Rem => "rem",
599+
ArithOp::BitXor => "bitxor",
600+
ArithOp::BitOr => "bitor",
601+
ArithOp::BitAnd => "bitand",
602+
},
603+
_ => return None,
604+
};
605+
606+
let trait_ = self.resolve_lang_item(lang_item)?.as_trait();
607+
608+
self.db.trait_data(trait_?).associated_type_by_name(&name![Output])
609+
}
610+
589611
fn resolve_boxed_box(&self) -> Option<AdtId> {
590612
let struct_ = self.resolve_lang_item("owned_box")?.as_struct()?;
591613
Some(struct_.into())

crates/hir_ty/src/infer/expr.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -531,13 +531,20 @@ impl<'a> InferenceContext<'a> {
531531
_ => Expectation::none(),
532532
};
533533
let lhs_ty = self.infer_expr(*lhs, &lhs_expectation);
534-
// FIXME: find implementation of trait corresponding to operation
535-
// symbol and resolve associated `Output` type
536534
let rhs_expectation = op::binary_op_rhs_expectation(*op, lhs_ty.clone());
537535
let rhs_ty = self.infer_expr(*rhs, &Expectation::has_type(rhs_expectation));
538536

539-
// FIXME: similar as above, return ty is often associated trait type
540-
op::binary_op_return_ty(*op, lhs_ty, rhs_ty)
537+
let ret = op::binary_op_return_ty(*op, lhs_ty.clone(), rhs_ty.clone());
538+
539+
if ret == Ty::Unknown {
540+
self.resolve_associated_type_with_params(
541+
lhs_ty,
542+
self.resolve_binary_op_output(op),
543+
&[rhs_ty],
544+
)
545+
} else {
546+
ret
547+
}
541548
}
542549
_ => Ty::Unknown,
543550
},

crates/hir_ty/src/tests/simple.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,3 +2225,89 @@ fn generic_default_depending_on_other_type_arg_forward() {
22252225
"#]],
22262226
);
22272227
}
2228+
2229+
#[test]
2230+
fn infer_operator_overload() {
2231+
check_infer(
2232+
r#"
2233+
struct V2([f32; 2]);
2234+
2235+
#[lang = "add"]
2236+
pub trait Add<Rhs = Self> {
2237+
/// The resulting type after applying the `+` operator.
2238+
type Output;
2239+
2240+
/// Performs the `+` operation.
2241+
#[must_use]
2242+
fn add(self, rhs: Rhs) -> Self::Output;
2243+
}
2244+
2245+
impl Add<V2> for V2 {
2246+
type Output = V2;
2247+
2248+
fn add(self, rhs: V2) -> V2 {
2249+
let x = self.0[0] + rhs.0[0];
2250+
let y = self.0[1] + rhs.0[1];
2251+
V2([x, y])
2252+
}
2253+
}
2254+
2255+
fn test() {
2256+
let va = V2([0.0, 1.0]);
2257+
let vb = V2([0.0, 1.0]);
2258+
2259+
let r = va + vb;
2260+
}
2261+
2262+
"#,
2263+
expect![[r#"
2264+
207..211 'self': Self
2265+
213..216 'rhs': Rhs
2266+
299..303 'self': V2
2267+
305..308 'rhs': V2
2268+
320..422 '{ ... }': V2
2269+
334..335 'x': f32
2270+
338..342 'self': V2
2271+
338..344 'self.0': [f32; _]
2272+
338..347 'self.0[0]': {unknown}
2273+
338..358 'self.0...s.0[0]': f32
2274+
345..346 '0': i32
2275+
350..353 'rhs': V2
2276+
350..355 'rhs.0': [f32; _]
2277+
350..358 'rhs.0[0]': {unknown}
2278+
356..357 '0': i32
2279+
372..373 'y': f32
2280+
376..380 'self': V2
2281+
376..382 'self.0': [f32; _]
2282+
376..385 'self.0[1]': {unknown}
2283+
376..396 'self.0...s.0[1]': f32
2284+
383..384 '1': i32
2285+
388..391 'rhs': V2
2286+
388..393 'rhs.0': [f32; _]
2287+
388..396 'rhs.0[1]': {unknown}
2288+
394..395 '1': i32
2289+
406..408 'V2': V2([f32; _]) -> V2
2290+
406..416 'V2([x, y])': V2
2291+
409..415 '[x, y]': [f32; _]
2292+
410..411 'x': f32
2293+
413..414 'y': f32
2294+
436..519 '{ ... vb; }': ()
2295+
446..448 'va': V2
2296+
451..453 'V2': V2([f32; _]) -> V2
2297+
451..465 'V2([0.0, 1.0])': V2
2298+
454..464 '[0.0, 1.0]': [f32; _]
2299+
455..458 '0.0': f32
2300+
460..463 '1.0': f32
2301+
475..477 'vb': V2
2302+
480..482 'V2': V2([f32; _]) -> V2
2303+
480..494 'V2([0.0, 1.0])': V2
2304+
483..493 '[0.0, 1.0]': [f32; _]
2305+
484..487 '0.0': f32
2306+
489..492 '1.0': f32
2307+
505..506 'r': V2
2308+
509..511 'va': V2
2309+
509..516 'va + vb': V2
2310+
514..516 'vb': V2
2311+
"#]],
2312+
);
2313+
}

0 commit comments

Comments
 (0)