Skip to content

Commit 5f8a953

Browse files
committed
Fallback {float} to f32 when f32: From<{float}>
1 parent 2a06022 commit 5f8a953

File tree

8 files changed

+125
-2
lines changed

8 files changed

+125
-2
lines changed

compiler/rustc_hir/src/lang_items.rs

+3
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,9 @@ language_item_table! {
433433
// Experimental lang items for implementing contract pre- and post-condition checking.
434434
ContractBuildCheckEnsures, sym::contract_build_check_ensures, contract_build_check_ensures_fn, Target::Fn, GenericRequirement::None;
435435
ContractCheckRequires, sym::contract_check_requires, contract_check_requires_fn, Target::Fn, GenericRequirement::None;
436+
437+
// Used to fallback `{float}` to `f32` when `f32: From<{float}>`
438+
From, sym::From, from_trait, Target::Trait, GenericRequirement::Exact(1);
436439
}
437440

438441
/// The requirement imposed on the generics of a lang item

compiler/rustc_hir_typeck/src/fallback.rs

+69-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use rustc_hir::HirId;
1111
use rustc_hir::def::{DefKind, Res};
1212
use rustc_hir::def_id::DefId;
1313
use rustc_hir::intravisit::{InferKind, Visitor};
14-
use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable};
14+
use rustc_middle::ty::{
15+
self, ClauseKind, FloatVid, PredicatePolarity, TraitPredicate, Ty, TyCtxt, TypeSuperVisitable,
16+
TypeVisitable,
17+
};
1518
use rustc_session::lint;
1619
use rustc_span::def_id::LocalDefId;
1720
use rustc_span::{DUMMY_SP, Span};
@@ -92,14 +95,16 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
9295

9396
let diverging_fallback = self
9497
.calculate_diverging_fallback(&unresolved_variables, self.diverging_fallback_behavior);
98+
let fallback_to_f32 = self.calculate_fallback_to_f32(&unresolved_variables);
9599

96100
// We do fallback in two passes, to try to generate
97101
// better error messages.
98102
// The first time, we do *not* replace opaque types.
99103
let mut fallback_occurred = false;
100104
for ty in unresolved_variables {
101105
debug!("unsolved_variable = {:?}", ty);
102-
fallback_occurred |= self.fallback_if_possible(ty, &diverging_fallback);
106+
fallback_occurred |=
107+
self.fallback_if_possible(ty, &diverging_fallback, &fallback_to_f32);
103108
}
104109

