Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When needed, add type parameters and variant check to field access #1411

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions source/vir/src/datatype_to_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ fn field_to_par(span: &Span, f: &Field) -> Par {
)
}

// For soundness, FieldOpr needs type arguments when there are >= 2 variants
// (see https://github.com/verus-lang/verus/issues/1366 )
fn has_field_typ_args(num_variants: usize) -> bool {
num_variants >= 2
}

pub(crate) fn field_typ_args<A: Default>(num_variants: usize, f: impl Fn() -> A) -> A {
if has_field_typ_args(num_variants) { f() } else { A::default() }
}

fn uses_ext_equal(ctx: &Ctx, typ: &Typ) -> bool {
match &**typ {
TypX::Int(_) => false,
Expand Down Expand Up @@ -298,6 +308,8 @@ fn datatype_or_fun_to_air_commands(
}

// constructor and field axioms
let tparams_opt = field_typ_args(variants.len(), || tparams.clone());
let typ_args_opt = field_typ_args(variants.len(), || typ_args.clone());
for variant in variants.iter() {
if let EncodedDtKind::Dt(dt) = &kind {
if ctx.datatypes_with_invariant.contains(dt) {
Expand Down Expand Up @@ -330,30 +342,45 @@ fn datatype_or_fun_to_air_commands(
}
}
for field in variant.fields.iter() {
let mut xfield_params: Vec<air::ast::Typ> = Vec::new();
let mut xfield_args: Vec<air::ast::Expr> = Vec::new();
let mut xfield_unbox_args: Vec<air::ast::Expr> = Vec::new();
for _ in tparams_opt.iter() {
xfield_params.extend(crate::def::types().iter().map(|s| str_typ(s)));
}
for t in typ_args_opt.iter() {
xfield_args.extend(crate::sst_to_air::typ_to_ids(t));
xfield_unbox_args.extend(crate::sst_to_air::typ_to_ids(t));
}
xfield_params.push(dtyp.clone());
xfield_args.push(x_var.clone());
xfield_unbox_args.push(unbox_x.clone());
let id = variant_field_ident(dpath, &variant.name, &field.name);
let internal_id = variant_field_ident_internal(dpath, &variant.name, &field.name, true);
let (typ, _, _) = &field.a;
let xfield = ident_apply(&id, &vec![x_var.clone()]);
let xfield = ident_apply(&id, &xfield_args);
let xfield_internal = ident_apply(&internal_id, &vec![x_var.clone()]);
let xfield_unbox = ident_apply(&id, &vec![unbox_x.clone()]);
let xfield_unbox = ident_apply(&id, &xfield_unbox_args);

// Create a wrapper function to access the field,
// because it seems to be dangerous to trigger directly on e.f,
// because Z3 seems to introduce e.f internally,
// which can unexpectedly trigger matching loops creating e.f.f.f.f...
// function f(x:datatyp):typ
// axiom forall x. f(x) = x.f
let decl_field = Arc::new(DeclX::Fun(
id.clone(),
Arc::new(vec![dtyp.clone()]),
typ_to_air(ctx, typ),
));
// function get_f(x:datatyp):typ
// axiom forall x. get_f(x) = x.f
// Also, see https://github.com/verus-lang/verus/issues/1366 : for 2 or more variants,
// when there are type parameters, we need to guard the axiom with a variant check:
// axiom forall a, x. x is f's variant ==> get_f(a, x) = x.f
let decl_field =
Arc::new(DeclX::Fun(id.clone(), Arc::new(xfield_params), typ_to_air(ctx, typ)));
field_commands.push(Arc::new(CommandX::Global(decl_field)));
let trigs = vec![xfield.clone()];
let name = format!("{}_{}", id, QID_ACCESSOR);
let bind =
func_bind_trig(ctx, name, &Arc::new(vec![]), &x_params(&datatyp), &trigs, false);
let bind = func_bind_trig(ctx, name, &tparams_opt, &x_params(&datatyp), &trigs, false);
let eq = mk_eq(&xfield, &xfield_internal);
let vid = is_variant_ident(&Dt::Path(dpath.clone()), &*variant.name);
let is_variant = ident_apply(&vid, &vec![x_var.clone()]);
let eq = if tparams_opt.len() > 0 { mk_implies(&is_variant, &eq) } else { eq };
let forall = mk_bind_expr(&bind, &eq);
let axiom = mk_unnamed_axiom(forall);
axiom_commands.push(Arc::new(CommandX::Global(axiom)));
Expand Down Expand Up @@ -488,6 +515,7 @@ fn datatype_or_fun_to_air_commands(
dpath,
&field_box_path,
&is_variant_ident(my_dt, &*variant.name),
&tparams_opt,
&variant_field_ident(dpath, &variant.name, &field.name),
recursive_function_field,
);
Expand Down Expand Up @@ -565,8 +593,16 @@ fn datatype_or_fun_to_air_commands(
// to avoid trigger matching loops, use ==, not ext_equal, for recursive fields:
&& !crate::ast_visitor::typ_visitor_check(typ, &mut is_recursive).is_err();
let fid = variant_field_ident(dpath, &variant.name, &field.name);
let xfield = ident_apply(&fid, &vec![unbox_x.clone()]);
let yfield = ident_apply(&fid, &vec![unbox_y.clone()]);
let mut xfield_args: Vec<air::ast::Expr> = Vec::new();
let mut yfield_args: Vec<air::ast::Expr> = Vec::new();
for t in typ_args_opt.iter() {
xfield_args.extend(crate::sst_to_air::typ_to_ids(t));
yfield_args.extend(crate::sst_to_air::typ_to_ids(t));
}
xfield_args.push(unbox_x.clone());
yfield_args.push(unbox_y.clone());
let xfield = ident_apply(&fid, &xfield_args);
let yfield = ident_apply(&fid, &yfield_args);
let eq = if uses_ext {
let xfield = crate::sst_to_air::as_box(ctx, xfield, typ);
let yfield = crate::sst_to_air::as_box(ctx, yfield, typ);
Expand Down
34 changes: 27 additions & 7 deletions source/vir/src/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ast::Path;
use crate::ast::{Idents, Path};
use crate::def::*;
use crate::sst_to_air::path_to_air_ident;
use air::ast::Ident;
Expand Down Expand Up @@ -833,10 +833,11 @@ pub(crate) fn strslice_functions(strslice_name: &str) -> Vec<Node> {
)
}

pub(crate) fn datatype_height_axiom(
fn datatype_height_axiom(
typ_name1: &Path,
typ_name2: &Option<Path>,
is_variant_ident: &Ident,
tparams: &Idents,
field: &Ident,
recursive_function_field: bool,
) -> Node {
Expand All @@ -849,18 +850,34 @@ pub(crate) fn datatype_height_axiom(
let is_variant = str_to_node(is_variant_ident.as_str());
let typ1 = str_to_node(path_to_air_ident(typ_name1).as_str());
let box_t1 = str_to_node(prefix_box(typ_name1).as_str());
let mut forall_params: Vec<Node> = Vec::new();
let mut field_x: Vec<Node> = Vec::new();
field_x.push(field);
for typ_param in tparams.iter() {
for (x, t) in crate::def::suffix_typ_param_ids_types(&typ_param) {
use crate::ast_util::LowerUniqueVar;
let x = str_to_node(&x.lower());
let t = str_to_node(t);
forall_params.push(node!(([x][t])));
field_x.push(x);
}
}
forall_params.push(node!((x[typ1])));
field_x.push(node!(x));
let forall_params = Node::List(forall_params);
let field_x = Node::List(field_x);
let field_of_x = match typ_name2 {
Some(typ2) => {
let box_t2 = str_to_node(prefix_box(&typ2).as_str());
node!(([box_t2] ([field] x)))
node!(([box_t2][field_x]))
}
// for a field with generic type, [field]'s return type is already "Poly"
None => node!(([field] x)),
None => field_x,
};
let field_of_x =
if recursive_function_field { node!(([height_rec_fun][field_of_x])) } else { field_of_x };
node!(
(axiom (forall ((x [typ1])) (!
(axiom (forall [forall_params] (!
(=>
([is_variant] x)
([height_lt]
Expand All @@ -879,12 +896,15 @@ pub(crate) fn datatype_height_axioms(
typ_name1: &Path,
typ_name2: &Option<Path>,
is_variant_ident: &Ident,
tparams: &Idents,
field: &Ident,
recursive_function_field: bool,
) -> Vec<Node> {
let axiom1 = datatype_height_axiom(typ_name1, typ_name2, is_variant_ident, field, false);
let axiom1 =
datatype_height_axiom(typ_name1, typ_name2, is_variant_ident, tparams, field, false);
if recursive_function_field {
let axiom2 = datatype_height_axiom(typ_name1, typ_name2, is_variant_ident, field, true);
let axiom2 =
datatype_height_axiom(typ_name1, typ_name2, is_variant_ident, tparams, field, true);
vec![axiom1, axiom2]
} else {
vec![axiom1]
Expand Down
15 changes: 14 additions & 1 deletion source/vir/src/sst_to_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -965,9 +965,22 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result<
}
UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant: _, check: _ }) => {
let expr = exp_to_expr(ctx, exp, expr_ctxt)?;
let (ts, num_variants) = match &*undecorate_typ(&exp.typ) {
TypX::Datatype(Dt::Path(p), ts, _) => {
let (_, variants) = &ctx.global.datatypes[p];
(ts.clone(), variants.len())
}
TypX::Datatype(Dt::Tuple(_), ts, _) => (ts.clone(), 1),
_ => panic!("internal error: expected datatype in field op"),
};
let mut exprs: Vec<Expr> =
crate::datatype_to_air::field_typ_args(num_variants, || {
ts.iter().map(typ_to_ids).flatten().collect()
});
exprs.push(expr);
Arc::new(ExprX::Apply(
variant_field_ident(&encode_dt_as_path(datatype), variant, field),
Arc::new(vec![expr]),
Arc::new(exprs),
))
}
UnaryOpr::CustomErr(_) => {
Expand Down