Skip to content

Cache deref chain #7213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/cairo-lang-semantic/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,11 @@ pub trait SemanticGroup:
impl_type_def_id: ImplTypeDefId,
) -> Maybe<GenericParamsData>;

/// Returns the deref chain and diagnostics for a given type.
#[salsa::invoke(items::imp::deref_chain)]
#[salsa::cycle(items::imp::deref_chain_cycle)]
fn deref_chain(&self, ty: TypeId, try_deref_mut: bool) -> Maybe<items::imp::DerefChain>;

// Impl type.
// ================
/// Returns the implized impl type if the impl is concrete. Returns a TypeId that's not an impl
Expand Down
91 changes: 20 additions & 71 deletions crates/cairo-lang-semantic/src/expr/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode, ast};
use cairo_lang_utils as utils;
use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap};
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use cairo_lang_utils::{Intern, LookupIntern, OptionHelper, extract_matches, try_extract_matches};
Expand All @@ -51,9 +50,9 @@ use super::pattern::{
PatternOtherwise, PatternTuple, PatternVariable,
};
use crate::corelib::{
CoreTraitContext, core_binary_operator, core_bool_ty, core_unary_operator, deref_mut_trait,
deref_trait, false_literal_expr, get_core_trait, get_usize_ty, never_ty, numeric_literal_trait,
true_literal_expr, try_get_core_ty_by_name, unit_expr, unit_ty, unwrap_error_propagation_type,
CoreTraitContext, core_binary_operator, core_bool_ty, core_unary_operator, false_literal_expr,
get_core_trait, get_usize_ty, never_ty, numeric_literal_trait, true_literal_expr,
try_get_core_ty_by_name, unit_expr, unit_ty, unwrap_error_propagation_type,
};
use crate::db::SemanticGroup;
use crate::diagnostic::SemanticDiagnosticKind::{self, *};
Expand Down Expand Up @@ -3061,7 +3060,7 @@ fn get_enriched_type_member_access(
match e.get_member(accessed_member_name) {
Some(value) => return Ok(Some(value)),
None => {
if e.exploration_tail.is_none() {
if e.deref_chain.len() == e.explored_derefs {
// There's no further exploration to be done, and member was not found.
return Ok(None);
}
Expand All @@ -3071,7 +3070,7 @@ fn get_enriched_type_member_access(
entry.swap_remove()
}
Entry::Vacant(_) => {
let (_, long_ty) = peel_snapshots(ctx.db, ty);
let (_, long_ty) = finalized_snapshot_peeled_ty(ctx, ty, stable_ptr)?;
let members =
if let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) = long_ty {
let members = ctx.db.concrete_struct_members(concrete_struct_id)?;
Expand All @@ -3086,10 +3085,15 @@ fn get_enriched_type_member_access(
} else {
Default::default()
};
EnrichedMembers { members, deref_functions: vec![], exploration_tail: Some(expr.id) }

EnrichedMembers {
members,
deref_chain: ctx.db.deref_chain(ty, is_mut_var)?.derefs,
explored_derefs: 0,
}
}
};
enrich_members(ctx, &mut enriched_members, is_mut_var, stable_ptr, accessed_member_name)?;
enrich_members(ctx, &mut enriched_members, stable_ptr, accessed_member_name)?;
let e = ctx.resolver.type_enriched_members.entry(key).or_insert(enriched_members);
Ok(e.get_member(accessed_member_name))
}
Expand All @@ -3101,84 +3105,29 @@ fn get_enriched_type_member_access(
fn enrich_members(
ctx: &mut ComputationContext<'_>,
enriched_members: &mut EnrichedMembers,
is_mut_var: bool,
stable_ptr: ast::ExprPtr,
accessed_member_name: &str,
) -> Maybe<()> {
let EnrichedMembers { members: enriched, deref_functions, exploration_tail } = enriched_members;
let mut visited_types: OrderedHashSet<TypeId> = OrderedHashSet::default();

let expr_id =
exploration_tail.expect("`enrich_members` should be called with a `calc_tail` value.");
let mut expr = ExprAndId { expr: ctx.arenas.exprs[expr_id].clone(), id: expr_id };

let deref_mut_trait_id = deref_mut_trait(ctx.db);
let deref_trait_id = deref_trait(ctx.db);

let compute_deref_method_function_call_data =
|ctx: &mut ComputationContext<'_>, expr: ExprAndId, use_deref_mut: bool| {
let deref_trait = if use_deref_mut { deref_mut_trait_id } else { deref_trait_id };
compute_method_function_call_data(
ctx,
&[deref_trait],
if use_deref_mut { "deref_mut".into() } else { "deref".into() },
expr.clone(),
stable_ptr.0,
None,
|_, _, _| None,
|_, _, _| None,
)
};

// If the variable is mutable, and implements DerefMut, we use DerefMut in the first iteration.
let mut use_deref_mut = deref_functions.is_empty()
&& is_mut_var
&& compute_deref_method_function_call_data(ctx, expr.clone(), true).is_ok();

// This function either finds a member and sets `exploration_tail` or finishes the exploration
// and leaves that exploration tail as `None`.
*exploration_tail = None;
let EnrichedMembers { members: enriched, deref_chain, explored_derefs } = enriched_members;

// Add members of derefed types.
while let Ok((function_id, _, cur_expr, mutability)) =
compute_deref_method_function_call_data(ctx, expr, use_deref_mut)
{
deref_functions.push((function_id, mutability));
use_deref_mut = false;
let n_deref = deref_functions.len();
expr = cur_expr;
let derefed_expr = expr_function_call(
ctx,
function_id,
vec![NamedArg(expr, None, mutability)],
stable_ptr,
stable_ptr,
)?;
let ty = ctx.reduce_ty(derefed_expr.ty());
let (_, long_ty) = finalized_snapshot_peeled_ty(ctx, ty, stable_ptr)?;
// If the type is still a variable we stop looking for derefed members.
if let TypeLongId::Var(_) = long_ty {
break;
}
expr = ExprAndId { expr: derefed_expr.clone(), id: ctx.arenas.exprs.alloc(derefed_expr) };
for deref_info in deref_chain.iter().skip(*explored_derefs).cloned() {
*explored_derefs += 1;
let (_, long_ty) = finalized_snapshot_peeled_ty(ctx, deref_info.target_ty, stable_ptr)?;
if let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) = long_ty {
let members = ctx.db.concrete_struct_members(concrete_struct_id)?;
for (member_name, member) in members.iter() {
// Insert member if there is not already a member with the same name.
enriched.entry(member_name.clone()).or_insert_with(|| (member.clone(), n_deref));
enriched
.entry(member_name.clone())
.or_insert_with(|| (member.clone(), *explored_derefs));
}
// If member is contained we can stop the calculation post the lookup.
if members.contains_key(accessed_member_name) {
// Found member, so exploration isn't done - setting up the tail.
*exploration_tail = Some(expr.id);
// Found member, so exploration isn't done.
break;
}
}
if !visited_types.insert(long_ty.intern(ctx.db)) {
// Break if we have a cycle. A diagnostic will be reported from the impl and not from
// member access.
break;
}
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cairo-lang-semantic/src/expr/inference/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ pub fn enrich_lookup_context(
}

/// Adds the defining module of the type to the lookup context.
fn enrich_lookup_context_with_ty(
pub fn enrich_lookup_context_with_ty(
db: &dyn SemanticGroup,
ty: TypeId,
lookup_context: &mut ImplLookupContext,
Expand Down
118 changes: 115 additions & 3 deletions crates/cairo-lang-semantic/src/items/imp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ use super::visibility::peek_visible_in;
use super::{TraitOrImplContext, resolve_trait_path};
use crate::corelib::{
CoreTraitContext, concrete_destruct_trait, concrete_drop_trait, copy_trait, core_crate,
deref_trait, destruct_trait, drop_trait, fn_once_trait, fn_trait, get_core_trait,
panic_destruct_trait,
deref_mut_trait, deref_trait, destruct_trait, drop_trait, fn_once_trait, fn_trait,
get_core_trait, panic_destruct_trait,
};
use crate::db::{SemanticGroup, get_resolver_data_options};
use crate::diagnostic::SemanticDiagnosticKind::{self, *};
Expand All @@ -71,7 +71,7 @@ use crate::expr::compute::{ComputationContext, ContextFunction, Environment, com
use crate::expr::inference::canonic::ResultNoErrEx;
use crate::expr::inference::conform::InferenceConform;
use crate::expr::inference::infers::InferenceEmbeddings;
use crate::expr::inference::solver::SolutionSet;
use crate::expr::inference::solver::{SolutionSet, enrich_lookup_context_with_ty};
use crate::expr::inference::{
ImplVarId, ImplVarTraitItemMappings, Inference, InferenceError, InferenceId,
};
Expand Down Expand Up @@ -768,6 +768,118 @@ pub fn impl_semantic_definition_diagnostics(
diagnostics.build()
}

/// Represents a chain of dereferences.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DerefChain {
pub derefs: Arc<[DerefInfo]>,
}

/// Represents a single steps in a deref chain.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DerefInfo {
/// The concrete `Deref::deref` or `MutDeref::mderef_mut` function.
pub function_id: FunctionId,
/// The mutability of the self argument of the deref function.
pub self_mutability: Mutability,
/// The target type of the deref function.
pub target_ty: TypeId,
}

/// Cycle handling for [crate::db::SemanticGroup::deref_chain].
pub fn deref_chain_cycle(
_db: &dyn SemanticGroup,
_cycle: &salsa::Cycle,
_ty: &TypeId,
_try_deref_mut: &bool,
) -> Maybe<DerefChain> {
// `SemanticDiagnosticKind::DerefCycle` will be reported by `deref_impl_diagnostics`.
Maybe::Err(skip_diagnostic())
}

/// Query implementation of [crate::db::SemanticGroup::deref_chain].
pub fn deref_chain(db: &dyn SemanticGroup, ty: TypeId, try_deref_mut: bool) -> Maybe<DerefChain> {
let mut opt_deref = None;
if try_deref_mut {
opt_deref = try_get_deref_func_and_target(db, ty, true)?;
}
let self_mutability = if opt_deref.is_some() {
Mutability::Reference
} else {
opt_deref = try_get_deref_func_and_target(db, ty, false)?;
Mutability::Immutable
};

let Some((function_id, target_ty)) = opt_deref else {
return Ok(DerefChain { derefs: Arc::new([]) });
};

let inner_chain = db.deref_chain(target_ty, false)?;

Ok(DerefChain {
derefs: chain!(
[DerefInfo { function_id, target_ty, self_mutability }],
inner_chain.derefs.iter().cloned()
)
.collect(),
})
}

/// Tries to find the deref function and the target type for a given type and deref trait.
fn try_get_deref_func_and_target(
db: &dyn SemanticGroup,
ty: TypeId,
is_mut_deref: bool,
) -> Result<Option<(FunctionId, TypeId)>, DiagnosticAdded> {
let (deref_trait_id, deref_method) = if is_mut_deref {
(deref_mut_trait(db), "deref_mut".into())
} else {
(deref_trait(db), "deref".into())
};

let defs_db = db.upcast();
let mut lookup_context =
ImplLookupContext::new(deref_trait_id.module_file_id(defs_db).0, vec![]);
enrich_lookup_context_with_ty(db, ty, &mut lookup_context);
let concrete_trait = ConcreteTraitLongId {
trait_id: deref_trait_id,
generic_args: vec![GenericArgumentId::Type(ty)],
}
.intern(db);
let Ok(deref_impl) = get_impl_at_context(db, lookup_context, concrete_trait, None) else {
return Ok(None);
};
let concrete_impl_id = match deref_impl.lookup_intern(db) {
ImplLongId::Concrete(concrete_impl_id) => concrete_impl_id,
_ => panic!("Expected concrete impl"),
};

let deref_trait_func =
db.trait_function_by_name(deref_trait_id, deref_method).unwrap().unwrap();
let function_id = FunctionLongId {
function: ConcreteFunction {
generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
impl_id: deref_impl,
function: deref_trait_func,
}),
generic_args: vec![],
},
}
.intern(db);

let data = db.priv_impl_definition_data(concrete_impl_id.impl_def_id(db)).unwrap();
let mut types_iter = data.item_type_asts.iter();
let (impl_item_type_id, _) = types_iter.next().unwrap();
if types_iter.next().is_some() {
panic!(
"get_impl_based_on_single_impl_type called with an impl that has more than one type"
);
}
let ty = db.impl_type_def_resolved_type(*impl_item_type_id).unwrap();
let ty = concrete_impl_id.substitution(db)?.substitute(db, ty).unwrap();

Ok(Some((function_id, ty)))
}

/// Reports diagnostic for a deref impl.
fn deref_impl_diagnostics(
db: &dyn SemanticGroup,
Expand Down
23 changes: 13 additions & 10 deletions crates/cairo-lang-semantic/src/resolve/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::iter::Peekable;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;

use cairo_lang_defs::ids::{
GenericKind, GenericParamId, GenericTypeId, ImplDefId, LanguageElementId, LookupItemId,
Expand Down Expand Up @@ -44,7 +45,7 @@ use crate::items::feature_kind::{FeatureConfig, FeatureKind, extract_feature_con
use crate::items::functions::{GenericFunctionId, ImplGenericFunctionId};
use crate::items::generics::generic_params_to_args;
use crate::items::imp::{
ConcreteImplId, ConcreteImplLongId, ImplImplId, ImplLongId, ImplLookupContext,
ConcreteImplId, ConcreteImplLongId, DerefInfo, ImplImplId, ImplLongId, ImplLookupContext,
};
use crate::items::module::ModuleItemInfo;
use crate::items::trt::{
Expand All @@ -55,7 +56,7 @@ use crate::items::{TraitOrImplContext, visibility};
use crate::substitution::{GenericSubstitution, SemanticRewriter};
use crate::types::{ConcreteEnumLongId, ImplTypeId, are_coupons_enabled, resolve_type};
use crate::{
ConcreteFunction, ConcreteTypeId, ConcreteVariant, ExprId, FunctionId, FunctionLongId,
ConcreteFunction, ConcreteTypeId, ConcreteVariant, FunctionId, FunctionLongId,
GenericArgumentId, GenericParam, Member, Mutability, TypeId, TypeLongId,
};

Expand Down Expand Up @@ -118,21 +119,23 @@ pub struct EnrichedMembers {
/// A map from member names to their semantic representation and the number of deref operations
/// needed to access them.
pub members: OrderedHashMap<SmolStr, (Member, usize)>,
/// The sequence of deref functions needed to access the members.
pub deref_functions: Vec<(FunctionId, Mutability)>,
/// The tail of deref chain explored so far. The search for additional members will continue
/// from this point.
/// Useful for partial computation of enriching members where a member was already previously
/// found.
pub exploration_tail: Option<ExprId>,
/// The sequence of deref needed to access the members.
pub deref_chain: Arc<[DerefInfo]>,
// The number of derefs that were explored.
pub explored_derefs: usize,
}
impl EnrichedMembers {
/// Returns `EnrichedTypeMemberAccess` for a single member if exists.
pub fn get_member(&self, name: &str) -> Option<EnrichedTypeMemberAccess> {
let (member, n_derefs) = self.members.get(name)?;
Some(EnrichedTypeMemberAccess {
member: member.clone(),
deref_functions: self.deref_functions[..*n_derefs].to_vec(),
deref_functions: self
.deref_chain
.iter()
.map(|deref_info| (deref_info.function_id, deref_info.self_mutability))
.take(*n_derefs)
.collect(),
})
}
}
Expand Down