Skip to content

Commit 9600e7d

Browse files
Implement projection and shim for AFIDT
1 parent bffa448 commit 9600e7d

File tree

14 files changed

+417
-22
lines changed

14 files changed

+417
-22
lines changed

compiler/rustc_middle/src/ty/instance.rs

+20-17
Original file line numberDiff line numberDiff line change
@@ -678,23 +678,26 @@ impl<'tcx> Instance<'tcx> {
678678
//
679679
// 1) The underlying method expects a caller location parameter
680680
// in the ABI
681-
if resolved.def.requires_caller_location(tcx)
682-
// 2) The caller location parameter comes from having `#[track_caller]`
683-
// on the implementation, and *not* on the trait method.
684-
&& !tcx.should_inherit_track_caller(def)
685-
// If the method implementation comes from the trait definition itself
686-
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
687-
// then we don't need to generate a shim. This check is needed because
688-
// `should_inherit_track_caller` returns `false` if our method
689-
// implementation comes from the trait block, and not an impl block
690-
&& !matches!(
691-
tcx.opt_associated_item(def),
692-
Some(ty::AssocItem {
693-
container: ty::AssocItemContainer::Trait,
694-
..
695-
})
696-
)
697-
{
681+
let needs_track_caller_shim = resolved.def.requires_caller_location(tcx)
682+
// 2) The caller location parameter comes from having `#[track_caller]`
683+
// on the implementation, and *not* on the trait method.
684+
&& !tcx.should_inherit_track_caller(def)
685+
// If the method implementation comes from the trait definition itself
686+
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
687+
// then we don't need to generate a shim. This check is needed because
688+
// `should_inherit_track_caller` returns `false` if our method
689+
// implementation comes from the trait block, and not an impl block
690+
&& !matches!(
691+
tcx.opt_associated_item(def),
692+
Some(ty::AssocItem {
693+
container: ty::AssocItemContainer::Trait,
694+
..
695+
})
696+
);
697+
// We also need to generate a shim if this is an AFIT.
698+
let needs_rpitit_shim =
699+
tcx.return_position_impl_trait_in_trait_shim_data(def).is_some();
700+
if needs_track_caller_shim || needs_rpitit_shim {
698701
if tcx.is_closure_like(def) {
699702
debug!(
700703
" => vtable fn pointer created for closure with #[track_caller]: {:?} for method {:?} {:?}",

compiler/rustc_middle/src/ty/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ mod opaque_types;
148148
mod parameterized;
149149
mod predicate;
150150
mod region;
151+
mod return_position_impl_trait_in_trait;
151152
mod rvalue_scopes;
152153
mod structural_impls;
153154
#[allow(hidden_glob_reexports)]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
use rustc_hir::def_id::DefId;
2+
3+
use crate::ty::{self, ExistentialPredicateStableCmpExt, TyCtxt};
4+
5+
impl<'tcx> TyCtxt<'tcx> {
6+
/// Given a `def_id` of a trait or impl method, compute whether that method needs to
7+
/// have an RPITIT shim applied to it for it to be object safe. If so, return the
8+
/// `def_id` of the RPITIT, and also the args of trait method that returns the RPITIT.
9+
///
10+
/// NOTE that these args are not, in general, the same as than the RPITIT's args. They
11+
/// are a subset of those args, since they do not include the late-bound lifetimes of
12+
/// the RPITIT. Depending on the context, these will need to be dealt with in different
13+
/// ways -- in codegen, it's okay to fill them with ReErased.
14+
pub fn return_position_impl_trait_in_trait_shim_data(
15+
self,
16+
def_id: DefId,
17+
) -> Option<(DefId, ty::EarlyBinder<'tcx, ty::GenericArgsRef<'tcx>>)> {
18+
let assoc_item = self.opt_associated_item(def_id)?;
19+
20+
let (trait_item_def_id, opt_impl_def_id) = match assoc_item.container {
21+
ty::AssocItemContainer::Impl => {
22+
(assoc_item.trait_item_def_id?, Some(self.parent(def_id)))
23+
}
24+
ty::AssocItemContainer::Trait => (def_id, None),
25+
};
26+
27+
let sig = self.fn_sig(trait_item_def_id);
28+
29+
let ty::Alias(ty::Projection, alias_ty) = *sig.skip_binder().skip_binder().output().kind()
30+
else {
31+
return None;
32+
};
33+
34+
if !self.is_impl_trait_in_trait(alias_ty.def_id) {
35+
return None;
36+
}
37+
38+
let args = if let Some(impl_def_id) = opt_impl_def_id {
39+
// Rebase the args from the RPITIT onto the impl trait ref, so we can later
40+
// substitute them with the method args of the *impl* method, since that's
41+
// the instance we're building a vtable shim for.
42+
ty::GenericArgs::identity_for_item(self, trait_item_def_id).rebase_onto(
43+
self,
44+
self.parent(trait_item_def_id),
45+
self.impl_trait_ref(impl_def_id)
46+
.expect("expected impl trait ref from parent of impl item")
47+
.instantiate_identity()
48+
.args,
49+
)
50+
} else {
51+
// This is when we have a default trait implementation.
52+
ty::GenericArgs::identity_for_item(self, trait_item_def_id)
53+
};
54+
55+
Some((alias_ty.def_id, ty::EarlyBinder::bind(args)))
56+
}
57+
58+
/// Given a `DefId` of an RPITIT and its args, return the existential predicates
59+
/// that corresponds to the RPITIT's bounds with the self type erased.
60+
pub fn item_bounds_to_existential_predicates(
61+
self,
62+
def_id: DefId,
63+
args: ty::GenericArgsRef<'tcx>,
64+
) -> &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
65+
let mut bounds: Vec<_> = self
66+
.item_super_predicates(def_id)
67+
.iter_instantiated(self, args)
68+
.filter_map(|clause| {
69+
clause
70+
.kind()
71+
.map_bound(|clause| match clause {
72+
ty::ClauseKind::Trait(trait_pred) => Some(ty::ExistentialPredicate::Trait(
73+
ty::ExistentialTraitRef::erase_self_ty(self, trait_pred.trait_ref),
74+
)),
75+
ty::ClauseKind::Projection(projection_pred) => {
76+
Some(ty::ExistentialPredicate::Projection(
77+
ty::ExistentialProjection::erase_self_ty(self, projection_pred),
78+
))
79+
}
80+
_ => None,
81+
})
82+
.transpose()
83+
})
84+
.collect();
85+
bounds.sort_by(|a, b| a.skip_binder().stable_cmp(self, &b.skip_binder()));
86+
self.mk_poly_existential_predicates(&bounds)
87+
}
88+
}

compiler/rustc_mir_transform/src/shim.rs

+53-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use rustc_index::{Idx, IndexVec};
99
use rustc_middle::mir::patch::MirPatch;
1010
use rustc_middle::mir::*;
1111
use rustc_middle::query::Providers;
12+
use rustc_middle::ty::adjustment::PointerCoercion;
1213
use rustc_middle::ty::{
1314
self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, GenericArgs, Ty, TyCtxt,
1415
};
@@ -710,6 +711,13 @@ fn build_call_shim<'tcx>(
710711
};
711712

712713
let def_id = instance.def_id();
714+
715+
let rpitit_shim = if let ty::InstanceKind::ReifyShim(..) = instance {
716+
tcx.return_position_impl_trait_in_trait_shim_data(def_id)
717+
} else {
718+
None
719+
};
720+
713721
let sig = tcx.fn_sig(def_id);
714722
let sig = sig.map_bound(|sig| tcx.instantiate_bound_regions_with_erased(sig));
715723

@@ -765,9 +773,34 @@ fn build_call_shim<'tcx>(
765773
let mut local_decls = local_decls_for_sig(&sig, span);
766774
let source_info = SourceInfo::outermost(span);
767775

776+
let mut destination = Place::return_place();
777+
if let Some((rpitit_def_id, fn_args)) = rpitit_shim {
778+
let rpitit_args =
779+
fn_args.instantiate_identity().extend_to(tcx, rpitit_def_id, |param, _| {
780+
match param.kind {
781+
ty::GenericParamDefKind::Lifetime => tcx.lifetimes.re_erased.into(),
782+
ty::GenericParamDefKind::Type { .. }
783+
| ty::GenericParamDefKind::Const { .. } => {
784+
unreachable!("rpitit should have no addition ty/ct")
785+
}
786+
}
787+
});
788+
let dyn_star_ty = Ty::new_dynamic(
789+
tcx,
790+
tcx.item_bounds_to_existential_predicates(rpitit_def_id, rpitit_args),
791+
tcx.lifetimes.re_erased,
792+
ty::DynStar,
793+
);
794+
destination = local_decls.push(local_decls[RETURN_PLACE].clone()).into();
795+
local_decls[RETURN_PLACE].ty = dyn_star_ty;
796+
let mut inputs_and_output = sig.inputs_and_output.to_vec();
797+
*inputs_and_output.last_mut().unwrap() = dyn_star_ty;
798+
sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
799+
}
800+
768801
let rcvr_place = || {
769802
assert!(rcvr_adjustment.is_some());
770-
Place::from(Local::new(1 + 0))
803+
Place::from(Local::new(1))
771804
};
772805
let mut statements = vec![];
773806

@@ -854,7 +887,7 @@ fn build_call_shim<'tcx>(
854887
TerminatorKind::Call {
855888
func: callee,
856889
args,
857-
destination: Place::return_place(),
890+
destination,
858891
target: Some(BasicBlock::new(1)),
859892
unwind: if let Some(Adjustment::RefMut) = rcvr_adjustment {
860893
UnwindAction::Cleanup(BasicBlock::new(3))
@@ -882,7 +915,24 @@ fn build_call_shim<'tcx>(
882915
);
883916
}
884917
// BB #1/#2 - return
885-
block(&mut blocks, vec![], TerminatorKind::Return, false);
918+
// NOTE: If this is an RPITIT in dyn, we also want to coerce
919+
// the return type of the function into a `dyn*`.
920+
let stmts = if rpitit_shim.is_some() {
921+
vec![Statement {
922+
source_info,
923+
kind: StatementKind::Assign(Box::new((
924+
Place::return_place(),
925+
Rvalue::Cast(
926+
CastKind::PointerCoercion(PointerCoercion::DynStar, CoercionSource::Implicit),
927+
Operand::Move(destination),
928+
sig.output(),
929+
),
930+
))),
931+
}]
932+
} else {
933+
vec![]
934+
};
935+
block(&mut blocks, stmts, TerminatorKind::Return, false);
886936
if let Some(Adjustment::RefMut) = rcvr_adjustment {
887937
// BB #3 - drop if closure panics
888938
block(

compiler/rustc_monomorphize/src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ fn custom_coerce_unsize_info<'tcx>(
4040
..
4141
})) => Ok(tcx.coerce_unsized_info(impl_def_id)?.custom_kind.unwrap()),
4242
impl_source => {
43-
bug!("invalid `CoerceUnsized` impl_source: {:?}", impl_source);
43+
bug!(
44+
"invalid `CoerceUnsized` from {source_ty} to {target_ty}: impl_source: {:?}",
45+
impl_source
46+
);
4447
}
4548
}
4649
}

