Skip to content

Commit a7fa4cb

Browse files
Implement projection and shim for AFIDT
1 parent 3b05779 commit a7fa4cb

File tree

17 files changed

+490
-22
lines changed

17 files changed

+490
-22
lines changed

compiler/rustc_middle/src/ty/instance.rs

+20-17
Original file line numberDiff line numberDiff line change
@@ -677,23 +677,26 @@ impl<'tcx> Instance<'tcx> {
677677
//
678678
// 1) The underlying method expects a caller location parameter
679679
// in the ABI
680-
if resolved.def.requires_caller_location(tcx)
681-
// 2) The caller location parameter comes from having `#[track_caller]`
682-
// on the implementation, and *not* on the trait method.
683-
&& !tcx.should_inherit_track_caller(def)
684-
// If the method implementation comes from the trait definition itself
685-
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
686-
// then we don't need to generate a shim. This check is needed because
687-
// `should_inherit_track_caller` returns `false` if our method
688-
// implementation comes from the trait block, and not an impl block
689-
&& !matches!(
690-
tcx.opt_associated_item(def),
691-
Some(ty::AssocItem {
692-
container: ty::AssocItemContainer::Trait,
693-
..
694-
})
695-
)
696-
{
680+
let needs_track_caller_shim = resolved.def.requires_caller_location(tcx)
681+
// 2) The caller location parameter comes from having `#[track_caller]`
682+
// on the implementation, and *not* on the trait method.
683+
&& !tcx.should_inherit_track_caller(def)
684+
// If the method implementation comes from the trait definition itself
685+
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
686+
// then we don't need to generate a shim. This check is needed because
687+
// `should_inherit_track_caller` returns `false` if our method
688+
// implementation comes from the trait block, and not an impl block
689+
&& !matches!(
690+
tcx.opt_associated_item(def),
691+
Some(ty::AssocItem {
692+
container: ty::AssocItemContainer::Trait,
693+
..
694+
})
695+
);
696+
// We also need to generate a shim if this is an AFIT.
697+
let needs_rpitit_shim =
698+
tcx.return_position_impl_trait_in_trait_shim_data(def).is_some();
699+
if needs_track_caller_shim || needs_rpitit_shim {
697700
if tcx.is_closure_like(def) {
698701
debug!(
699702
" => vtable fn pointer created for closure with #[track_caller]: {:?} for method {:?} {:?}",

compiler/rustc_middle/src/ty/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ mod opaque_types;
146146
mod parameterized;
147147
mod predicate;
148148
mod region;
149+
mod return_position_impl_trait_in_trait;
149150
mod rvalue_scopes;
150151
mod structural_impls;
151152
#[allow(hidden_glob_reexports)]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
use rustc_hir::def_id::DefId;
2+
3+
use crate::ty::{self, ExistentialPredicateStableCmpExt, TyCtxt};
4+
5+
impl<'tcx> TyCtxt<'tcx> {
6+
/// Given a `def_id` of a trait or impl method, compute whether that method needs to
7+
/// have an RPITIT shim applied to it for it to be object safe. If so, return the
8+
/// `def_id` of the RPITIT, and also the args of trait method that returns the RPITIT.
9+
///
10+
/// NOTE that these args are not, in general, the same as than the RPITIT's args. They
11+
/// are a subset of those args, since they do not include the late-bound lifetimes of
12+
/// the RPITIT. Depending on the context, these will need to be dealt with in different
13+
/// ways -- in codegen, it's okay to fill them with ReErased.
14+
pub fn return_position_impl_trait_in_trait_shim_data(
15+
self,
16+
def_id: DefId,
17+
) -> Option<(DefId, ty::EarlyBinder<'tcx, ty::GenericArgsRef<'tcx>>)> {
18+
let assoc_item = self.opt_associated_item(def_id)?;
19+
20+
let (trait_item_def_id, opt_impl_def_id) = match assoc_item.container {
21+
ty::AssocItemContainer::Impl => {
22+
(assoc_item.trait_item_def_id?, Some(self.parent(def_id)))
23+
}
24+
ty::AssocItemContainer::Trait => (def_id, None),
25+
};
26+
27+
let sig = self.fn_sig(trait_item_def_id);
28+
29+
// Check if the trait returns an RPITIT.
30+
let ty::Alias(ty::Projection, ty::AliasTy { def_id, .. }) =
31+
*sig.skip_binder().skip_binder().output().kind()
32+
else {
33+
return None;
34+
};
35+
if !self.is_impl_trait_in_trait(def_id) {
36+
return None;
37+
}
38+
39+
let args = if let Some(impl_def_id) = opt_impl_def_id {
40+
// Rebase the args from the RPITIT onto the impl trait ref, so we can later
41+
// substitute them with the method args of the *impl* method, since that's
42+
// the instance we're building a vtable shim for.
43+
ty::GenericArgs::identity_for_item(self, trait_item_def_id).rebase_onto(
44+
self,
45+
self.parent(trait_item_def_id),
46+
self.impl_trait_ref(impl_def_id)
47+
.expect("expected impl trait ref from parent of impl item")
48+
.instantiate_identity()
49+
.args,
50+
)
51+
} else {
52+
// This is when we have a default trait implementation.
53+
ty::GenericArgs::identity_for_item(self, trait_item_def_id)
54+
};
55+
56+
Some((def_id, ty::EarlyBinder::bind(args)))
57+
}
58+
59+
/// Given a `DefId` of an RPITIT and its args, return the existential predicates
60+
/// that corresponds to the RPITIT's bounds with the self type erased.
61+
pub fn item_bounds_to_existential_predicates(
62+
self,
63+
def_id: DefId,
64+
args: ty::GenericArgsRef<'tcx>,
65+
) -> &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
66+
let mut bounds: Vec<_> = self
67+
.item_super_predicates(def_id)
68+
.iter_instantiated(self, args)
69+
.filter_map(|clause| {
70+
clause
71+
.kind()
72+
.map_bound(|clause| match clause {
73+
ty::ClauseKind::Trait(trait_pred) => Some(ty::ExistentialPredicate::Trait(
74+
ty::ExistentialTraitRef::erase_self_ty(self, trait_pred.trait_ref),
75+
)),
76+
ty::ClauseKind::Projection(projection_pred) => {
77+
Some(ty::ExistentialPredicate::Projection(
78+
ty::ExistentialProjection::erase_self_ty(self, projection_pred),
79+
))
80+
}
81+
ty::ClauseKind::TypeOutlives(_) => {
82+
// Type outlives bounds don't really turn into anything,
83+
// since we must use an intersection region for the `dyn*`'s
84+
// region anyways.
85+
None
86+
}
87+
_ => unreachable!("unexpected clause in item bounds: {clause:?}"),
88+
})
89+
.transpose()
90+
})
91+
.collect();
92+
bounds.sort_by(|a, b| a.skip_binder().stable_cmp(self, &b.skip_binder()));
93+
self.mk_poly_existential_predicates(&bounds)
94+
}
95+
}

compiler/rustc_mir_transform/src/shim.rs

+53-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use rustc_index::{Idx, IndexVec};
99
use rustc_middle::mir::patch::MirPatch;
1010
use rustc_middle::mir::*;
1111
use rustc_middle::query::Providers;
12+
use rustc_middle::ty::adjustment::PointerCoercion;
1213
use rustc_middle::ty::{
1314
self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, GenericArgs, Ty, TyCtxt,
1415
};
@@ -710,6 +711,13 @@ fn build_call_shim<'tcx>(
710711
};
711712

712713
let def_id = instance.def_id();
714+
715+
let rpitit_shim = if let ty::InstanceKind::ReifyShim(..) = instance {
716+
tcx.return_position_impl_trait_in_trait_shim_data(def_id)
717+
} else {
718+
None
719+
};
720+
713721
let sig = tcx.fn_sig(def_id);
714722
let sig = sig.map_bound(|sig| tcx.instantiate_bound_regions_with_erased(sig));
715723

@@ -765,9 +773,34 @@ fn build_call_shim<'tcx>(
765773
let mut local_decls = local_decls_for_sig(&sig, span);
766774
let source_info = SourceInfo::outermost(span);
767775

776+
let mut destination = Place::return_place();
777+
if let Some((rpitit_def_id, fn_args)) = rpitit_shim {
778+
let rpitit_args =
779+
fn_args.instantiate_identity().extend_to(tcx, rpitit_def_id, |param, _| {
780+
match param.kind {
781+
ty::GenericParamDefKind::Lifetime => tcx.lifetimes.re_erased.into(),
782+
ty::GenericParamDefKind::Type { .. }
783+
| ty::GenericParamDefKind::Const { .. } => {
784+
unreachable!("rpitit should have no addition ty/ct")
785+
}
786+
}
787+
});
788+
let dyn_star_ty = Ty::new_dynamic(
789+
tcx,
790+
tcx.item_bounds_to_existential_predicates(rpitit_def_id, rpitit_args),
791+
tcx.lifetimes.re_erased,
792+
ty::DynStar,
793+
);
794+
destination = local_decls.push(local_decls[RETURN_PLACE].clone()).into();
795+
local_decls[RETURN_PLACE].ty = dyn_star_ty;
796+
let mut inputs_and_output = sig.inputs_and_output.to_vec();
797+
*inputs_and_output.last_mut().unwrap() = dyn_star_ty;
798+
sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
799+
}
800+
768801
let rcvr_place = || {
769802
assert!(rcvr_adjustment.is_some());
770-
Place::from(Local::new(1 + 0))
803+
Place::from(Local::new(1))
771804
};
772805
let mut statements = vec![];
773806

@@ -854,7 +887,7 @@ fn build_call_shim<'tcx>(
854887
TerminatorKind::Call {
855888
func: callee,
856889
args,
857-
destination: Place::return_place(),
890+
destination,
858891
target: Some(BasicBlock::new(1)),
859892
unwind: if let Some(Adjustment::RefMut) = rcvr_adjustment {
860893
UnwindAction::Cleanup(BasicBlock::new(3))
@@ -882,7 +915,24 @@ fn build_call_shim<'tcx>(
882915
);
883916
}
884917
// BB #1/#2 - return
885-
block(&mut blocks, vec![], TerminatorKind::Return, false);
918+
// NOTE: If this is an RPITIT in dyn, we also want to coerce
919+
// the return type of the function into a `dyn*`.
920+
let stmts = if rpitit_shim.is_some() {
921+
vec![Statement {
922+
source_info,
923+
kind: StatementKind::Assign(Box::new((
924+
Place::return_place(),
925+
Rvalue::Cast(
926+
CastKind::PointerCoercion(PointerCoercion::DynStar, CoercionSource::Implicit),
927+
Operand::Move(destination),
928+
sig.output(),
929+
),
930+
))),
931+
}]
932+
} else {
933+
vec![]
934+
};
935+
block(&mut blocks, stmts, TerminatorKind::Return, false);
886936
if let Some(Adjustment::RefMut) = rcvr_adjustment {
887937
// BB #3 - drop if closure panics
888938
block(

compiler/rustc_monomorphize/src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ fn custom_coerce_unsize_info<'tcx>(
4242
..
4343
})) => Ok(tcx.coerce_unsized_info(impl_def_id)?.custom_kind.unwrap()),
4444
impl_source => {
45-
bug!("invalid `CoerceUnsized` impl_source: {:?}", impl_source);
45+
bug!(
46+
"invalid `CoerceUnsized` from {source_ty} to {target_ty}: impl_source: {:?}",
47+
impl_source
48+
);
4649
}
4750
}
4851
}

