Skip to content

Commit

Permalink
Fix issue with impl selection
Browse files Browse the repository at this point in the history
  • Loading branch information
makspll committed Apr 3, 2024
1 parent 9961a9e commit 3ff539d
Show file tree
Hide file tree
Showing 13 changed files with 17,398 additions and 14,318 deletions.
2 changes: 1 addition & 1 deletion .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[env]
TARGET_DIR={ value = "target", relative = true }
TARGET_DIR = { value = "target", relative = true }
2 changes: 2 additions & 0 deletions crates/bevy_api_gen/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[rust]
debug-logging = false
24 changes: 16 additions & 8 deletions crates/bevy_api_gen/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl<'tcx> BevyCtxt<'tcx> {
#[derive(Clone, Default, Debug)]
pub(crate) struct ReflectType<'tcx> {
/// Map from traits to their implementations for the reflect type (from a selection)
pub(crate) trait_impls: Option<HashMap<DefId, DefId>>,
pub(crate) trait_impls: Option<HashMap<DefId, Vec<DefId>>>,
/// Information about the ADT structure, fields, and variants
pub(crate) variant_data: Option<AdtDef<'tcx>>,
/// Functions passing criteria to be proxied
Expand Down Expand Up @@ -91,7 +91,7 @@ pub(crate) const DEF_PATHS_GET_TYPE_REGISTRATION: [&str; 2] = [

/// A collection of traits which we search for in the codebase, some of these are necessary to figure out if a type
/// is Clone or Debug for the purposes of the macro code generation
pub(crate) const FN_SOURCE_TRAITS: [&str; 13] = [
pub(crate) const STD_SOURCE_TRAITS: [&str; 13] = [
// PRINTING
"std::fmt::Debug",
"std::string::ToString",
Expand All @@ -106,8 +106,8 @@ pub(crate) const FN_SOURCE_TRAITS: [&str; 13] = [
"std::ops::Rem",
"std::cmp::Eq",
"std::cmp::PartialEq",
"std::ord::Ord", // we don't use these fully cuz of the output types not being lua primitives, but keeping it for the future
"std::ord::PartialOrd",
"std::cmp::Ord", // we don't use these fully cuz of the output types not being lua primitives, but keeping it for the future
"std::cmp::PartialOrd",
];

/// A collection of common traits stored for quick access.
Expand All @@ -118,7 +118,7 @@ pub(crate) struct CachedTraits {
pub(crate) bevy_reflect_reflect: Option<DefId>,
pub(crate) bevy_reflect_get_type_registration: Option<DefId>,
/// Traits whose methods can be included in the generated code
pub(crate) fn_source_traits: HashMap<String, DefId>,
pub(crate) std_source_traits: HashMap<String, DefId>,
}

impl CachedTraits {
Expand All @@ -130,10 +130,18 @@ impl CachedTraits {
self.bevy_reflect_reflect.is_some() && self.bevy_reflect_get_type_registration.is_some()
}

pub(crate) fn has_all_fn_source_traits(&self) -> bool {
self.fn_source_traits
pub(crate) fn has_all_std_source_traits(&self) -> bool {
STD_SOURCE_TRAITS
.iter()
.all(|(k, _)| FN_SOURCE_TRAITS.contains(&k.as_str()))
.all(|t| self.std_source_traits.contains_key(*t))
}

pub(crate) fn missing_std_source_traits(&self) -> Vec<String> {
STD_SOURCE_TRAITS
.iter()
.filter(|t| !self.std_source_traits.contains_key(**t))
.map(|s| (*s).to_owned())
.collect()
}
}

Expand Down
33 changes: 26 additions & 7 deletions crates/bevy_api_gen/src/passes/cache_traits.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use log::trace;
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_span::Symbol;

use crate::{
Args, BevyCtxt, DEF_PATHS_FROM_LUA, DEF_PATHS_GET_TYPE_REGISTRATION, DEF_PATHS_INTO_LUA,
DEF_PATHS_REFLECT, FN_SOURCE_TRAITS,
DEF_PATHS_REFLECT, STD_SOURCE_TRAITS,
};

/// Finds and caches relevant traits, if they cannot be found throws an ICE
Expand All @@ -25,15 +26,15 @@ pub(crate) fn cache_traits(ctxt: &mut BevyCtxt<'_>, _args: &Args) -> bool {
} else if DEF_PATHS_GET_TYPE_REGISTRATION.contains(&def_path_str.as_str()) {
trace!("found GetTypeRegistration trait def id: {trait_did:?}");
ctxt.cached_traits.bevy_reflect_get_type_registration = Some(trait_did);
} else if FN_SOURCE_TRAITS.contains(&def_path_str.as_str()) {
} else if STD_SOURCE_TRAITS.contains(&def_path_str.as_str()) {
trace!("found misc trait def id: {trait_did:?}");
ctxt.cached_traits
.fn_source_traits
.std_source_traits
.insert(def_path_str.to_string(), trait_did);
} else if FN_SOURCE_TRAITS.contains(&def_path_str.as_str()) {
} else if STD_SOURCE_TRAITS.contains(&def_path_str.as_str()) {
trace!("found misc trait def id: {trait_did:?}");
ctxt.cached_traits
.fn_source_traits
.std_source_traits
.insert(def_path_str.to_string(), trait_did);
}
}
Expand All @@ -52,9 +53,27 @@ pub(crate) fn cache_traits(ctxt: &mut BevyCtxt<'_>, _args: &Args) -> bool {
)
}

if !ctxt.cached_traits.has_all_fn_source_traits() {
// some crates specifically do not have std in scope via `#![no_std]` which means we do not care about these traits
let has_std = tcx
.get_attrs_by_path(LOCAL_CRATE.as_def_id(), &[Symbol::intern("no_std")])
.map(|_| ())
.next()
.is_none();

log::trace!("has_std: {}", has_std);

if has_std && !ctxt.cached_traits.has_all_std_source_traits() {
log::debug!(
"all traits: {}",
tcx.all_traits()
.map(|t| tcx.def_path_str(t).to_string())
.collect::<Vec<_>>()
.join(", ")
);

panic!(
"Could not find all fn source traits in crate: {}, did bootstrapping go wrong?",
"Could not find traits: [{}] in crate: {}, did bootstrapping go wrong?",
ctxt.cached_traits.missing_std_source_traits().join(", "),
tcx.crate_name(LOCAL_CRATE)
)
}
Expand Down
4 changes: 1 addition & 3 deletions crates/bevy_api_gen/src/passes/find_methods_and_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub(crate) fn find_methods_and_fields(ctxt: &mut BevyCtxt<'_>, _args: &Args) ->
.inherent_impls(def_id)
.unwrap()
.iter()
.chain(trait_impls_for_ty.iter())
.chain(trait_impls_for_ty.iter().flatten())
.collect::<Vec<_>>();

// sort them to avoid unnecessary diffs, we can use hashes here as they are forever stable (touch wood)
Expand Down Expand Up @@ -126,8 +126,6 @@ pub(crate) fn find_methods_and_fields(ctxt: &mut BevyCtxt<'_>, _args: &Args) ->
if unstability.is_unstable() {
log::debug!("Skipping unstable function: `{}` on type: `{}` feature: {:?}", ctxt.tcx.item_name(fn_did), ctxt.tcx.item_name(def_id), unstability.feature.as_str());
return None;
} else {
log::debug!("Allowing possibly unstable function: `{}` on type: `{}`, stability: {:?}", ctxt.tcx.item_name(fn_did), ctxt.tcx.item_name(def_id), unstability)
}
};

Expand Down
105 changes: 85 additions & 20 deletions crates/bevy_api_gen/src/passes/find_trait_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@ use std::collections::HashMap;

use log::trace;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_trait_selection::infer::InferCtxtExt;
use rustc_infer::{
infer::{BoundRegionConversionTime, DefineOpaqueTypes, InferCtxt, TyCtxtInferExt},
traits::{Obligation, ObligationCause, PolyTraitObligation},
};
use rustc_middle::ty::{Binder, ClauseKind, ImplPolarity, PolyTraitPredicate, TraitPredicate, Ty};
use rustc_span::DUMMY_SP;
use rustc_trait_selection::{
infer::InferCtxtExt,
traits::{elaborate, ObligationCtxt},
};

use crate::{Args, BevyCtxt};

Expand Down Expand Up @@ -31,13 +39,13 @@ pub(crate) fn find_trait_impls(ctxt: &mut BevyCtxt<'_>, _args: &Args) -> bool {
ctxt.cached_traits.mlua_from_lua_multi.unwrap(),
reflect_ty_did,
)
.is_none()
.is_empty()
|| type_impl_of_trait(
tcx,
ctxt.cached_traits.mlua_into_lua_multi.unwrap(),
reflect_ty_did,
)
.is_none();
.is_empty();

if !retaining {
trace!(
Expand All @@ -48,13 +56,23 @@ pub(crate) fn find_trait_impls(ctxt: &mut BevyCtxt<'_>, _args: &Args) -> bool {
retaining
});

log::trace!(
"Looking for impls of the traits: [{}]",
ctxt.cached_traits
.std_source_traits
.values()
.map(|d| tcx.def_path_str(*d))
.collect::<Vec<_>>()
.join(", ")
);

for (reflect_ty_did, type_ctxt) in ctxt.reflect_types.iter_mut() {
let mut impls = Vec::default();

for trait_did in ctxt.cached_traits.fn_source_traits.values() {
let impl_ = type_impl_of_trait(tcx, *trait_did, reflect_ty_did);
if let Some(impl_did) = impl_ {
impls.push((*trait_did, impl_did));
for trait_did in ctxt.cached_traits.std_source_traits.values() {
let matching_impls = type_impl_of_trait(tcx, *trait_did, reflect_ty_did);
if !matching_impls.is_empty() {
impls.push((*trait_did, matching_impls));
}
}

Expand All @@ -64,17 +82,19 @@ pub(crate) fn find_trait_impls(ctxt: &mut BevyCtxt<'_>, _args: &Args) -> bool {
true
}

/// Checks if a type implements a trait, returns all implementations with the generic args required
fn type_impl_of_trait(
tcx: &rustc_middle::ty::TyCtxt<'_>,
trait_did: DefId,
reflect_ty_did: &rustc_hir::def_id::DefId,
) -> Option<DefId> {
) -> Vec<DefId> {
log::trace!(
"Finding impl for trait: {:?} on type: {:?}",
tcx.def_path_str(trait_did),
tcx.def_path_str(*reflect_ty_did)
);
let mut out = None;
let mut out = Vec::default();

tcx.for_each_relevant_impl(
trait_did,
tcx.type_of(reflect_ty_did).instantiate_identity(),
Expand All @@ -83,20 +103,65 @@ fn type_impl_of_trait(
"Possible impl for trait: {:?} on type: {:?} found: {:?}",
tcx.def_path_str(trait_did),
tcx.def_path_str(reflect_ty_did),
impl_did
impl_did,
);
//TODO: false negatives coming from this inference

let ty = tcx.type_of(reflect_ty_did).instantiate_identity();
let param_env = tcx.param_env(impl_did);
let applies = tcx
.infer_ctxt()
.build()
.type_implements_trait(trait_did, [ty], param_env)
.must_apply_modulo_regions();
if applies {
trace!("Applies with: {param_env:?}, type: {ty}",);
out = Some(impl_did);
let infcx = tcx.infer_ctxt().build();
let result = impl_matches(&infcx, ty, trait_did, impl_did);
log::trace!("Result: {:#?}", result);
if result {
trace!(
"Type: `{}` implements trait: `{}`",
ty,
tcx.item_name(trait_did)
);
out.push(impl_did)
} else {
trace!(
"Type: `{}` does not implement trait: `{}`",
ty,
tcx.item_name(trait_did)
);
}
},
);
out
}

/// this is the same logic as in rustc_trait_selection::...::recompute_applicable_impls, i.e. we need to go through all
/// impls that may match and perform full on matching on them
/// If this goes out of date with rustc, we can just copy the function here
fn impl_matches<'tcx>(
infcx: &InferCtxt<'tcx>,
ty: Ty<'tcx>,
trait_def_id: DefId,
impl_def_id: DefId,
) -> bool {
let tcx = infcx.tcx;

let impl_may_apply = |impl_def_id| {
let ocx = ObligationCtxt::new(infcx);
let param_env = tcx.param_env_reveal_all_normalized(impl_def_id);
let impl_args = infcx.fresh_args_for_item(DUMMY_SP, impl_def_id);
let impl_trait_ref = tcx
.impl_trait_ref(impl_def_id)
.unwrap()
.instantiate(tcx, impl_args);
let impl_trait_ref = ocx.normalize(&ObligationCause::dummy(), param_env, impl_trait_ref);
let impl_trait_ref_ty = impl_trait_ref.self_ty();
if let Err(_) = ocx.eq(&ObligationCause::dummy(), param_env, impl_trait_ref_ty, ty) {
return false;
}

let impl_predicates = tcx.predicates_of(impl_def_id).instantiate(tcx, impl_args);
ocx.register_obligations(impl_predicates.predicates.iter().map(|&predicate| {
Obligation::new(tcx, ObligationCause::dummy(), param_env, predicate)
}));

ocx.select_where_possible().is_empty()
};

infcx.probe(|_| impl_may_apply(impl_def_id))
}
Loading

0 comments on commit 3ff539d

Please sign in to comment.