compiler/rustc_trait_selection/src/traits/project.rs

+57-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use rustc_data_structures::stack::ensure_sufficient_stack;
77
use rustc_errors::ErrorGuaranteed;
88
use rustc_hir::def::DefKind;
99
use rustc_hir::lang_items::LangItem;
10-
use rustc_infer::infer::DefineOpaqueTypes;
1110
use rustc_infer::infer::resolve::OpportunisticRegionResolver;
11+
use rustc_infer::infer::{DefineOpaqueTypes, RegionVariableOrigin};
1212
use rustc_infer::traits::{ObligationCauseCode, PredicateObligations};
1313
pub use rustc_middle::traits::Reveal;
1414
use rustc_middle::traits::select::OverflowError;
@@ -19,6 +19,7 @@ use rustc_middle::ty::visit::{MaxUniverse, TypeVisitable, TypeVisitableExt};
1919
use rustc_middle::ty::{self, Term, Ty, TyCtxt, TypingMode, Upcast};
2020
use rustc_middle::{bug, span_bug};
2121
use rustc_span::symbol::sym;
22+
use thin_vec::thin_vec;
2223
use tracing::{debug, instrument};
2324

2425
use super::{
@@ -62,6 +63,9 @@ enum ProjectionCandidate<'tcx> {
6263
/// Bounds specified on an object type
6364
Object(ty::PolyProjectionPredicate<'tcx>),
6465

66+
/// Built-in bound for a dyn async fn in trait
67+
ObjectRpitit,
68+
6569
/// From an "impl" (or a "pseudo-impl" returned by select)
6670
Select(Selection<'tcx>),
6771
}
@@ -852,6 +856,17 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>(
852856
env_predicates,
853857
false,
854858
);
859+
860+
// `dyn Trait` automagically project their AFITs to `dyn* Future`.
861+
if tcx.is_impl_trait_in_trait(obligation.predicate.def_id)
862+
&& let Some(out_trait_def_id) = data.principal_def_id()
863+
&& let rpitit_trait_def_id = tcx.parent(obligation.predicate.def_id)
864+
&& tcx
865+
.supertrait_def_ids(out_trait_def_id)
866+
.any(|trait_def_id| trait_def_id == rpitit_trait_def_id)
867+
{
868+
candidate_set.push_candidate(ProjectionCandidate::ObjectRpitit);
869+
}
855870
}
856871

857872
#[instrument(
@@ -1270,6 +1285,8 @@ fn confirm_candidate<'cx, 'tcx>(
12701285
ProjectionCandidate::Select(impl_source) => {
12711286
confirm_select_candidate(selcx, obligation, impl_source)
12721287
}
1288+
1289+
ProjectionCandidate::ObjectRpitit => confirm_object_rpitit_candidate(selcx, obligation),
12731290
};
12741291

12751292
// When checking for cycle during evaluation, we compare predicates with
@@ -2057,6 +2074,45 @@ fn confirm_impl_candidate<'cx, 'tcx>(
20572074
}
20582075
}
20592076