compiler/rustc_trait_selection/src/traits/project.rs

+57-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use rustc_data_structures::stack::ensure_sufficient_stack;
77
use rustc_errors::ErrorGuaranteed;
88
use rustc_hir::def::DefKind;
99
use rustc_hir::lang_items::LangItem;
10-
use rustc_infer::infer::DefineOpaqueTypes;
1110
use rustc_infer::infer::resolve::OpportunisticRegionResolver;
11+
use rustc_infer::infer::{DefineOpaqueTypes, RegionVariableOrigin};
1212
use rustc_infer::traits::{ObligationCauseCode, PredicateObligations};
1313
use rustc_middle::traits::select::OverflowError;
1414
use rustc_middle::traits::{BuiltinImplSource, ImplSource, ImplSourceUserDefinedData};
@@ -18,6 +18,7 @@ use rustc_middle::ty::visit::TypeVisitableExt;
1818
use rustc_middle::ty::{self, Term, Ty, TyCtxt, TypingMode, Upcast};
1919
use rustc_middle::{bug, span_bug};
2020
use rustc_span::symbol::sym;
21+
use thin_vec::thin_vec;
2122
use tracing::{debug, instrument};
2223

2324
use super::{
@@ -61,6 +62,9 @@ enum ProjectionCandidate<'tcx> {
6162
/// Bounds specified on an object type
6263
Object(ty::PolyProjectionPredicate<'tcx>),
6364

65+
/// Built-in bound for a dyn async fn in trait
66+
ObjectRpitit,
67+
6468
/// From an "impl" (or a "pseudo-impl" returned by select)
6569
Select(Selection<'tcx>),
6670
}
@@ -827,6 +831,17 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>(
827831
env_predicates,
828832
false,
829833
);
834+
835+
// `dyn Trait` automagically project their AFITs to `dyn* Future`.
836+
if tcx.is_impl_trait_in_trait(obligation.predicate.def_id)
837+
&& let Some(out_trait_def_id) = data.principal_def_id()
838+
&& let rpitit_trait_def_id = tcx.parent(obligation.predicate.def_id)
839+
&& tcx
840+
.supertrait_def_ids(out_trait_def_id)
841+
.any(|trait_def_id| trait_def_id == rpitit_trait_def_id)
842+
{
843+
candidate_set.push_candidate(ProjectionCandidate::ObjectRpitit);
844+
}
830845
}
831846

832847
#[instrument(
@@ -1247,6 +1262,8 @@ fn confirm_candidate<'cx, 'tcx>(
12471262
ProjectionCandidate::Select(impl_source) => {
12481263
confirm_select_candidate(selcx, obligation, impl_source)
12491264
}
1265+
1266+
ProjectionCandidate::ObjectRpitit => confirm_object_rpitit_candidate(selcx, obligation),
12501267
};
12511268

12521269
// When checking for cycle during evaluation, we compare predicates with
@@ -2034,6 +2051,45 @@ fn confirm_impl_candidate<'cx, 'tcx>(
20342051
}
20352052
}
20362053

