Skip to content

Commit

Permalink
function call param + type param substitution
Browse files Browse the repository at this point in the history
needed for namespace subset checking at function calls
  • Loading branch information
zeldovich committed Feb 25, 2025
1 parent 6bf7bc1 commit 85a6453
Showing 1 changed file with 68 additions and 27 deletions.
95 changes: 68 additions & 27 deletions source/vir/src/ast_to_sst.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::ast::{
ArithOp, AssertQueryMode, AutospecUsage, BinaryOp, BitwiseOp, CallTarget, ComputeMode,
Constant, Expr, ExprX, FieldOpr, Fun, Function, Ident, IntRange, InvAtomicity,
Constant, Expr, ExprX, FieldOpr, Fun, Function, FunctionKind, Ident, IntRange, InvAtomicity,
LoopInvariantKind, MaskSpec, Mode, PatternX, SpannedTyped, Stmt, StmtX, Typ, TypX, Typs,
UnaryOp, UnaryOpr, VarAt, VarBinder, VarBinderX, VarBinders, VarIdent, VarIdentDisambiguate,
VariantCheck, VirErr,
Expand All @@ -15,7 +15,7 @@ use crate::sst::{
Bnd, BndX, CallFun, Dest, Exp, ExpX, Exps, InternalFun, LocalDecl, LocalDeclKind, LocalDeclX,
ParPurpose, Pars, Stm, StmX, UniqueIdent,
};
use crate::sst_util::{sst_bitwidth, sst_conjoin, sst_int_literal, sst_le, sst_lt, sst_unit_value};
use crate::sst_util::{sst_bitwidth, sst_conjoin, sst_int_literal, sst_le, sst_lt, sst_unit_value, subst_exp};
use crate::sst_visitor::{map_exp_visitor, map_stm_exp_visitor};
use crate::util::vec_map_result;
use crate::visitor::VisitorControlFlow;
Expand Down Expand Up @@ -831,23 +831,63 @@ fn is_small_exp_or_loc(exp: &Exp) -> bool {
}
}

fn mask_set_for_call(ctx: &Ctx, state: &State, function: &Function) -> MaskSet {
let mask_spec = function.x.mask_spec_or_default();
let mut inv_exps = vec![];
match &mask_spec {
MaskSpec::InvariantOpens(es) | MaskSpec::InvariantOpensExcept(es) => {
for e in es.iter() {
let pars =
crate::ast_to_sst_func::params_to_pre_post_pars(&function.x.params, true);
let exp = expr_to_exp_skip_checks(ctx, state.diagnostics, &pars, e).unwrap();
inv_exps.push((e.span.clone(), exp));
fn mask_set_for_call(ctx: &Ctx, state: &State, function: &Function, typs: &Typs, args: &Vec<Exp>) -> MaskSet {
let (trait_typ_substs, req_ens_function) =
if let FunctionKind::TraitMethodImpl { method, trait_path, trait_typ_args, .. } =
&function.x.kind
{
// Inherit opens_invariants from trait method declaration
let tr = &ctx.trait_map[trait_path];
let mut typ_params = vec![crate::def::trait_self_type_param()];
for (x, _) in tr.x.typ_params.iter() {
typ_params.push(x.clone());
}
}
let mut trait_typ_substs: HashMap<Ident, Typ> = HashMap::new();
assert!(typ_params.len() == trait_typ_args.len());
for (x, t) in typ_params.iter().zip(trait_typ_args.iter()) {
trait_typ_substs.insert(x.clone(), t.clone());
}
(trait_typ_substs, &ctx.func_map[method])
} else {
(HashMap::new(), function)
};

let mut typ_substs = trait_typ_substs;
assert!(req_ens_function.x.typ_params.len() == typs.len());
for (n, typ_param) in req_ens_function.x.typ_params.iter().enumerate() {
let typ = &typs[n];
typ_substs.insert(typ_param.clone(), typ.clone());
}

let pars = crate::ast_to_sst_func::params_to_pars(&req_ens_function.x.params, true);
assert!(req_ens_function.x.params.len() == args.len());
let mut param_substs = HashMap::<VarIdent, Exp>::new();
for (n, param) in req_ens_function.x.params.iter().enumerate() {
let arg = state.finalize_exp(ctx, &args[n]).unwrap();
param_substs.insert(param.x.name.clone(), arg);
}
let expr_to_exp = |e| {
let exp = expr_to_exp_skip_checks(ctx, state.diagnostics, &pars, e).unwrap();
let exp = state.finalize_exp(ctx, &exp).unwrap();
let exp = subst_exp(&typ_substs, &param_substs, &exp);
exp
};

let mask_spec = req_ens_function.x.mask_spec_or_default();
match &mask_spec {
MaskSpec::InvariantOpens(_exprs) => MaskSet::from_list(&inv_exps, &function.span),
MaskSpec::InvariantOpensExcept(_exprs) => {
MaskSet::from_list_complement(&inv_exps, &function.span)
MaskSpec::InvariantOpens(es) => {
let mut inv_exps = vec![];
for e in es.iter() {
inv_exps.push((e.span.clone(), expr_to_exp(e)));
};
MaskSet::from_list(&inv_exps, &req_ens_function.span)
}
MaskSpec::InvariantOpensExcept(es) => {
let mut inv_exps = vec![];
for e in es.iter() {
inv_exps.push((e.span.clone(), expr_to_exp(e)));
};
MaskSet::from_list_complement(&inv_exps, &req_ens_function.span)
}
}
}
Expand Down Expand Up @@ -879,21 +919,11 @@ fn stm_call(
stms.push(init_var(&arg.span, &temp_id, arg));
}
}
let call = StmX::Call {
fun: name,
resolved_method,
mode: fun.x.mode,
typ_args: typs,
args: Arc::new(small_args),
split: None,
dest,
assert_id: state.next_assert_id(),
};

if !state.checking_recommends(ctx) {
match &state.mask {
Some(caller_mask) => {
let callee_mask = mask_set_for_call(ctx, state, &fun);
let callee_mask = mask_set_for_call(ctx, state, &fun, &typs, &small_args);
for assertion in callee_mask.subset_of(ctx, caller_mask, span) {
stms.push(Spanned::new(
span.clone(),
Expand All @@ -905,6 +935,17 @@ fn stm_call(
}
}

let call = StmX::Call {
fun: name,
resolved_method,
mode: fun.x.mode,
typ_args: typs,
args: Arc::new(small_args),
split: None,
dest,
assert_id: state.next_assert_id(),
};

stms.push(Spanned::new(span.clone(), call));
Ok(stms_to_one_stm(span, stms))
}
Expand Down

0 comments on commit 85a6453

Please sign in to comment.