2077+
fn confirm_object_rpitit_candidate<'cx, 'tcx>(
2078+
selcx: &mut SelectionContext<'cx, 'tcx>,
2079+
obligation: &ProjectionTermObligation<'tcx>,
2080+
) -> Progress<'tcx> {
2081+
let tcx = selcx.tcx();
2082+
let mut obligations = thin_vec![];
2083+
2084+
// Compute an intersection lifetime for all the input components of this GAT.
2085+
let intersection =
2086+
selcx.infcx.next_region_var(RegionVariableOrigin::MiscVariable(obligation.cause.span));
2087+
for component in obligation.predicate.args {
2088+
match component.unpack() {
2089+
ty::GenericArgKind::Lifetime(lt) => {
2090+
obligations.push(obligation.with(tcx, ty::OutlivesPredicate(lt, intersection)));
2091+
}
2092+
ty::GenericArgKind::Type(ty) => {
2093+
obligations.push(obligation.with(tcx, ty::OutlivesPredicate(ty, intersection)));
2094+
}
2095+
ty::GenericArgKind::Const(_ct) => {
2096+
// Consts have no outlives...
2097+
}
2098+
}
2099+
}
2100+
2101+
Progress {
2102+
term: Ty::new_dynamic(
2103+
tcx,
2104+
tcx.item_bounds_to_existential_predicates(
2105+
obligation.predicate.def_id,
2106+
obligation.predicate.args,
2107+
),
2108+
intersection,
2109+
ty::DynStar,
2110+
)
2111+
.into(),
2112+
obligations,
2113+
}
2114+
}
2115+
20602116
// Get obligations corresponding to the predicates from the where-clause of the
20612117
// associated type itself.
20622118
fn assoc_ty_own_obligations<'cx, 'tcx>(

0 commit comments

Comments
 (0)