2054+
fn confirm_object_rpitit_candidate<'cx, 'tcx>(
2055+
selcx: &mut SelectionContext<'cx, 'tcx>,
2056+
obligation: &ProjectionTermObligation<'tcx>,
2057+
) -> Progress<'tcx> {
2058+
let tcx = selcx.tcx();
2059+
let mut obligations = thin_vec![];
2060+
2061+
// Compute an intersection lifetime for all the input components of this GAT.
2062+
let intersection =
2063+
selcx.infcx.next_region_var(RegionVariableOrigin::MiscVariable(obligation.cause.span));
2064+
for component in obligation.predicate.args {
2065+
match component.unpack() {
2066+
ty::GenericArgKind::Lifetime(lt) => {
2067+
obligations.push(obligation.with(tcx, ty::OutlivesPredicate(lt, intersection)));
2068+
}
2069+
ty::GenericArgKind::Type(ty) => {
2070+
obligations.push(obligation.with(tcx, ty::OutlivesPredicate(ty, intersection)));
2071+
}
2072+
ty::GenericArgKind::Const(_ct) => {
2073+
// Consts have no outlives...
2074+
}
2075+
}
2076+
}
2077+
2078+
Progress {
2079+
term: Ty::new_dynamic(
2080+
tcx,
2081+
tcx.item_bounds_to_existential_predicates(
2082+
obligation.predicate.def_id,
2083+
obligation.predicate.args,
2084+
),
2085+
intersection,
2086+
ty::DynStar,
2087+
)
2088+
.into(),
2089+
obligations,
2090+
}
2091+
}
2092+
20372093
// Get obligations corresponding to the predicates from the where-clause of the
20382094
// associated type itself.
20392095
fn assoc_ty_own_obligations<'cx, 'tcx>(

0 commit comments

Comments
 (0)