From 85a6453fcc07a2af88bd2963ecc954e46bab7236 Mon Sep 17 00:00:00 2001 From: Nickolai Zeldovich Date: Tue, 25 Feb 2025 13:34:55 -0800 Subject: [PATCH] function call param + type param substitution needed for namespace subset checking at function calls --- source/vir/src/ast_to_sst.rs | 95 ++++++++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 27 deletions(-) diff --git a/source/vir/src/ast_to_sst.rs b/source/vir/src/ast_to_sst.rs index 96af79077..46a3e75a8 100644 --- a/source/vir/src/ast_to_sst.rs +++ b/source/vir/src/ast_to_sst.rs @@ -1,6 +1,6 @@ use crate::ast::{ ArithOp, AssertQueryMode, AutospecUsage, BinaryOp, BitwiseOp, CallTarget, ComputeMode, - Constant, Expr, ExprX, FieldOpr, Fun, Function, Ident, IntRange, InvAtomicity, + Constant, Expr, ExprX, FieldOpr, Fun, Function, FunctionKind, Ident, IntRange, InvAtomicity, LoopInvariantKind, MaskSpec, Mode, PatternX, SpannedTyped, Stmt, StmtX, Typ, TypX, Typs, UnaryOp, UnaryOpr, VarAt, VarBinder, VarBinderX, VarBinders, VarIdent, VarIdentDisambiguate, VariantCheck, VirErr, @@ -15,7 +15,7 @@ use crate::sst::{ Bnd, BndX, CallFun, Dest, Exp, ExpX, Exps, InternalFun, LocalDecl, LocalDeclKind, LocalDeclX, ParPurpose, Pars, Stm, StmX, UniqueIdent, }; -use crate::sst_util::{sst_bitwidth, sst_conjoin, sst_int_literal, sst_le, sst_lt, sst_unit_value}; +use crate::sst_util::{sst_bitwidth, sst_conjoin, sst_int_literal, sst_le, sst_lt, sst_unit_value, subst_exp}; use crate::sst_visitor::{map_exp_visitor, map_stm_exp_visitor}; use crate::util::vec_map_result; use crate::visitor::VisitorControlFlow; @@ -831,23 +831,63 @@ fn is_small_exp_or_loc(exp: &Exp) -> bool { } } -fn mask_set_for_call(ctx: &Ctx, state: &State, function: &Function) -> MaskSet { - let mask_spec = function.x.mask_spec_or_default(); - let mut inv_exps = vec![]; - match &mask_spec { - MaskSpec::InvariantOpens(es) | MaskSpec::InvariantOpensExcept(es) => { - for e in es.iter() { - let pars = - crate::ast_to_sst_func::params_to_pre_post_pars(&function.x.params, true); - let exp = expr_to_exp_skip_checks(ctx, state.diagnostics, &pars, e).unwrap(); - inv_exps.push((e.span.clone(), exp)); +fn mask_set_for_call(ctx: &Ctx, state: &State, function: &Function, typs: &Typs, args: &Vec) -> MaskSet { + let (trait_typ_substs, req_ens_function) = + if let FunctionKind::TraitMethodImpl { method, trait_path, trait_typ_args, .. } = + &function.x.kind + { + // Inherit opens_invariants from trait method declaration + let tr = &ctx.trait_map[trait_path]; + let mut typ_params = vec![crate::def::trait_self_type_param()]; + for (x, _) in tr.x.typ_params.iter() { + typ_params.push(x.clone()); } - } + let mut trait_typ_substs: HashMap = HashMap::new(); + assert!(typ_params.len() == trait_typ_args.len()); + for (x, t) in typ_params.iter().zip(trait_typ_args.iter()) { + trait_typ_substs.insert(x.clone(), t.clone()); + } + (trait_typ_substs, &ctx.func_map[method]) + } else { + (HashMap::new(), function) + }; + + let mut typ_substs = trait_typ_substs; + assert!(req_ens_function.x.typ_params.len() == typs.len()); + for (n, typ_param) in req_ens_function.x.typ_params.iter().enumerate() { + let typ = &typs[n]; + typ_substs.insert(typ_param.clone(), typ.clone()); + } + + let pars = crate::ast_to_sst_func::params_to_pars(&req_ens_function.x.params, true); + assert!(req_ens_function.x.params.len() == args.len()); + let mut param_substs = HashMap::::new(); + for (n, param) in req_ens_function.x.params.iter().enumerate() { + let arg = state.finalize_exp(ctx, &args[n]).unwrap(); + param_substs.insert(param.x.name.clone(), arg); + } + let expr_to_exp = |e| { + let exp = expr_to_exp_skip_checks(ctx, state.diagnostics, &pars, e).unwrap(); + let exp = state.finalize_exp(ctx, &exp).unwrap(); + let exp = subst_exp(&typ_substs, ¶m_substs, &exp); + exp }; + + let mask_spec = req_ens_function.x.mask_spec_or_default(); match &mask_spec { - MaskSpec::InvariantOpens(_exprs) => MaskSet::from_list(&inv_exps, &function.span), - MaskSpec::InvariantOpensExcept(_exprs) => { - MaskSet::from_list_complement(&inv_exps, &function.span) + MaskSpec::InvariantOpens(es) => { + let mut inv_exps = vec![]; + for e in es.iter() { + inv_exps.push((e.span.clone(), expr_to_exp(e))); + }; + MaskSet::from_list(&inv_exps, &req_ens_function.span) + } + MaskSpec::InvariantOpensExcept(es) => { + let mut inv_exps = vec![]; + for e in es.iter() { + inv_exps.push((e.span.clone(), expr_to_exp(e))); + }; + MaskSet::from_list_complement(&inv_exps, &req_ens_function.span) } } } @@ -879,21 +919,11 @@ fn stm_call( stms.push(init_var(&arg.span, &temp_id, arg)); } } - let call = StmX::Call { - fun: name, - resolved_method, - mode: fun.x.mode, - typ_args: typs, - args: Arc::new(small_args), - split: None, - dest, - assert_id: state.next_assert_id(), - }; if !state.checking_recommends(ctx) { match &state.mask { Some(caller_mask) => { - let callee_mask = mask_set_for_call(ctx, state, &fun); + let callee_mask = mask_set_for_call(ctx, state, &fun, &typs, &small_args); for assertion in callee_mask.subset_of(ctx, caller_mask, span) { stms.push(Spanned::new( span.clone(), @@ -905,6 +935,17 @@ fn stm_call( } } + let call = StmX::Call { + fun: name, + resolved_method, + mode: fun.x.mode, + typ_args: typs, + args: Arc::new(small_args), + split: None, + dest, + assert_id: state.next_assert_id(), + }; + stms.push(Spanned::new(span.clone(), call)); Ok(stms_to_one_stm(span, stms)) }