105110
fallback_occurred
@@ -124,6 +129,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
124129
&self,
125130
ty: Ty<'tcx>,
126131
diverging_fallback: &UnordMap<Ty<'tcx>, Ty<'tcx>>,
132+
fallback_to_f32: &UnordSet<FloatVid>,
127133
) -> bool {
128134
// Careful: we do NOT shallow-resolve `ty`. We know that `ty`
129135
// is an unsolved variable, and we determine its fallback
@@ -146,6 +152,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
146152
let fallback = match ty.kind() {
147153
_ if let Some(e) = self.tainted_by_errors() => Ty::new_error(self.tcx, e),
148154
ty::Infer(ty::IntVar(_)) => self.tcx.types.i32,
155+
ty::Infer(ty::FloatVar(vid)) if fallback_to_f32.contains(&vid) => self.tcx.types.f32,
149156
ty::Infer(ty::FloatVar(_)) => self.tcx.types.f64,
150157
_ => match diverging_fallback.get(&ty) {
151158
Some(&fallback_ty) => fallback_ty,
@@ -160,6 +167,61 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
160167
true
161168
}
162169

170+
/// Existing code relies on `f32: From<T>` (usually written as `T: Into<f32>`) resolving `T` to
171+
/// `f32` when the type of `T` is inferred from an unsuffixed float literal. Using the default
172+
/// fallback of `f64`, this would break when adding `impl From<f16> for f32`, as there are now
173+
/// two float type which could be `T`, meaning that the fallback of `f64` would be used and
174+
/// compilation error would occur as `f32` does not implement `From<f64>`. To avoid breaking
175+
/// existing code, we instead fallback `T` to `f32` when there is a trait predicate
176+
/// `f32: From<T>`. This means code like the following will continue to compile:
177+
///
178+
/// ```rust
179+
/// fn foo<T: Into<f32>>(_: T) {}
180+
///
181+
/// foo(1.0);
182+
/// ```
183+
fn calculate_fallback_to_f32(&self, unresolved_variables: &[Ty<'tcx>]) -> UnordSet<FloatVid> {
184+
let Some(from_trait) = self.tcx.lang_items().from_trait() else {
185+
return UnordSet::new();
186+
};
187+
let pending_obligations = self.fulfillment_cx.borrow_mut().pending_obligations();
188+
debug!("calculate_fallback_to_f32: pending_obligations={:?}", pending_obligations);
189+
let roots: UnordSet<ty::FloatVid> = pending_obligations
190+
.into_iter()
191+
.filter_map(|obligation| {
192+
// The predicates we are looking for look like
193+
// `TraitPredicate(<f32 as std::convert::From<{float}>>, polarity:Positive)`.
194+
// They will have no bound variables.
195+
obligation.predicate.kind().no_bound_vars()
196+
})
197+
.filter_map(|predicate| match predicate {
198+
ty::PredicateKind::Clause(ClauseKind::Trait(TraitPredicate {
199+
polarity: PredicatePolarity::Positive,
200+
trait_ref,
201+
})) if trait_ref.def_id == from_trait
202+
&& self.shallow_resolve(trait_ref.self_ty()).kind()
203+
== &ty::Float(ty::FloatTy::F32) =>
204+
{
205+
self.root_float_vid(trait_ref.args.type_at(1))
206+
}
207+
_ => None,
208+
})
209+
.collect();
210+
debug!("calculate_fallback_to_f32: roots={:?}", roots);
211+
if roots.is_empty() {
212+
// Most functions have no `f32: From<{float}>` predicates, so short-circuit and return
213+
// an empty set when this is the case.
214+
return UnordSet::new();
215+
}
216+
let fallback_to_f32 = unresolved_variables
217+
.iter()
218+
.flat_map(|ty| ty.float_vid())
219+
.filter(|vid| roots.contains(&self.root_float_var(*vid)))
220+
.collect();
221+
debug!("calculate_fallback_to_f32: fallback_to_f32={:?}", fallback_to_f32);
222+
fallback_to_f32
223+
}
224+
163225
/// The "diverging fallback" system is rather complicated. This is
164226
/// a result of our need to balance 'do the right thing' with
165227
/// backwards compatibility.
@@ -565,6 +627,11 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
565627
Some(self.root_var(self.shallow_resolve(ty).ty_vid()?))
566628
}
567629

630+
/// If `ty` is an unresolved float type variable, returns its root vid.
631+
fn root_float_vid(&self, ty: Ty<'tcx>) -> Option<ty::FloatVid> {
632+
Some(self.root_float_var(self.shallow_resolve(ty).float_vid()?))
633+
}
634+
568635
/// Given a set of diverging vids and coercions, walk the HIR to gather a
569636
/// set of suggestions which can be applied to preserve fallback to unit.
570637
fn try_to_suggest_annotations(

compiler/rustc_infer/src/infer/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,10 @@ impl<'tcx> InferCtxt<'tcx> {
10611061
self.inner.borrow_mut().type_variables().root_var(var)
10621062
}
10631063

1064+
pub fn root_float_var(&self, var: ty::FloatVid) -> ty::FloatVid {
1065+
self.inner.borrow_mut().float_unification_table().find(var)
1066+
}
1067+
10641068
pub fn root_const_var(&self, var: ty::ConstVid) -> ty::ConstVid {
10651069
self.inner.borrow_mut().const_unification_table().find(var).vid
10661070
}

compiler/rustc_middle/src/ty/sty.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,14 @@ impl<'tcx> Ty<'tcx> {
11241124
}
11251125
}
11261126

1127+
#[inline]
1128+
pub fn float_vid(self) -> Option<ty::FloatVid> {
1129+
match self.kind() {
1130+
&Infer(FloatVar(vid)) => Some(vid),
1131+
_ => None,
1132+
}
1133+
}
1134+
11271135
#[inline]
11281136
pub fn is_ty_or_numeric_infer(self) -> bool {
11291137
matches!(self.kind(), Infer(_))

library/core/src/convert/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ pub trait Into<T>: Sized {
573573
/// [`from`]: From::from
574574
/// [book]: ../../book/ch09-00-error-handling.html
575575
#[rustc_diagnostic_item = "From"]
576+
#[cfg_attr(not(bootstrap), lang = "From")]
576577
#[stable(feature = "rust1", since = "1.0.0")]
577578
#[rustc_on_unimplemented(on(
578579
all(_Self = "&str", T = "alloc::string::String"),

tests/ui/float/f32-into-f32.rs

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
//@ run-pass
2+
3+
fn foo(_: impl Into<f32>) {}
4+
5+
fn main() {
6+
foo(1.0);
7+
}

tests/ui/float/trait-f16-or-f32.rs

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//@ check-fail
2+
3+
#![feature(f16)]
4+
5+
trait Trait {}
6+
impl Trait for f16 {}
7+
impl Trait for f32 {}
8+
9+
fn foo(_: impl Trait) {}
10+
11+
fn main() {
12+
foo(1.0); //~ ERROR the trait bound `f64: Trait` is not satisfied
13+
}
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
error[E0277]: the trait bound `f64: Trait` is not satisfied
2+
--> $DIR/trait-f16-or-f32.rs:12:9
3+
|
4+
LL | foo(1.0);
5+
| --- ^^^ the trait `Trait` is not implemented for `f64`
6+
| |
7+
| required by a bound introduced by this call
8+
|
9+
= help: the following other types implement trait `Trait`:
10+
f16
11+
f32
12+
note: required by a bound in `foo`
13+
--> $DIR/trait-f16-or-f32.rs:9:16
14+
|
15+
LL | fn foo(_: impl Trait) {}
16+
| ^^^^^ required by this bound in `foo`
17+
18+
error: aborting due to 1 previous error
19+
20+
For more information about this error, try `rustc --explain E0277`.

0 commit comments

Comments
 (0)