From a913f4e14dc8cf85eff3988ddd11bdf6a01f5e15 Mon Sep 17 00:00:00 2001 From: Travis Hance Date: Fri, 12 Jan 2024 12:55:26 -0500 Subject: [PATCH 01/11] fix type inconsistency in ast_simplify, fixes #956 --- source/rust_verify_test/tests/regression.rs | 13 ++++++ source/vir/src/ast_simplify.rs | 21 ++++++--- source/vir/src/ast_util.rs | 10 +++++ source/vir/src/check_ast_flavor.rs | 49 +++++++++++++++------ source/vir/src/context.rs | 11 +++-- source/vir/src/sst_to_air.rs | 4 +- source/vir/src/sst_util.rs | 15 ++++++- 7 files changed, 95 insertions(+), 28 deletions(-) diff --git a/source/rust_verify_test/tests/regression.rs b/source/rust_verify_test/tests/regression.rs index 53e4078d7e..967583e803 100644 --- a/source/rust_verify_test/tests/regression.rs +++ b/source/rust_verify_test/tests/regression.rs @@ -1231,3 +1231,16 @@ test_verify_one_file! { } } => Ok(()) } + +test_verify_one_file! { + #[test] struct_with_updater_and_typ_subst_in_field_issue956 verus_code! { + struct X { + a: int, + b: T, + } + + proof fn stuff(x: X) { + let y = X { a: x.a + 1, .. x }; + } + } => Ok(()) +} diff --git a/source/vir/src/ast_simplify.rs b/source/vir/src/ast_simplify.rs index b4114f0cb4..550c626843 100644 --- a/source/vir/src/ast_simplify.rs +++ b/source/vir/src/ast_simplify.rs @@ -11,7 +11,7 @@ use crate::ast::{ }; use crate::ast_util::int_range_from_type; use crate::ast_util::is_integer_type; -use crate::ast_util::{conjoin, disjoin, if_then_else, wrap_in_trigger}; +use crate::ast_util::{conjoin, disjoin, if_then_else, typ_args_for_datatype_typ, wrap_in_trigger}; use crate::ast_visitor::VisitorScopeMap; use crate::context::GlobalCtx; use crate::def::{ @@ -19,6 +19,7 @@ use crate::def::{ }; use crate::messages::error; use crate::messages::Span; +use crate::sst_util::subst_typ_for_datatype; use crate::util::vec_map_result; use air::ast::BinderX; use air::ast::Binders; @@ -330,9 +331,10 @@ fn simplify_one_expr( } decls.extend(temp_decl.into_iter()); } - let datatype = &ctx.datatypes[path]; - assert_eq!(datatype.len(), 1); - let fields = &datatype[0].a; + let (typ_positives, variants) = &ctx.datatypes[path]; + assert_eq!(variants.len(), 1); + let fields = &variants[0].a; + let typ_args = typ_args_for_datatype_typ(&expr.typ); // replace ..update // with f1: update.f1, f2: update.f2, ... for field in fields.iter() { @@ -344,7 +346,8 @@ fn simplify_one_expr( get_variant: false, }); let exprx = ExprX::UnaryOpr(op, update.clone()); - let field_exp = SpannedTyped::new(&expr.span, &field.a.0, exprx); + let ty = subst_typ_for_datatype(&typ_positives, typ_args, &field.a.0); + let field_exp = SpannedTyped::new(&expr.span, &ty, exprx); binders.push(ident_binder(&field.name, &field_exp)); } } @@ -903,8 +906,12 @@ pub fn simplify_krate(ctx: &mut GlobalCtx, krate: &Krate) -> Result bool { typs1.len() == typs2.len() && typs1.iter().zip(typs2.iter()).all(|(t1, t2)| types_equal(t1, t2)) } +pub fn typ_args_for_datatype_typ(typ: &Typ) -> &Typs { + match &**typ { + TypX::Decorate(_, t) => typ_args_for_datatype_typ(t), + TypX::Datatype(_, args, _) => args, + _ => { + panic!("typ_args_for_datatype_typ expected datatype type"); + } + } +} + pub const QUANT_FORALL: Quant = Quant { quant: air::ast::Quant::Forall }; pub fn params_equal_opt( diff --git a/source/vir/src/check_ast_flavor.rs b/source/vir/src/check_ast_flavor.rs index b0784014c7..b6bbbd0f21 100644 --- a/source/vir/src/check_ast_flavor.rs +++ b/source/vir/src/check_ast_flavor.rs @@ -1,10 +1,15 @@ -use crate::ast::{Expr, ExprX, FunctionX, GenericBoundX, Krate, KrateX, Typ, TypX, UnaryOpr}; +use crate::ast::{ + Expr, ExprX, Function, FunctionX, GenericBoundX, Idents, Krate, KrateX, MaskSpec, Typ, TypX, + UnaryOpr, +}; use crate::ast_visitor::{ expr_visitor_check, expr_visitor_dfs, typ_visitor_check, VisitorControlFlow, VisitorScopeMap, }; pub use air::ast_util::{ident_binder, str_ident}; +use std::sync::Arc; -fn check_expr_simplified(_scope_map: &VisitorScopeMap, expr: &Expr) -> Result<(), ()> { +fn check_expr_simplified(expr: &Expr, function: &Function) -> Result<(), ()> { + check_typ_simplified(&expr.typ, &function.x.typ_params)?; match expr.x { ExprX::ConstVar(..) | ExprX::UnaryOpr(UnaryOpr::TupleField { .. }, _) @@ -14,9 +19,10 @@ fn check_expr_simplified(_scope_map: &VisitorScopeMap, expr: &Expr) -> Result<() } } -fn check_typ_simplified(typ: &Typ) -> Result<(), ()> { +fn check_typ_simplified(typ: &Typ, typ_params: &Idents) -> Result<(), ()> { match &**typ { TypX::Tuple(..) => Err(()), + TypX::TypParam(id) if !typ_params.contains(id) => Err(()), _ => Ok(()), } } @@ -39,13 +45,24 @@ pub fn check_krate_simplified(krate: &Krate) { } = &**krate; for function in functions { - let FunctionX { require, ensure, decrease, body, typ_bounds, params, ret, .. } = - &function.x; + let FunctionX { + require, ensure, decrease, body, typ_bounds, params, ret, mask_spec, .. + } = &function.x; - let all_exprs = - require.iter().chain(ensure.iter()).chain(decrease.iter()).chain(body.iter()); + let mask_exprs = match mask_spec { + MaskSpec::InvariantOpens(es) => es.clone(), + MaskSpec::InvariantOpensExcept(es) => es.clone(), + MaskSpec::NoSpec => Arc::new(vec![]), + }; + + let all_exprs = require + .iter() + .chain(ensure.iter()) + .chain(decrease.iter()) + .chain(body.iter()) + .chain(mask_exprs.iter()); for expr in all_exprs { - expr_visitor_check(expr, &mut check_expr_simplified) + expr_visitor_check(expr, &mut |_, e| check_expr_simplified(e, function)) .expect("function AST expression uses node that should have been simplified"); } @@ -53,25 +70,29 @@ pub fn check_krate_simplified(krate: &Krate) { match &**bound { GenericBoundX::Trait(_, ts) => { for t in ts.iter() { - typ_visitor_check(t, &mut check_typ_simplified).expect( - "function param bound uses node that should have been simplified", - ); + typ_visitor_check(t, &mut |t| { + check_typ_simplified(t, &function.x.typ_params) + }) + .expect("function param bound uses node that should have been simplified"); } } } } for param in params.iter().chain(std::iter::once(ret)) { - typ_visitor_check(¶m.x.typ, &mut check_typ_simplified) - .expect("function param typ uses node that should have been simplified"); + typ_visitor_check(¶m.x.typ, &mut |t| { + check_typ_simplified(t, &function.x.typ_params) + }) + .expect("function param typ uses node that should have been simplified"); } } for datatype in datatypes { + let typ_params = Arc::new(datatype.x.typ_params.iter().map(|(id, _)| id.clone()).collect()); for variant in datatype.x.variants.iter() { for field in variant.a.iter() { let (typ, _, _) = &field.a; - typ_visitor_check(typ, &mut check_typ_simplified) + typ_visitor_check(typ, &mut |t| check_typ_simplified(t, &typ_params)) .expect("datatype field typ uses node that should have been simplified"); } } diff --git a/source/vir/src/context.rs b/source/vir/src/context.rs index 8c9d12c5b4..5c0cc35340 100644 --- a/source/vir/src/context.rs +++ b/source/vir/src/context.rs @@ -1,6 +1,6 @@ use crate::ast::{ ArchWordBits, Datatype, Fun, Function, GenericBounds, Ident, IntRange, Krate, Mode, Path, - Primitive, Trait, TypX, Variants, VirErr, + Primitive, Trait, TypPositives, TypX, Variants, VirErr, }; use crate::datatype_to_air::is_datatype_transparent; use crate::def::FUEL_ID; @@ -34,7 +34,7 @@ pub struct ChosenTriggers { /// Context for across all modules pub struct GlobalCtx { pub(crate) chosen_triggers: std::cell::RefCell>, // diagnostics - pub(crate) datatypes: Arc>, + pub(crate) datatypes: Arc>, pub(crate) fun_bounds: Arc>, /// Used for synthesized AST nodes that have no relation to any location in the original code: pub(crate) no_span: Span, @@ -198,8 +198,11 @@ impl GlobalCtx { let chosen_triggers: std::cell::RefCell> = std::cell::RefCell::new(Vec::new()); - let datatypes: HashMap = - krate.datatypes.iter().map(|d| (d.x.path.clone(), d.x.variants.clone())).collect(); + let datatypes: HashMap = krate + .datatypes + .iter() + .map(|d| (d.x.path.clone(), (d.x.typ_params.clone(), d.x.variants.clone()))) + .collect(); let mut func_map: HashMap = HashMap::new(); for function in krate.functions.iter() { assert!(!func_map.contains_key(&function.x.name)); diff --git a/source/vir/src/sst_to_air.rs b/source/vir/src/sst_to_air.rs index c550102bf1..bb17b33132 100644 --- a/source/vir/src/sst_to_air.rs +++ b/source/vir/src/sst_to_air.rs @@ -535,7 +535,7 @@ pub(crate) fn ctor_to_apply<'a>( variant: &Ident, binders: &'a Binders, ) -> (Ident, impl Iterator>>) { - let fields = &get_variant(&ctx.global.datatypes[path], variant).a; + let fields = &get_variant(&ctx.global.datatypes[path].1, variant).a; (variant_ident(path, &variant), fields.iter().map(move |f| get_field(binders, &f.name))) } @@ -1328,7 +1328,7 @@ fn assume_other_fields_unchanged_inner( assert!(u[0].datatype == *datatype && u[0].variant == *variant); updated_fields.entry(&u[0].field).or_insert(Vec::new()).push(u[1..].to_vec()); } - let datatype_fields = &get_variant(&ctx.global.datatypes[datatype], variant).a; + let datatype_fields = &get_variant(&ctx.global.datatypes[datatype].1, variant).a; let dt = vec_map_result(&**datatype_fields, |field: &Binder<(Typ, Mode, Visibility)>| { let base_exp = if let TypX::Boxed(base_typ) = &*undecorate_typ(&base.typ) { diff --git a/source/vir/src/sst_util.rs b/source/vir/src/sst_util.rs index b606751f1d..f2cb0afb8c 100644 --- a/source/vir/src/sst_util.rs +++ b/source/vir/src/sst_util.rs @@ -1,6 +1,6 @@ use crate::ast::{ ArithOp, BinaryOp, BitwiseOp, Constant, InequalityOp, IntRange, IntegerTypeBoundKind, Mode, - Quant, SpannedTyped, Typ, TypX, UnaryOp, UnaryOpr, + Quant, SpannedTyped, Typ, TypX, Typs, UnaryOp, UnaryOpr, }; use crate::def::{unique_bound, user_local_name, Spanned}; use crate::interpreter::InterpExp; @@ -53,6 +53,19 @@ fn subst_typ(typ_substs: &HashMap, typ: &Typ) -> Typ { .expect("subst_typ") } +pub fn subst_typ_for_datatype( + typ_params: &crate::ast::TypPositives, + args: &Typs, + typ: &Typ, +) -> Typ { + assert!(typ_params.len() == args.len()); + let mut typ_substs: HashMap = HashMap::new(); + for (typ_param, arg) in typ_params.iter().zip(args.iter()) { + typ_substs.insert(typ_param.0.clone(), arg.clone()); + } + subst_typ(&typ_substs, typ) +} + fn subst_rename_binders A, FT: Fn(&A) -> Typ>( span: &Span, substs: &mut ScopeMap, From 8fdfdb8a2ff8dbd131791b178c62e4287287b0b5 Mon Sep 17 00:00:00 2001 From: tjhance Date: Sun, 14 Jan 2024 09:43:49 -0500 Subject: [PATCH 02/11] handle numeric fields in struct syntax (#949) * handle numeric fields in struct syntax * handle all idents the same way * add comment in lifetime_emit --- source/rust_verify/src/fn_call_to_vir.rs | 9 +-- source/rust_verify/src/lifetime_ast.rs | 1 + source/rust_verify/src/lifetime_emit.rs | 6 ++ source/rust_verify/src/lifetime_generate.rs | 27 +++---- source/rust_verify/src/rust_to_vir_adts.rs | 11 +-- source/rust_verify/src/rust_to_vir_expr.rs | 24 +++---- source/rust_verify_test/tests/adts.rs | 79 +++++++++++++++++++++ source/rust_verify_test/tests/modes.rs | 23 +++++- source/vir/src/ast.rs | 8 +-- source/vir/src/def.rs | 4 ++ 10 files changed, 143 insertions(+), 49 deletions(-) diff --git a/source/rust_verify/src/fn_call_to_vir.rs b/source/rust_verify/src/fn_call_to_vir.rs index 984695a08f..04bc1a954a 100644 --- a/source/rust_verify/src/fn_call_to_vir.rs +++ b/source/rust_verify/src/fn_call_to_vir.rs @@ -33,7 +33,7 @@ use vir::ast::{ UnaryOpr, VarAt, VirErr, }; use vir::ast_util::{const_int_from_string, typ_to_diagnostic_str, types_equal, undecorate_typ}; -use vir::def::positional_field_ident; +use vir::def::field_ident_from_rust; pub(crate) fn fn_call_to_vir<'tcx>( bctx: &BodyCtxt<'tcx>, @@ -1801,12 +1801,7 @@ fn check_variant_field<'tcx>( return err_span(span, "field has the wrong type"); } - let field_ident = if field_name.as_str().bytes().nth(0).unwrap().is_ascii_digit() { - let i = field_name.parse::().unwrap(); - positional_field_ident(i) - } else { - str_ident(&field_name) - }; + let field_ident = field_ident_from_rust(&field_name); Ok((adt_path, Some(field_ident))) } diff --git a/source/rust_verify/src/lifetime_ast.rs b/source/rust_verify/src/lifetime_ast.rs index cf2bf5b873..30fa4fdd55 100644 --- a/source/rust_verify/src/lifetime_ast.rs +++ b/source/rust_verify/src/lifetime_ast.rs @@ -11,6 +11,7 @@ pub(crate) enum IdKind { Fun, Local, Builtin, + Field, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/source/rust_verify/src/lifetime_emit.rs b/source/rust_verify/src/lifetime_emit.rs index c460ad2de0..33467b09bd 100644 --- a/source/rust_verify/src/lifetime_emit.rs +++ b/source/rust_verify/src/lifetime_emit.rs @@ -12,6 +12,12 @@ pub(crate) fn encode_id(kind: IdKind, rename_count: usize, raw_id: &String) -> S IdKind::Fun => format!("f{}_{}", rename_count, raw_id), IdKind::Local => format!("x{}_{}", rename_count, vir::def::user_local_name(raw_id)), IdKind::Builtin => raw_id.clone(), + + // Numeric fields need to be emitted as numeric fields. + // Non-numeric fields need to be unique-ified to avoid conflict with method names. + // Therefore, we only use the rename_count for non-numeric fields. + IdKind::Field if raw_id.bytes().nth(0).unwrap().is_ascii_digit() => raw_id.clone(), + IdKind::Field => format!("y{}_{}", rename_count, vir::def::user_local_name(raw_id)), } } diff --git a/source/rust_verify/src/lifetime_generate.rs b/source/rust_verify/src/lifetime_generate.rs index 11eb8cd9c5..b5dcb25255 100644 --- a/source/rust_verify/src/lifetime_generate.rs +++ b/source/rust_verify/src/lifetime_generate.rs @@ -26,7 +26,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use vir::ast::{AutospecUsage, DatatypeTransparency, Fun, FunX, Function, Mode, Path}; use vir::ast_util::get_field; -use vir::def::VERUS_SPEC; +use vir::def::{field_ident_from_rust, VERUS_SPEC}; use vir::messages::AstId; impl TypX { @@ -77,6 +77,7 @@ pub(crate) struct State { datatype_worklist: Vec, imported_fun_worklist: Vec, id_to_name: HashMap, + field_to_name: HashMap, typ_param_to_name: HashMap<(String, Option), Id>, lifetime_to_name: HashMap<(String, Option), Id>, fun_to_name: HashMap, @@ -106,6 +107,7 @@ impl State { datatype_worklist: Vec::new(), imported_fun_worklist: Vec::new(), id_to_name: HashMap::new(), + field_to_name: HashMap::new(), typ_param_to_name: HashMap::new(), lifetime_to_name: HashMap::new(), fun_to_name: HashMap::new(), @@ -146,6 +148,12 @@ impl State { Self::id(&mut self.rename_count, &mut self.id_to_name, IdKind::Local, &raw_id, f) } + fn field>(&mut self, raw_id: S) -> Id { + let raw_id = raw_id.into(); + let f = || raw_id.clone(); + Self::id(&mut self.rename_count, &mut self.field_to_name, IdKind::Field, &raw_id, f) + } + fn typ_param>(&mut self, raw_id: S, maybe_impl_index: Option) -> Id { let raw_id = raw_id.into(); let (is_impl, impl_index) = match (raw_id.starts_with("impl "), maybe_impl_index) { @@ -555,7 +563,7 @@ fn erase_pat<'tcx>(ctxt: &Context<'tcx>, state: &mut State, pat: &Pat) -> Patter let name = state.datatype_name(&vir_path); let mut binders: Vec<(Id, Pattern)> = Vec::new(); for pat in pats.iter() { - let field = state.local(pat.ident.to_string()); + let field = state.field(pat.ident.to_string()); let pattern = erase_pat(ctxt, state, &pat.pat); binders.push((field, pattern)); } @@ -1174,9 +1182,9 @@ fn erase_expr<'tcx>( let variant = datatype.x.get_variant(&variant_name); let mut fs: Vec<(Id, Exp)> = Vec::new(); for f in fields.iter() { - let field_name = Arc::new(f.ident.as_str().to_string()); - let (_, field_mode, _) = get_field(&variant.a, &field_name).a; - let name = state.local(f.ident.to_string()); + let vir_field_name = field_ident_from_rust(f.ident.as_str()); + let (_, field_mode, _) = get_field(&variant.a, &vir_field_name).a; + let name = state.field(f.ident.to_string()); let e = if field_mode == Mode::Spec { phantom_data_expr(ctxt, state, &f.expr) } else { @@ -1264,12 +1272,7 @@ fn erase_expr<'tcx>( if expect_spec { erase_spec_exps(ctxt, state, expr, vec![exp1]) } else { - let field_name = field.to_string(); - let field_id = if field_name.chars().all(char::is_numeric) { - Id::new(IdKind::Builtin, 0, field_name) - } else { - state.local(field.to_string()) - }; + let field_id = state.field(field.to_string()); mk_exp(ExpX::Field(exp1.expect("expr"), field_id)) } } @@ -2148,7 +2151,7 @@ fn erase_variant_data<'tcx>( None => { let mut fields: Vec = Vec::new(); for field in &variant.fields { - let name = state.local(field.ident(ctxt.tcx).to_string()); + let name = state.field(field.ident(ctxt.tcx).to_string()); let typ = erase_ty(ctxt, state, &ctxt.tcx.type_of(field.did).skip_binder()); fields.push(Field { name, typ: revise_typ(field.did, typ) }); } diff --git a/source/rust_verify/src/rust_to_vir_adts.rs b/source/rust_verify/src/rust_to_vir_adts.rs index afcac7d401..550637683e 100644 --- a/source/rust_verify/src/rust_to_vir_adts.rs +++ b/source/rust_verify/src/rust_to_vir_adts.rs @@ -14,7 +14,7 @@ use rustc_span::Span; use std::sync::Arc; use vir::ast::{DatatypeTransparency, DatatypeX, Ident, KrateX, Mode, Path, Variant, VirErr}; use vir::ast_util::ident_binder; -use vir::def::positional_field_ident; +use vir::def::field_ident_from_rust; // The `rustc_hir::VariantData` is optional here because we won't have it available // when handling external datatype definitions. @@ -72,14 +72,7 @@ where } }; - // Only way I can see to determine if the field is positional using rustc_middle - let use_positional = field_def.name.as_str().bytes().nth(0).unwrap().is_ascii_digit(); - - let ident = if use_positional { - positional_field_ident(idx) - } else { - str_ident(&field_def_ident.as_str()) - }; + let ident = field_ident_from_rust(&field_def_ident.as_str()); let typ = mid_ty_to_vir( ctxt.tcx, diff --git a/source/rust_verify/src/rust_to_vir_expr.rs b/source/rust_verify/src/rust_to_vir_expr.rs index 283f60c73b..8db90d2b78 100644 --- a/source/rust_verify/src/rust_to_vir_expr.rs +++ b/source/rust_verify/src/rust_to_vir_expr.rs @@ -40,7 +40,7 @@ use vir::ast::{ Typ, TypX, UnaryOp, UnaryOpr, VirErr, }; use vir::ast_util::{ident_binder, typ_to_diagnostic_str, types_equal, undecorate_typ}; -use vir::def::positional_field_ident; +use vir::def::{field_ident_from_rust, positional_field_ident}; pub(crate) fn pat_to_mut_var<'tcx>(pat: &Pat) -> Result<(bool, String), VirErr> { let Pat { hir_id: _, kind, span, default_binding_modes } = pat; @@ -481,7 +481,8 @@ pub(crate) fn pattern_to_vir_inner<'tcx>( let mut binders: Vec> = Vec::new(); for fpat in pats.iter() { let pattern = pattern_to_vir(bctx, &fpat.pat)?; - let binder = ident_binder(&str_ident(&fpat.ident.as_str()), &pattern); + let ident = field_ident_from_rust(fpat.ident.as_str()); + let binder = ident_binder(&ident, &pattern); binders.push(binder); } PatternX::Constructor(vir_path, variant_name, Arc::new(binders)) @@ -1455,18 +1456,12 @@ pub(crate) fn expr_to_vir_innermost<'tcx>( let hir_def = bctx.ctxt.tcx.adt_def(adt_def.did()); let variant = hir_def.variants().iter().next().unwrap(); let variant_name = str_ident(&variant.ident(tcx).as_str()); - let field_name = match variant.ctor_kind() { - Some(rustc_hir::def::CtorKind::Fn) => { - let field_idx = variant - .fields - .iter() - .position(|f| f.ident(tcx).as_str() == name.as_str()) - .expect("positional field not found"); - positional_field_ident(field_idx) - } - None => str_ident(&name.as_str()), + let field_name = field_ident_from_rust(&name.as_str()); + match variant.ctor_kind() { + Some(rustc_hir::def::CtorKind::Fn) => {} + None => {} Some(rustc_hir::def::CtorKind::Const) => panic!("unexpected tuple constructor"), - }; + } (datatype_path, variant_name, field_name) } else { let lhs_typ = typ_of_node(bctx, lhs.span, &lhs.hir_id, false)?; @@ -1650,7 +1645,8 @@ pub(crate) fn expr_to_vir_innermost<'tcx>( .iter() .map(|f| -> Result<_, VirErr> { let vir = expr_to_vir(bctx, f.expr, modifier)?; - Ok(ident_binder(&str_ident(&f.ident.as_str()), &vir)) + let ident = field_ident_from_rust(f.ident.as_str()); + Ok(ident_binder(&ident, &vir)) }) .collect::, _>>()?, ); diff --git a/source/rust_verify_test/tests/adts.rs b/source/rust_verify_test/tests/adts.rs index ebcd99f40c..71a9822e60 100644 --- a/source/rust_verify_test/tests/adts.rs +++ b/source/rust_verify_test/tests/adts.rs @@ -1088,3 +1088,82 @@ test_verify_one_file! { } } => Ok(()) } + +test_verify_one_file! { + #[test] struct_syntax_with_numeric_field_names verus_code! { + #[is_variant] + enum Foo { + Bar(u32, u32), + Qux, + } + + fn test() { + let b = Foo::Bar { 1: 30, 0: 20 }; + assert(b.get_Bar_0() == 20); + assert(b.get_Bar_1() == 30); + } + + fn test2() { + let b = Foo::Bar { 1: 30, 0: 20 }; + assert(b.get_Bar_1() == 20); // FAILS + } + + + fn test_pat(foo: Foo) { + let foo = Foo::Bar(20, 40); + match foo { + Foo::Bar { 1: a, 0: b } => { + assert(b == 20); + assert(a == 40); + } + Foo::Qux => { assert(false); } + } + } + + fn test_pat2(foo: Foo) { + let foo = Foo::Bar(20, 40); + match foo { + Foo::Bar { 1: a, 0: b } => { + assert(b == 40); // FAILS + } + Foo::Qux => { } + } + } + + fn test_pat_not_all_fields(foo: Foo) { + let foo = Foo::Bar(20, 40); + match foo { + Foo::Bar { 1: a, .. } => { + assert(a == 40); + } + Foo::Qux => { assert(false); } + } + } + + fn test_pat_not_all_fields2(foo: Foo) { + let foo = Foo::Bar(20, 40); + match foo { + Foo::Bar { 1: a, .. } => { + assert(a == 20); // FAILS + } + Foo::Qux => { } + } + } + + spec fn sfn(foo: Foo) -> (u32, u32) { + match foo { + Foo::Bar { 1: a, 0: b } => (b, a), + Foo::Qux => (0, 0), + } + } + + proof fn test_sfn(foo: Foo) { + assert(sfn(Foo::Bar(20, 30)) == (20u32, 30u32)); + assert(sfn(Foo::Qux) == (0u32, 0u32)); + } + + proof fn test_sfn2(foo: Foo) { + assert(sfn(Foo::Bar(20, 30)).0 == 30); // FAILS + } + } => Err(err) => assert_fails(err, 4) +} diff --git a/source/rust_verify_test/tests/modes.rs b/source/rust_verify_test/tests/modes.rs index ade9a75299..a855d883f0 100644 --- a/source/rust_verify_test/tests/modes.rs +++ b/source/rust_verify_test/tests/modes.rs @@ -950,9 +950,8 @@ test_verify_one_file! { } test_verify_one_file! { - // TODO(utaal) issue with tracked rewrite, I believe - #[ignore] #[test] test_struct_pattern_fields_out_of_order_fail_issue_348 verus_code! { - struct Foo { + #[test] test_struct_pattern_fields_out_of_order_fail_issue_348 verus_code! { + tracked struct Foo { ghost a: u64, tracked b: u64, } @@ -992,6 +991,24 @@ test_verify_one_file! { } => Ok(()) } +test_verify_one_file! { + #[test] test_struct_pattern_fields_numeric_out_of_order_fail verus_code! { + tracked struct Foo(ghost u64, tracked u64); + + proof fn some_call(tracked y: u64) { } + + proof fn t() { + let tracked foo = Foo(5, 6); + let tracked Foo { 1: b, 0: a } = foo; + + // Variable 'a' has the mode of field '0' (that is, spec) + // some_call requires 'proof' + // So this should fail + some_call(a); + } + } => Err(err) => assert_vir_error_msg(err, "expression has mode spec, expected mode proof") +} + test_verify_one_file! { #[test] test_use_exec_var_in_forall verus_code! { spec fn some_fn(j: int) -> bool { diff --git a/source/vir/src/ast.rs b/source/vir/src/ast.rs index c68fb6f455..c2a0d4065d 100644 --- a/source/vir/src/ast.rs +++ b/source/vir/src/ast.rs @@ -480,8 +480,8 @@ pub enum PatternX { /// Note: ast_simplify replaces this with Constructor Tuple(Patterns), /// Match constructor of datatype Path, variant Ident - /// For tuple-style variants, the patterns appear in order and are named "0", "1", etc. - /// For struct-style variants, the patterns may appear in any order. + /// For tuple-style variants, the fields are named "_0", "_1", etc. + /// Fields can appear **in any order** even for tuple variants. Constructor(Path, Ident, Binders), Or(Pattern, Pattern), } @@ -611,8 +611,8 @@ pub enum ExprX { Tuple(Exprs), /// Construct datatype value of type Path and variant Ident, /// with field initializers Binders and an optional ".." update expression. - /// For tuple-style variants, the field initializers appear in order and are named "_0", "_1", etc. - /// For struct-style variants, the field initializers may appear in any order. + /// For tuple-style variants, the fields are named "_0", "_1", etc. + /// Fields can appear **in any order** even for tuple variants. Ctor(Path, Ident, Binders, Option), /// Primitive 0-argument operation NullaryOpr(NullaryOpr), diff --git a/source/vir/src/def.rs b/source/vir/src/def.rs index e340bd1bbe..3b98d23601 100644 --- a/source/vir/src/def.rs +++ b/source/vir/src/def.rs @@ -474,6 +474,10 @@ pub fn positional_field_ident(idx: usize) -> Ident { Arc::new(format!("_{}", idx)) } +pub fn field_ident_from_rust(s: &str) -> Ident { + Arc::new(format!("_{}", s)) +} + pub fn monotyp_apply(datatype: &Path, args: &Vec) -> Path { if args.len() == 0 { datatype.clone() From 49cd0a599b79e28180e8a18dd6541596366ac437 Mon Sep 17 00:00:00 2001 From: tjhance Date: Tue, 9 Jan 2024 13:43:30 -0500 Subject: [PATCH 03/11] Support ? for Option and Result (#950) * support ? for Option * support Result based on Chris's suggestion --- source/pervasive/std_specs/control_flow.rs | 72 +++++++++++++++++ source/pervasive/std_specs/mod.rs | 1 + source/rust_verify/src/config.rs | 1 + source/rust_verify/src/fn_call_to_vir.rs | 34 +++++++- source/rust_verify/src/lifetime_generate.rs | 37 ++++++++- source/rust_verify/src/verus_items.rs | 4 + source/rust_verify_test/tests/std.rs | 90 +++++++++++++++++++++ 7 files changed, 234 insertions(+), 5 deletions(-) create mode 100644 source/pervasive/std_specs/control_flow.rs diff --git a/source/pervasive/std_specs/control_flow.rs b/source/pervasive/std_specs/control_flow.rs new file mode 100644 index 0000000000..9f6438f725 --- /dev/null +++ b/source/pervasive/std_specs/control_flow.rs @@ -0,0 +1,72 @@ +use crate::prelude::*; +use core::ops::Try; +use core::ops::ControlFlow; +use core::ops::FromResidual; +use core::convert::Infallible; + +verus!{ + +#[verifier(external_type_specification)] +#[verifier::accept_recursive_types(B)] +#[verifier::reject_recursive_types_in_ground_variants(C)] +pub struct ExControlFlow(ControlFlow); + +#[verifier(external_type_specification)] +#[verifier(external_body)] +pub struct ExInfallible(Infallible); + + +#[verifier::external_fn_specification] +pub fn ex_result_branch(result: Result) -> (cf: ControlFlow< as Try>::Residual, as Try>::Output>) + ensures + cf === match result { + Ok(v) => ControlFlow::Continue(v), + Err(e) => ControlFlow::Break(Err(e)), + }, +{ + result.branch() +} + +#[verifier::external_fn_specification] +pub fn ex_option_branch(option: Option) -> (cf: ControlFlow< as Try>::Residual, as Try>::Output>) + ensures + cf === match option { + Some(v) => ControlFlow::Continue(v), + None => ControlFlow::Break(None), + }, +{ + option.branch() +} + +#[verifier::external_fn_specification] +pub fn ex_option_from_residual(option: Option) -> (option2: Option) + ensures + option.is_none(), + option2.is_none(), +{ + Option::from_residual(option) +} + +pub spec fn spec_from(value: T, ret: S) -> bool; + +#[verifier::broadcast_forall] +#[verifier::external_body] +pub proof fn spec_from_blanket_identity(t: T, s: T) + ensures + spec_from::(t, s) ==> t == s +{ +} + +#[verifier::external_fn_specification] +pub fn ex_result_from_residual>(result: Result) + -> (result2: Result) + ensures + match (result, result2) { + (Err(e), Err(e2)) => spec_from::(e, e2), + _ => false, + }, +{ + Result::from_residual(result) +} + +} diff --git a/source/pervasive/std_specs/mod.rs b/source/pervasive/std_specs/mod.rs index 7cde07157e..4f144afd3b 100644 --- a/source/pervasive/std_specs/mod.rs +++ b/source/pervasive/std_specs/mod.rs @@ -5,6 +5,7 @@ pub mod result; pub mod option; pub mod atomic; pub mod bits; +pub mod control_flow; #[cfg(feature = "alloc")] pub mod vec; diff --git a/source/rust_verify/src/config.rs b/source/rust_verify/src/config.rs index aca4066750..81ac190587 100644 --- a/source/rust_verify/src/config.rs +++ b/source/rust_verify/src/config.rs @@ -114,6 +114,7 @@ pub fn enable_default_features_and_verus_attr( "register_tool", "tuple_trait", "custom_inner_attributes", + "try_trait_v2", ] { rustc_args.push("-Z".to_string()); rustc_args.push(format!("crate-attr=feature({})", feature)); diff --git a/source/rust_verify/src/fn_call_to_vir.rs b/source/rust_verify/src/fn_call_to_vir.rs index 04bc1a954a..73b505a3d2 100644 --- a/source/rust_verify/src/fn_call_to_vir.rs +++ b/source/rust_verify/src/fn_call_to_vir.rs @@ -21,7 +21,7 @@ use air::ast_util::str_ident; use rustc_ast::LitKind; use rustc_hir::def::Res; use rustc_hir::{Expr, ExprKind, Node, QPath}; -use rustc_middle::ty::{GenericArgKind, TyKind}; +use rustc_middle::ty::{GenericArg, GenericArgKind, TyKind}; use rustc_span::def_id::DefId; use rustc_span::source_map::Spanned; use rustc_span::Span; @@ -96,9 +96,6 @@ pub(crate) fn fn_call_to_vir<'tcx>( ), ); } - Some(RustItem::TryTraitBranch) => { - return err_span(expr.span, "Verus does not yet support the ? operator"); - } Some(RustItem::Clone) => { // Special case `clone` for standard Rc and Arc types // (Could also handle it for other types where cloning is the identity @@ -167,6 +164,8 @@ pub(crate) fn fn_call_to_vir<'tcx>( // If the resolution is statically known, we record the resolved function for the // to be used by lifetime_generate. + let node_substs = fix_node_substs(tcx, bctx.types, node_substs, rust_item, &args, expr); + let target_kind = if tcx.trait_of_item(f).is_none() { vir::ast::CallTargetKind::Static } else { @@ -1646,6 +1645,33 @@ fn mk_is_smaller_than<'tcx>( return Ok(dec_exp); } +pub(crate) fn fix_node_substs<'tcx, 'a>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + types: &'tcx rustc_middle::ty::TypeckResults<'tcx>, + node_substs: &'tcx rustc_middle::ty::List>, + rust_item: Option, + args: &'a [&'tcx Expr<'tcx>], + expr: &'a Expr<'tcx>, +) -> &'tcx rustc_middle::ty::List> { + match rust_item { + Some(RustItem::TryTraitBranch) => { + // I don't understand why, but in this case, node_substs is empty instead + // of having the type argument. Let's fix it here. + // `branch` has type `fn branch(self) -> ...` + // so we can get the Self argument from the first argument. + let generic_arg = GenericArg::from(types.expr_ty_adjusted(&args[0])); + tcx.mk_args(&[generic_arg]) + } + Some(RustItem::ResidualTraitFromResidual) => { + // `fn from_residual(residual: R) -> Self;` + let generic_arg0 = GenericArg::from(types.expr_ty(expr)); + let generic_arg1 = GenericArg::from(types.expr_ty_adjusted(&args[0])); + tcx.mk_args(&[generic_arg0, generic_arg1]) + } + _ => node_substs, + } +} + fn mk_typ_args<'tcx>( bctx: &BodyCtxt<'tcx>, substs: &rustc_middle::ty::List>, diff --git a/source/rust_verify/src/lifetime_generate.rs b/source/rust_verify/src/lifetime_generate.rs index b5dcb25255..4cabfaecb6 100644 --- a/source/rust_verify/src/lifetime_generate.rs +++ b/source/rust_verify/src/lifetime_generate.rs @@ -469,6 +469,27 @@ fn erase_ty<'tcx>(ctxt: &Context<'tcx>, state: &mut State, ty: &Ty<'tcx>) -> Typ TyKind::Alias(rustc_middle::ty::AliasKind::Projection, t) => { // Note: even if rust_to_vir_base decides to normalize t, // we don't have to normalize t here, since we're generating Rust code, not VIR. + // However, normalizing means we might reach less stuff so it's + // still useful. + + // Try normalization: + use crate::rustc_trait_selection::infer::TyCtxtInferExt; + use crate::rustc_trait_selection::traits::NormalizeExt; + if let Some(fun_id) = state.enclosing_fun_id { + let param_env = ctxt.tcx.param_env(fun_id); + let infcx = ctxt.tcx.infer_ctxt().ignoring_regions().build(); + let cause = rustc_infer::traits::ObligationCause::dummy(); + let at = infcx.at(&cause, param_env); + let resolved_ty = infcx.resolve_vars_if_possible(*ty); + if !rustc_middle::ty::TypeVisitableExt::has_escaping_bound_vars(&resolved_ty) { + let norm = at.normalize(*ty); + if norm.value != *ty { + return erase_ty(ctxt, state, &norm.value); + } + } + } + + // If normalization isn't possible: let assoc_item = ctxt.tcx.associated_item(t.def_id); let name = state.typ_param(assoc_item.name.to_string(), None); let trait_def = ctxt.tcx.generics_of(t.def_id).parent; @@ -782,10 +803,21 @@ fn erase_call<'tcx>( // Maybe resolve from trait function to a specific implementation - let mut node_substs = node_substs; + let node_substs = node_substs; let mut fn_def_id = fn_def_id.expect("call id"); let param_env = ctxt.tcx.param_env(state.enclosing_fun_id.expect("enclosing_fun_id")); + + let rust_item = crate::verus_items::get_rust_item(ctxt.tcx, fn_def_id); + let mut node_substs = crate::fn_call_to_vir::fix_node_substs( + ctxt.tcx, + ctxt.types(), + node_substs, + rust_item, + &args_slice.iter().collect::>(), + expr, + ); + let normalized_substs = ctxt.tcx.normalize_erasing_regions(param_env, node_substs); let inst = rustc_middle::ty::Instance::resolve( ctxt.tcx, @@ -1798,6 +1830,8 @@ fn erase_fn_common<'tcx>( &mut typ_params, &mut generic_bounds, ); + + state.enclosing_fun_id = Some(id); let mut params: Vec = Vec::new(); for ((input, param), param_info) in inputs.iter().zip(f_vir.x.params.iter()).zip(params_info.iter()) @@ -1832,6 +1866,7 @@ fn erase_fn_common<'tcx>( } else { Some((None, erase_ty(ctxt, state, &fn_sig.output().skip_binder()))) }; + state.enclosing_fun_id = None; let decl = FunDecl { sig_span: sig_span, name_span, diff --git a/source/rust_verify/src/verus_items.rs b/source/rust_verify/src/verus_items.rs index 0c1911a4f6..a7d9238965 100644 --- a/source/rust_verify/src/verus_items.rs +++ b/source/rust_verify/src/verus_items.rs @@ -546,6 +546,7 @@ pub(crate) enum RustItem { IntIntrinsic(RustIntIntrinsicItem), AllocGlobal, TryTraitBranch, + ResidualTraitFromResidual, IntoIterFn, Destruct, } @@ -582,6 +583,9 @@ pub(crate) fn get_rust_item<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option Err(err) => assert_one_fails(err) } + +test_verify_one_file! { + #[test] question_mark_option verus_code! { + use vstd::*; + + fn test() -> (res: Option) + ensures res.is_none() + { + let x: Option = None; + let y = x?; + + assert(false); + return None; + } + + fn test2() -> (res: Option) + { + let x: Option = Some(5); + let y = x?; + + assert(false); // FAILS + return None; + } + + fn test3() -> (res: Option) + ensures res.is_some(), + { + let x: Option = None; + let y = x?; // FAILS + + return Some(13); + } + + fn test4() -> (res: Option) + ensures false, + { + let x: Option = Some(12); + let y = x?; + + assert(y == 12); + + loop { } + } + } => Err(err) => assert_fails(err, 2) +} + +test_verify_one_file! { + #[test] question_mark_result verus_code! { + use vstd::*; + + fn test() -> (res: Result) + ensures res === Err(false), + { + let x: Result = Err(false); + let y = x?; + + assert(false); + return Err(true); + } + + fn test2() -> (res: Result) + { + let x: Result = Ok(5); + let y = x?; + + assert(false); // FAILS + return Err(false); + } + + fn test3() -> (res: Result) + ensures res.is_ok(), + { + let x: Result = Err(false); + let y = x?; // FAILS + + return Ok(13); + } + + fn test4() -> (res: Result) + ensures false, + { + let x: Result = Ok(12); + let y = x?; + + assert(y == 12); + + loop { } + } + } => Err(err) => assert_fails(err, 2) +} From 52bf4e17c4ed3713f8cd31ba9934fe707f8f94b6 Mon Sep 17 00:00:00 2001 From: tjhance Date: Mon, 15 Jan 2024 11:01:02 -0500 Subject: [PATCH 04/11] support unions (#951) * support unions * fix rebasing issues related to field VIR naming changes --- dependencies/syn/src/item.rs | 6 + source/builtin/src/lib.rs | 7 + source/builtin_macros/src/syntax.rs | 12 +- source/pervasive/std_specs/core.rs | 5 + source/rust_verify/src/fn_call_to_vir.rs | 111 +++++- source/rust_verify/src/lifetime_ast.rs | 1 + source/rust_verify/src/lifetime_emit.rs | 3 +- source/rust_verify/src/lifetime_generate.rs | 31 +- source/rust_verify/src/rust_to_vir.rs | 27 +- source/rust_verify/src/rust_to_vir_adts.rs | 91 +++++ source/rust_verify/src/rust_to_vir_expr.rs | 115 ++++-- source/rust_verify/src/verus_items.rs | 2 + source/rust_verify_test/tests/adts.rs | 9 + source/rust_verify_test/tests/unions.rs | 407 ++++++++++++++++++++ source/vir/src/ast.rs | 8 + source/vir/src/ast_simplify.rs | 15 +- source/vir/src/ast_to_sst.rs | 34 +- source/vir/src/modes.rs | 7 +- source/vir/src/poly.rs | 8 +- source/vir/src/sst_to_air.rs | 8 +- source/vir/src/triggers_auto.rs | 2 +- source/vir/src/well_formed.rs | 2 +- 22 files changed, 854 insertions(+), 57 deletions(-) create mode 100644 source/rust_verify_test/tests/unions.rs diff --git a/dependencies/syn/src/item.rs b/dependencies/syn/src/item.rs index 94e82a6283..ea42733e28 100644 --- a/dependencies/syn/src/item.rs +++ b/dependencies/syn/src/item.rs @@ -2237,7 +2237,13 @@ pub mod parsing { impl Parse for ItemUnion { fn parse(input: ParseStream) -> Result { let attrs = input.call(Attribute::parse_outer)?; + let vis = input.parse::()?; + + if input.peek(Token![tracked]) || input.peek(Token![ghost]) { + return Err(input.error("a 'union' can only be exec-mode")); + } + let union_token = input.parse::()?; let ident = input.parse::()?; let generics = input.parse::()?; diff --git a/source/builtin/src/lib.rs b/source/builtin/src/lib.rs index 0ca50b9da7..f85c7693b0 100644 --- a/source/builtin/src/lib.rs +++ b/source/builtin/src/lib.rs @@ -255,6 +255,13 @@ pub fn get_variant_field(_a: Adt, _variant: &str, _field: &str) -> F unimplemented!(); } +#[cfg(verus_keep_ghost)] +#[rustc_diagnostic_item = "verus::builtin::get_union_field"] +#[verifier::spec] +pub fn get_union_field(_a: Adt, _field: &str) -> Field { + unimplemented!(); +} + #[cfg(verus_keep_ghost)] #[rustc_diagnostic_item = "verus::builtin::assume_"] #[verifier::proof] diff --git a/source/builtin_macros/src/syntax.rs b/source/builtin_macros/src/syntax.rs index f92907c8cd..8211468ba3 100644 --- a/source/builtin_macros/src/syntax.rs +++ b/source/builtin_macros/src/syntax.rs @@ -14,15 +14,15 @@ use syn_verus::token::{Brace, Bracket, Paren, Semi}; use syn_verus::visit_mut::{ visit_block_mut, visit_expr_loop_mut, visit_expr_mut, visit_expr_while_mut, visit_field_mut, visit_impl_item_method_mut, visit_item_const_mut, visit_item_enum_mut, visit_item_fn_mut, - visit_item_static_mut, visit_item_struct_mut, visit_local_mut, visit_trait_item_method_mut, - VisitMut, + visit_item_static_mut, visit_item_struct_mut, visit_item_union_mut, visit_local_mut, + visit_trait_item_method_mut, VisitMut, }; use syn_verus::{ braced, bracketed, parenthesized, parse_macro_input, AttrStyle, Attribute, BareFnArg, BinOp, Block, DataMode, Decreases, Ensures, Expr, ExprBinary, ExprCall, ExprLit, ExprLoop, ExprTuple, ExprUnary, ExprWhile, Field, FnArgKind, FnMode, Global, Ident, ImplItem, ImplItemMethod, Invariant, InvariantEnsures, InvariantNameSet, InvariantNameSetList, Item, ItemConst, ItemEnum, - ItemFn, ItemImpl, ItemMod, ItemStatic, ItemStruct, ItemTrait, Lit, Local, ModeSpec, + ItemFn, ItemImpl, ItemMod, ItemStatic, ItemStruct, ItemTrait, ItemUnion, Lit, Local, ModeSpec, ModeSpecChecked, Pat, Path, PathArguments, PathSegment, Publish, Recommends, Requires, ReturnType, Signature, SignatureDecreases, SignatureInvariants, Stmt, Token, TraitItem, TraitItemMethod, Type, TypeFnSpec, UnOp, Visibility, @@ -2407,6 +2407,12 @@ impl VisitMut for Visitor { self.filter_attrs(&mut item.attrs); } + fn visit_item_union_mut(&mut self, item: &mut ItemUnion) { + item.attrs.push(mk_verus_attr(item.span(), quote! { verus_macro })); + visit_item_union_mut(self, item); + self.filter_attrs(&mut item.attrs); + } + fn visit_item_struct_mut(&mut self, item: &mut ItemStruct) { item.attrs.push(mk_verus_attr(item.span(), quote! { verus_macro })); visit_item_struct_mut(self, item); diff --git a/source/pervasive/std_specs/core.rs b/source/pervasive/std_specs/core.rs index 029d51fadb..c049ee8ae9 100644 --- a/source/pervasive/std_specs/core.rs +++ b/source/pervasive/std_specs/core.rs @@ -54,4 +54,9 @@ pub fn ex_intrinsics_unlikely(b: bool) -> (c: bool) core::intrinsics::unlikely(b) } +#[verifier::external_type_specification] +#[verifier::external_body] +#[verifier::reject_recursive_types_in_ground_variants(V)] +pub struct ExManuallyDrop(core::mem::ManuallyDrop); + } diff --git a/source/rust_verify/src/fn_call_to_vir.rs b/source/rust_verify/src/fn_call_to_vir.rs index 73b505a3d2..c1c79b40ce 100644 --- a/source/rust_verify/src/fn_call_to_vir.rs +++ b/source/rust_verify/src/fn_call_to_vir.rs @@ -30,7 +30,7 @@ use vir::ast::{ ArithOp, AssertQueryMode, AutospecUsage, BinaryOp, BitwiseOp, BuiltinSpecFun, CallTarget, ChainedOp, ComputeMode, Constant, ExprX, FieldOpr, FunX, HeaderExpr, HeaderExprX, InequalityOp, IntRange, IntegerTypeBoundKind, Mode, ModeCoercion, MultiOp, Quant, Typ, TypX, UnaryOp, - UnaryOpr, VarAt, VirErr, + UnaryOpr, VarAt, VariantCheck, VirErr, }; use vir::ast_util::{const_int_from_string, typ_to_diagnostic_str, types_equal, undecorate_typ}; use vir::def::field_ident_from_rust; @@ -733,6 +733,33 @@ fn verus_item_to_vir<'tcx, 'a>( variant: str_ident(&variant_name), field: variant_field.unwrap(), get_variant: true, + check: VariantCheck::None, + }), + adt_arg, + )) + } + ExprItem::GetUnionField => { + record_spec_fn_allow_proof_args(bctx, expr); + assert!(args.len() == 2); + let adt_arg = expr_to_vir(bctx, &args[0], ExprModifier::REGULAR)?; + let field_name = get_string_lit_arg(&args[1], &f_name)?; + + let adt_path = check_union_field( + bctx, + expr.span, + args[0], + &field_name, + &bctx.types.expr_ty(expr), + )?; + + let field_ident = str_ident(&field_name); + mk_expr(ExprX::UnaryOpr( + UnaryOpr::Field(FieldOpr { + datatype: adt_path, + variant: field_ident.clone(), + field: field_ident_from_rust(&field_ident), + get_variant: true, + check: VariantCheck::None, }), adt_arg, )) @@ -1791,11 +1818,6 @@ fn check_variant_field<'tcx>( } }; - let variant_opt = adt.variants().iter().find(|v| v.ident(tcx).as_str() == variant_name); - let Some(variant) = variant_opt else { - return err_span(span, format!("no variant `{variant_name:}` for this datatype")); - }; - let vir_adt_ty = mid_ty_to_vir(tcx, &bctx.ctxt.verus_items, bctx.fun_id, span, &ty, false)?; let adt_path = match &*vir_adt_ty { TypX::Datatype(path, _, _) => path.clone(), @@ -1804,9 +1826,34 @@ fn check_variant_field<'tcx>( } }; + if adt.is_union() { + if field_name_typ.is_some() { + // Don't use get_variant_field with unions + return err_span( + span, + format!("this datatype is a union; consider `get_union_field` instead"), + ); + } + let variant = adt.non_enum_variant(); + let field_opt = variant.fields.iter().find(|f| f.ident(tcx).as_str() == variant_name); + if field_opt.is_none() { + return err_span(span, format!("no field `{variant_name:}` for this union")); + } + + return Ok((adt_path, None)); + } + + // Enum case: + + let variant_opt = adt.variants().iter().find(|v| v.ident(tcx).as_str() == variant_name); + let Some(variant) = variant_opt else { + return err_span(span, format!("no variant `{variant_name:}` for this datatype")); + }; + match field_name_typ { None => Ok((adt_path, None)), Some((field_name, expected_field_typ)) => { + // The 'get_variant_field' case let field_opt = variant.fields.iter().find(|f| f.ident(tcx).as_str() == field_name); let Some(field) = field_opt else { return err_span(span, format!("no field `{field_name:}` for this variant")); @@ -1834,6 +1881,58 @@ fn check_variant_field<'tcx>( } } +fn check_union_field<'tcx>( + bctx: &BodyCtxt<'tcx>, + span: Span, + adt_arg: &'tcx Expr<'tcx>, + field_name: &String, + expected_field_typ: &rustc_middle::ty::Ty<'tcx>, +) -> Result { + let tcx = bctx.ctxt.tcx; + + let ty = bctx.types.expr_ty_adjusted(adt_arg); + let ty = match ty.kind() { + rustc_middle::ty::TyKind::Ref(_, t, rustc_ast::Mutability::Not) => t, + _ => &ty, + }; + let (adt, substs) = match ty.kind() { + rustc_middle::ty::TyKind::Adt(adt, substs) => (adt, substs), + _ => { + return err_span(span, format!("expected type to be datatype")); + } + }; + + if !adt.is_union() { + return err_span(span, format!("get_union_field expects a union type")); + } + + let variant = adt.non_enum_variant(); + + let field_opt = variant.fields.iter().find(|f| f.ident(tcx).as_str() == field_name); + let Some(field) = field_opt else { + return err_span(span, format!("no field `{field_name:}` for this union")); + }; + + let field_ty = field.ty(tcx, substs); + let vir_field_ty = + mid_ty_to_vir(tcx, &bctx.ctxt.verus_items, bctx.fun_id, span, &field_ty, false)?; + let vir_expected_field_ty = + mid_ty_to_vir(tcx, &bctx.ctxt.verus_items, bctx.fun_id, span, &expected_field_typ, false)?; + if !types_equal(&vir_field_ty, &vir_expected_field_ty) { + return err_span(span, "field has the wrong type"); + } + + let vir_adt_ty = mid_ty_to_vir(tcx, &bctx.ctxt.verus_items, bctx.fun_id, span, &ty, false)?; + let adt_path = match &*vir_adt_ty { + TypX::Datatype(path, _, _) => path.clone(), + _ => { + return err_span(span, format!("expected type to be datatype")); + } + }; + + Ok(adt_path) +} + fn record_compilable_operator<'tcx>(bctx: &BodyCtxt<'tcx>, expr: &Expr, op: CompilableOperator) { let resolved_call = ResolvedCall::CompilableOperator(op); let mut erasure_info = bctx.ctxt.erasure_info.borrow_mut(); diff --git a/source/rust_verify/src/lifetime_ast.rs b/source/rust_verify/src/lifetime_ast.rs index 30fa4fdd55..ff109f6b7c 100644 --- a/source/rust_verify/src/lifetime_ast.rs +++ b/source/rust_verify/src/lifetime_ast.rs @@ -119,6 +119,7 @@ pub(crate) enum Fields { pub(crate) enum Datatype { Struct(Fields), Enum(Vec<(Id, Fields)>), + Union(Fields), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/source/rust_verify/src/lifetime_emit.rs b/source/rust_verify/src/lifetime_emit.rs index 33467b09bd..d0d4f4058b 100644 --- a/source/rust_verify/src/lifetime_emit.rs +++ b/source/rust_verify/src/lifetime_emit.rs @@ -852,6 +852,7 @@ pub(crate) fn emit_datatype_decl(state: &mut EmitState, d: &DatatypeDecl) { let d_keyword = match &*d.datatype { Datatype::Struct(..) => "struct ", Datatype::Enum(..) => "enum ", + Datatype::Union(..) => "union ", }; state.newline(); state.write_spanned(d_keyword, d.span); @@ -865,7 +866,7 @@ pub(crate) fn emit_datatype_decl(state: &mut EmitState, d: &DatatypeDecl) { } }; match &*d.datatype { - Datatype::Struct(fields) => { + Datatype::Struct(fields) | Datatype::Union(fields) => { let suffix = if suffix_where { "" } else { ";" }; emit_fields(state, fields, suffix); if suffix_where { diff --git a/source/rust_verify/src/lifetime_generate.rs b/source/rust_verify/src/lifetime_generate.rs index 4cabfaecb6..2418ca63ce 100644 --- a/source/rust_verify/src/lifetime_generate.rs +++ b/source/rust_verify/src/lifetime_generate.rs @@ -3,7 +3,7 @@ use crate::erase::{ErasureHints, ResolvedCall}; use crate::rust_to_vir_base::{ def_id_to_vir_path, local_to_var, mid_ty_const_to_vir, mid_ty_to_vir_datatype, }; -use crate::rust_to_vir_expr::get_adt_res; +use crate::rust_to_vir_expr::{get_adt_res_struct_enum, get_adt_res_struct_enum_union}; use crate::verus_items::{PervasiveItem, RustItem, VerusItem, VerusItems}; use crate::{lifetime_ast::*, verus_items}; use air::ast_util::str_ident; @@ -533,7 +533,8 @@ fn erase_pat<'tcx>(ctxt: &Context<'tcx>, state: &mut State, pat: &Pat) -> Patter } PatKind::Path(qpath) => { let res = ctxt.types().qpath_res(qpath, pat.hir_id); - let (adt_def_id, variant_def, is_enum) = get_adt_res(ctxt.tcx, res, pat.span).unwrap(); + let (adt_def_id, variant_def, is_enum) = + get_adt_res_struct_enum(ctxt.tcx, res, pat.span).unwrap(); let variant_name = str_ident(&variant_def.ident(ctxt.tcx).as_str()); let vir_path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, adt_def_id); @@ -559,7 +560,8 @@ fn erase_pat<'tcx>(ctxt: &Context<'tcx>, state: &mut State, pat: &Pat) -> Patter } PatKind::TupleStruct(qpath, pats, dot_dot_pos) => { let res = ctxt.types().qpath_res(qpath, pat.hir_id); - let (adt_def_id, variant_def, is_enum) = get_adt_res(ctxt.tcx, res, pat.span).unwrap(); + let (adt_def_id, variant_def, is_enum) = + get_adt_res_struct_enum(ctxt.tcx, res, pat.span).unwrap(); let variant_name = str_ident(&variant_def.ident(ctxt.tcx).as_str()); let vir_path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, adt_def_id); @@ -574,7 +576,8 @@ fn erase_pat<'tcx>(ctxt: &Context<'tcx>, state: &mut State, pat: &Pat) -> Patter } PatKind::Struct(qpath, pats, has_omitted) => { let res = ctxt.types().qpath_res(qpath, pat.hir_id); - let (adt_def_id, variant_def, is_enum) = get_adt_res(ctxt.tcx, res, pat.span).unwrap(); + let (adt_def_id, variant_def, is_enum) = + get_adt_res_struct_enum(ctxt.tcx, res, pat.span).unwrap(); let variant_name = str_ident(&variant_def.ident(ctxt.tcx).as_str()); let vir_path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, adt_def_id); @@ -1051,7 +1054,7 @@ fn erase_expr<'tcx>( None } else { let (adt_def_id, variant_def, is_enum) = - get_adt_res(ctxt.tcx, res, expr.span).unwrap(); + get_adt_res_struct_enum(ctxt.tcx, res, expr.span).unwrap(); let variant_name = str_ident(&variant_def.ident(ctxt.tcx).as_str()); let vir_path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, adt_def_id); @@ -1205,9 +1208,9 @@ fn erase_expr<'tcx>( erase_spec_exps(ctxt, state, expr, exps) } else { let res = ctxt.types().qpath_res(qpath, expr.hir_id); - let (adt_def_id, variant_def, is_enum) = - get_adt_res(ctxt.tcx, res, expr.span).unwrap(); - let variant_name = str_ident(&variant_def.ident(ctxt.tcx).as_str()); + + let (adt_def_id, variant_name, is_enum) = + get_adt_res_struct_enum_union(ctxt.tcx, res, expr.span, fields).unwrap(); let vir_path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, adt_def_id); let datatype = &ctxt.datatypes[&vir_path]; @@ -1465,7 +1468,7 @@ fn erase_block<'tcx>( block: &Block<'tcx>, ) -> Option { let mk_exp = |e: ExpX| Some(Box::new((block.span, e))); - assert!(block.rules == BlockCheckMode::DefaultBlock); + assert!(matches!(block.rules, BlockCheckMode::DefaultBlock | BlockCheckMode::UnsafeBlock(_))); assert!(!block.targeted_by_break); let mut stms: Vec = Vec::new(); for stmt in block.stmts { @@ -2286,6 +2289,10 @@ fn erase_mir_datatype<'tcx>(ctxt: &Context<'tcx>, state: &mut State, id: DefId) } let datatype = Datatype::Enum(variants); erase_datatype(ctxt, state, span, id, datatype); + } else if adt_def.is_union() { + let fields = erase_variant_data(ctxt, state, adt_def.non_enum_variant()); + let datatype = Datatype::Union(fields); + erase_datatype(ctxt, state, span, id, datatype); } else { panic!("unexpected datatype {:?}", id); } @@ -2442,6 +2449,12 @@ pub(crate) fn gen_check_tracked_lifetimes<'tcx>( } state.reach_datatype(&ctxt, id); } + ItemKind::Union(_e, _generics) => { + if vattrs.is_external(&ctxt.cmd_line_args) { + continue; + } + state.reach_datatype(&ctxt, id); + } ItemKind::Const(_ty, _, body_id) | ItemKind::Static(_ty, _, body_id) => { if vattrs.size_of_global || vattrs.is_external(&ctxt.cmd_line_args) { continue; diff --git a/source/rust_verify/src/rust_to_vir.rs b/source/rust_verify/src/rust_to_vir.rs index 6a8439a4b3..de614dca1e 100644 --- a/source/rust_verify/src/rust_to_vir.rs +++ b/source/rust_verify/src/rust_to_vir.rs @@ -8,7 +8,7 @@ For soundness's sake, be as defensive as possible: use crate::attributes::get_verifier_attrs; use crate::context::Context; -use crate::rust_to_vir_adts::{check_item_enum, check_item_struct}; +use crate::rust_to_vir_adts::{check_item_enum, check_item_struct, check_item_union}; use crate::rust_to_vir_base::{ check_generics_bounds, def_id_to_vir_path, mid_ty_to_vir, mk_visibility, process_predicate_bounds, typ_path_and_ident_to_vir_path, @@ -195,6 +195,31 @@ fn check_item<'tcx>( &adt_def, )?; } + ItemKind::Union(variant_data, generics) => { + if vattrs.is_external(&ctxt.cmd_line_args) { + let def_id = id.owner_id.to_def_id(); + let path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, def_id); + vir.external_types.push(path); + + return Ok(()); + } + + let tyof = ctxt.tcx.type_of(item.owner_id.to_def_id()).skip_binder(); + let adt_def = tyof.ty_adt_def().expect("adt_def"); + + check_item_union( + ctxt, + vir, + &module_path(), + item.span, + id, + visibility(), + ctxt.tcx.hir().attrs(item.hir_id()), + variant_data, + generics, + &adt_def, + )?; + } ItemKind::Impl(impll) => { let impl_def_id = item.owner_id.to_def_id(); let impl_path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, impl_def_id); diff --git a/source/rust_verify/src/rust_to_vir_adts.rs b/source/rust_verify/src/rust_to_vir_adts.rs index 550637683e..47d001587c 100644 --- a/source/rust_verify/src/rust_to_vir_adts.rs +++ b/source/rust_verify/src/rust_to_vir_adts.rs @@ -284,6 +284,97 @@ pub fn check_item_enum<'tcx>( Ok(()) } +pub fn check_item_union<'tcx>( + ctxt: &Context<'tcx>, + vir: &mut KrateX, + module_path: &Path, + span: Span, + id: &ItemId, + visibility: vir::ast::Visibility, + attrs: &[Attribute], + variant_data: &'tcx VariantData<'tcx>, + generics: &'tcx Generics<'tcx>, + adt_def: &rustc_middle::ty::AdtDef, +) -> Result<(), VirErr> { + assert!(adt_def.is_union()); + + let vattrs = get_verifier_attrs(attrs, Some(&mut *ctxt.diagnostics.borrow_mut()))?; + + if vattrs.external_fn_specification { + return err_span(span, "`external_fn_specification` attribute not supported here"); + } + + let mode = get_mode(Mode::Exec, attrs); + if mode != Mode::Exec { + return err_span(span, "a 'union' can only be exec-mode"); + } + let VariantData::Struct(hir_fields, _) = variant_data else { + return err_span(span, "check_item_union: wrong VariantData"); + }; + for hir_field_def in hir_fields.iter() { + let mode = get_mode(Mode::Exec, ctxt.tcx.hir().attrs(hir_field_def.hir_id)); + if mode != Mode::Exec { + return err_span(span, "a union field can only be exec-mode"); + } + } + + let def_id = id.owner_id.to_def_id(); + let (typ_params, typ_bounds) = check_generics_bounds( + ctxt.tcx, + &ctxt.verus_items, + generics, + vattrs.external_body, + def_id, + Some(&vattrs), + Some(&mut *ctxt.diagnostics.borrow_mut()), + )?; + let path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, def_id); + + let (variants, transparency) = if vattrs.external_body { + let name = path.segments.last().expect("unexpected struct path"); + let variant_name = Arc::new(name.clone()); + (vec![ident_binder(&variant_name, &Arc::new(vec![]))], DatatypeTransparency::Never) + } else { + let mut variants: Vec<_> = vec![]; + let mut total_vis = visibility.clone(); + assert!(adt_def.variants().len() == 1); + let variant_def = adt_def.variants().iter().next().unwrap(); + for field_def in variant_def.fields.iter() { + let variant_name = str_ident(field_def.ident(ctxt.tcx).as_str()); + let field_name = field_ident_from_rust(&variant_name); + + let vis = mk_visibility_from_vis(ctxt, field_def.vis); + total_vis = total_vis.join(&vis); + + let field_ty = ctxt.tcx.type_of(field_def.did).skip_binder(); + let typ = mid_ty_to_vir(ctxt.tcx, &ctxt.verus_items, def_id, span, &field_ty, false)?; + + let field = (typ, Mode::Exec, vis); + let variant = + ident_binder(&variant_name, &Arc::new(vec![ident_binder(&field_name, &field)])); + + variants.push(variant); + } + (variants, DatatypeTransparency::WhenVisible(total_vis)) + }; + vir.datatypes.push(ctxt.spanned_new( + span, + DatatypeX { + path, + proxy: None, + visibility, + owning_module: Some(module_path.clone()), + transparency, + typ_params, + typ_bounds, + variants: Arc::new(variants), + mode: get_mode(Mode::Exec, attrs), + ext_equal: vattrs.ext_equal, + }, + )); + Ok(()) +} + pub(crate) fn check_item_external<'tcx>( ctxt: &Context<'tcx>, vir: &mut KrateX, diff --git a/source/rust_verify/src/rust_to_vir_expr.rs b/source/rust_verify/src/rust_to_vir_expr.rs index 8db90d2b78..6d7cf0eaeb 100644 --- a/source/rust_verify/src/rust_to_vir_expr.rs +++ b/source/rust_verify/src/rust_to_vir_expr.rs @@ -37,7 +37,7 @@ use std::sync::Arc; use vir::ast::{ ArithOp, ArmX, AutospecUsage, BinaryOp, BitwiseOp, CallTarget, Constant, ExprX, FieldOpr, FunX, HeaderExprX, InequalityOp, IntRange, InvAtomicity, Mode, PatternX, SpannedTyped, StmtX, Stmts, - Typ, TypX, UnaryOp, UnaryOpr, VirErr, + Typ, TypX, UnaryOp, UnaryOpr, VariantCheck, VirErr, }; use vir::ast_util::{ident_binder, typ_to_diagnostic_str, types_equal, undecorate_typ}; use vir::def::{field_ident_from_rust, positional_field_ident}; @@ -249,13 +249,56 @@ pub(crate) fn get_fn_path<'tcx>( } } +/// Handle a struct expression `Struct { ... }` +/// Returns the DefId for the ADT and the variant name used in VIR. +/// Getting the variant name requires a special case for unions. +pub(crate) fn get_adt_res_struct_enum_union<'tcx>( + tcx: TyCtxt<'tcx>, + res: Res, + span: Span, + fields: &'tcx [rustc_hir::ExprField<'tcx>], +) -> Result<(DefId, vir::ast::Ident, bool), VirErr> { + let (adt_def_id, variant_def, is_enum, is_union) = get_adt_res(tcx, res, span)?; + if is_union { + // For a union, rustc has one "variant" with all the fields, while our + // VIR representation has one variant per field. + // Use the name of the VIR variant (same as the field name). + assert!(fields.len() == 1); + let variant_name = str_ident(fields[0].ident.as_str()); + Ok((adt_def_id, variant_name, is_enum)) + } else { + // Structs, enums: VIR variant corresponds to rustc's variant. + let variant_name = str_ident(&variant_def.ident(tcx).as_str()); + Ok((adt_def_id, variant_name, is_enum)) + } +} + /// Gets the DefId of the AdtDef for this Res and the Variant -/// The bool is true if it's an enum, false for struct -pub(crate) fn get_adt_res<'tcx>( +/// The bool return value: is_enum +/// Doesn't support unions, use this where union is unexpected or unsupported +pub(crate) fn get_adt_res_struct_enum<'tcx>( tcx: TyCtxt<'tcx>, res: Res, span: Span, ) -> Result<(DefId, &'tcx VariantDef, bool), VirErr> { + let (adt_def_id, variant_def, is_enum, is_union) = get_adt_res(tcx, res, span)?; + if is_union { + unsupported_err!(span, "using a union here") + } else { + Ok((adt_def_id, variant_def, is_enum)) + } +} + +/// Gets the DefId of the AdtDef for this Res and the Variant +/// The bool return values: (is_enum, is_union) +/// (As a caller, you probably want to use `get_adt_res_struct_enum` or +/// `get_adt_res_struct_enum_union` instead, +/// depending on whether you need to handle unions or not.) +fn get_adt_res<'tcx>( + tcx: TyCtxt<'tcx>, + res: Res, + span: Span, +) -> Result<(DefId, &'tcx VariantDef, bool, bool), VirErr> { // Based off of implementation of rustc_middle's TyCtxt::expect_variant_res // But with a few more cases it didn't handle // Also, returns the adt DefId instead of just the VariantDef @@ -265,22 +308,30 @@ pub(crate) fn get_adt_res<'tcx>( Res::Def(DefKind::Variant, did) => { let enum_did = tcx.parent(did); let variant_def = tcx.adt_def(enum_did).variant_with_id(did); - Ok((enum_did, variant_def, true)) + Ok((enum_did, variant_def, true, false)) } Res::Def(DefKind::Struct, did) => { let variant_def = tcx.adt_def(did).non_enum_variant(); - Ok((did, variant_def, false)) + Ok((did, variant_def, false, false)) + } + Res::Def(DefKind::Union, did) => { + let variant_def = tcx.adt_def(did).non_enum_variant(); + Ok((did, variant_def, false, true)) } Res::Def(DefKind::Ctor(CtorOf::Variant, ..), variant_ctor_did) => { let variant_did = tcx.parent(variant_ctor_did); let enum_did = tcx.parent(variant_did); - let variant_def = tcx.adt_def(enum_did).variant_with_ctor_id(variant_ctor_did); - Ok((enum_did, variant_def, true)) + let adt_def = tcx.adt_def(enum_did); + assert!(adt_def.is_enum()); + let variant_def = adt_def.variant_with_ctor_id(variant_ctor_did); + Ok((enum_did, variant_def, true, false)) } Res::Def(DefKind::Ctor(CtorOf::Struct, ..), ctor_did) => { let struct_did = tcx.parent(ctor_did); - let variant_def = tcx.adt_def(struct_did).non_enum_variant(); - Ok((struct_did, variant_def, false)) + let adt_def = tcx.adt_def(struct_did); + assert!(adt_def.is_struct()); + let variant_def = adt_def.non_enum_variant(); + Ok((struct_did, variant_def, false, false)) } Res::Def(DefKind::TyAlias { lazy }, alias_did) => { unsupported_err_unless!(!lazy, span, "lazy type alias"); @@ -297,8 +348,11 @@ pub(crate) fn get_adt_res<'tcx>( } }; - let variant_def = tcx.adt_def(struct_did).non_enum_variant(); - Ok((struct_did, variant_def, false)) + let adt_def = tcx.adt_def(struct_did); + assert!(adt_def.is_struct() || adt_def.is_union()); + + let variant_def = adt_def.non_enum_variant(); + Ok((struct_did, variant_def, false, adt_def.is_union())) } Res::SelfCtor(impl_id) | Res::SelfTyAlias { alias_to: impl_id, .. } => { let self_ty = tcx.type_of(impl_id).skip_binder(); @@ -312,8 +366,11 @@ pub(crate) fn get_adt_res<'tcx>( } }; - let variant_def = tcx.adt_def(struct_did).non_enum_variant(); - Ok((struct_did, variant_def, false)) + let adt_def = tcx.adt_def(struct_did); + assert!(adt_def.is_struct() || adt_def.is_union()); + + let variant_def = adt_def.non_enum_variant(); + Ok((struct_did, variant_def, false, adt_def.is_union())) } _ => { println!("res: {:#?}", res); @@ -333,7 +390,7 @@ pub(crate) fn expr_tuple_datatype_ctor_to_vir<'tcx>( let tcx = bctx.ctxt.tcx; let expr_typ = typ_of_node(bctx, expr.span, &expr.hir_id, false)?; - let (adt_def_id, variant_def, _is_enum) = get_adt_res(tcx, *res, fun_span)?; + let (adt_def_id, variant_def, _is_enum) = get_adt_res_struct_enum(tcx, *res, fun_span)?; let variant_name = str_ident(&variant_def.ident(tcx).as_str()); let vir_path = def_id_to_vir_path(bctx.ctxt.tcx, &bctx.ctxt.verus_items, adt_def_id); @@ -390,7 +447,7 @@ pub(crate) fn pattern_to_vir_inner<'tcx>( } PatKind::Path(qpath) => { let res = bctx.types.qpath_res(qpath, pat.hir_id); - let (adt_def_id, variant_def, _is_enum) = get_adt_res(tcx, res, pat.span)?; + let (adt_def_id, variant_def, _is_enum) = get_adt_res_struct_enum(tcx, res, pat.span)?; let variant_name = str_ident(&variant_def.ident(tcx).as_str()); let vir_path = def_id_to_vir_path(bctx.ctxt.tcx, &bctx.ctxt.verus_items, adt_def_id); PatternX::Constructor(vir_path, variant_name, Arc::new(vec![])) @@ -421,7 +478,7 @@ pub(crate) fn pattern_to_vir_inner<'tcx>( } PatKind::TupleStruct(qpath, pats, dot_dot_pos) => { let res = bctx.types.qpath_res(qpath, pat.hir_id); - let (adt_def_id, variant_def, _is_enum) = get_adt_res(tcx, res, pat.span)?; + let (adt_def_id, variant_def, _is_enum) = get_adt_res_struct_enum(tcx, res, pat.span)?; let variant_name = str_ident(&variant_def.ident(tcx).as_str()); let vir_path = def_id_to_vir_path(bctx.ctxt.tcx, &bctx.ctxt.verus_items, adt_def_id); @@ -474,7 +531,7 @@ pub(crate) fn pattern_to_vir_inner<'tcx>( } PatKind::Struct(qpath, pats, _) => { let res = bctx.types.qpath_res(qpath, pat.hir_id); - let (adt_def_id, variant_def, _is_enum) = get_adt_res(tcx, res, pat.span)?; + let (adt_def_id, variant_def, _is_enum) = get_adt_res_struct_enum(tcx, res, pat.span)?; let variant_name = str_ident(&variant_def.ident(tcx).as_str()); let vir_path = def_id_to_vir_path(bctx.ctxt.tcx, &bctx.ctxt.verus_items, adt_def_id); @@ -1445,7 +1502,15 @@ pub(crate) fn expr_to_vir_innermost<'tcx>( let vir_lhs = expr_to_vir(bctx, lhs, lhs_modifier)?; let lhs_ty = tc.expr_ty_adjusted(lhs); let lhs_ty = mid_ty_simplify(tcx, &bctx.ctxt.verus_items, &lhs_ty, true); - let (datatype, variant_name, field_name) = if let Some(adt_def) = lhs_ty.ty_adt_def() { + let (datatype, variant_name, field_name, check) = if let Some(adt_def) = + lhs_ty.ty_adt_def() + { + unsupported_err_unless!( + current_modifier == ExprModifier::REGULAR || !adt_def.is_union(), + expr.span, + "assigning to or taking &mut of a union field", + expr + ); unsupported_err_unless!( adt_def.variants().len() == 1, expr.span, @@ -1455,14 +1520,19 @@ pub(crate) fn expr_to_vir_innermost<'tcx>( let datatype_path = def_id_to_vir_path(tcx, &bctx.ctxt.verus_items, adt_def.did()); let hir_def = bctx.ctxt.tcx.adt_def(adt_def.did()); let variant = hir_def.variants().iter().next().unwrap(); - let variant_name = str_ident(&variant.ident(tcx).as_str()); let field_name = field_ident_from_rust(&name.as_str()); match variant.ctor_kind() { Some(rustc_hir::def::CtorKind::Fn) => {} None => {} Some(rustc_hir::def::CtorKind::Const) => panic!("unexpected tuple constructor"), } - (datatype_path, variant_name, field_name) + let variant_name = if adt_def.is_union() { + str_ident(name.as_str()) + } else { + str_ident(&variant.ident(tcx).as_str()) + }; + let check = if adt_def.is_union() { VariantCheck::Yes } else { VariantCheck::None }; + (datatype_path, variant_name, field_name, check) } else { let lhs_typ = typ_of_node(bctx, lhs.span, &lhs.hir_id, false)?; let lhs_typ = undecorate_typ(&lhs_typ); @@ -1487,6 +1557,7 @@ pub(crate) fn expr_to_vir_innermost<'tcx>( variant: variant_name, field: field_name, get_variant: false, + check, }), vir_lhs, ), @@ -1636,8 +1707,8 @@ pub(crate) fn expr_to_vir_innermost<'tcx>( }; let res = bctx.types.qpath_res(qpath, expr.hir_id); - let (adt_def_id, variant_def, _is_enum) = get_adt_res(tcx, res, expr.span)?; - let variant_name = str_ident(&variant_def.ident(tcx).as_str()); + let (adt_def_id, variant_name, _is_enum) = + get_adt_res_struct_enum_union(tcx, res, expr.span, fields)?; let path = def_id_to_vir_path(bctx.ctxt.tcx, &bctx.ctxt.verus_items, adt_def_id); let vir_fields = Arc::new( diff --git a/source/rust_verify/src/verus_items.rs b/source/rust_verify/src/verus_items.rs index a7d9238965..ef369abfab 100644 --- a/source/rust_verify/src/verus_items.rs +++ b/source/rust_verify/src/verus_items.rs @@ -126,6 +126,7 @@ pub(crate) enum ExprItem { ChooseTuple, Old, GetVariantField, + GetUnionField, IsVariant, StrSliceLen, StrSliceGetChar, @@ -353,6 +354,7 @@ fn verus_items_map() -> Vec<(&'static str, VerusItem)> { ("verus::builtin::choose_tuple", VerusItem::Expr(ExprItem::ChooseTuple)), ("verus::builtin::old", VerusItem::Expr(ExprItem::Old)), ("verus::builtin::get_variant_field", VerusItem::Expr(ExprItem::GetVariantField)), + ("verus::builtin::get_union_field", VerusItem::Expr(ExprItem::GetUnionField)), ("verus::builtin::is_variant", VerusItem::Expr(ExprItem::IsVariant)), ("verus::builtin::strslice_len", VerusItem::Expr(ExprItem::StrSliceLen)), ("verus::builtin::strslice_get_char", VerusItem::Expr(ExprItem::StrSliceGetChar)), diff --git a/source/rust_verify_test/tests/adts.rs b/source/rust_verify_test/tests/adts.rs index 71a9822e60..e331bb65f5 100644 --- a/source/rust_verify_test/tests/adts.rs +++ b/source/rust_verify_test/tests/adts.rs @@ -345,6 +345,15 @@ test_verify_one_file! { } => Err(err) => assert_vir_error_msg(err, "no field `1` for this variant") } +test_verify_one_file! { + #[test] test_builtin_get_variant_field_invalid_3 IS_VARIANT_MAYBE.to_string() + verus_code_str! { + struct T { } + proof fn test_fail(tracked u: Maybe) { + let tracked j = get_variant_field::<_, T>(u, "Some", "0"); + } + } => Err(err) => assert_vir_error_msg(err, "expression has mode spec, expected mode proof") +} + test_verify_one_file! { #[test] test_is_variant_not_enum verus_code! { #[is_variant] diff --git a/source/rust_verify_test/tests/unions.rs b/source/rust_verify_test/tests/unions.rs new file mode 100644 index 0000000000..1a4fcb26f0 --- /dev/null +++ b/source/rust_verify_test/tests/unions.rs @@ -0,0 +1,407 @@ +#![feature(rustc_private)] +#[macro_use] +mod common; +use common::*; + +test_verify_one_file! { + #[test] union_basic verus_code! { + union U { x: u8, y: bool } + + fn test_ok() { + let u = U { x: 3 }; + + assert(is_variant(u, "x")); + assert(!is_variant(u, "y")); + assert(get_union_field::<_, u8>(u, "x") == 3); + + unsafe { + let j = u.x; + assert(j == 3); + } + } + + fn test_fail() { + let u = U { x: 3 }; + + unsafe { + let j = u.y; // FAILS + } + } + + fn test_fail2() { + let u = U { x: 3 }; + + unsafe { + proof { + let tracked j = &u.y; // FAILS + } + } + } + + fn test_fail3() { + let u = U { x: 3 }; + + unsafe { + proof { + let j = &u.y; // FAILS + } + } + } + + impl U { + fn test_self_ctor() { + let u = Self { x: 3 }; + assert(is_variant(u, "x")); + } + } + + type U2 = U; + + fn test_type_alias() { + let u = U2 { x: 3 }; + assert(is_variant(u, "x")); + } + } => Err(err) => assert_fails(err, 3) +} + +test_verify_one_file! { + #[test] union_pattern verus_code! { + union U { x: u8, y: bool } + + fn test_fail() { + let u = U { x: 3 }; + unsafe { + let U { x } = u; + } + } + } => Err(err) => assert_vir_error_msg(err, "The verifier does not yet support the following Rust feature: using a union here") +} + +test_verify_one_file! { + #[test] union_mut_assign verus_code! { + union U { x: u8, y: bool } + + fn test_fail() { + let mut u = U { x: 3 }; + unsafe { + u.x = 7; + } + } + } => Err(err) => assert_vir_error_msg(err, "The verifier does not yet support the following Rust feature: assigning to or taking &mut of a union field") +} + +test_verify_one_file! { + #[test] union_mut_ref verus_code! { + union U { x: u8, y: bool } + + fn take_mut_ref(x: &mut u8) { } + + fn test_fail() { + let mut u = U { x: 3 }; + unsafe { + take_mut_ref(&mut u.x); + } + } + } => Err(err) => assert_vir_error_msg(err, "The verifier does not yet support the following Rust feature: assigning to or taking &mut of a union field") +} + +test_verify_one_file! { + #[test] get_union_field_non_union verus_code! { + enum X { + Foo(u8), + Stuff(bool), + } + + fn test_fail(x: X) { + assert(get_union_field::<_, u8>(x, "Foo") == 5); + } + } => Err(err) => assert_vir_error_msg(err, "get_union_field expects a union type") +} + +test_verify_one_file! { + #[test] get_union_field_bad_field_name verus_code! { + union U { x: u8, y: bool } + + fn test_fail(u: U) { + assert(get_union_field::<_, u8>(u, "z") == 5); + } + } => Err(err) => assert_vir_error_msg(err, "no field `z` for this union") +} + +test_verify_one_file! { + #[test] get_union_field_bad_field_type verus_code! { + union U { x: u8, y: bool } + + fn test_fail(u: U) { + assert(get_union_field::<_, u16>(u, "x") == 5); + } + } => Err(err) => assert_vir_error_msg(err, "field has the wrong type") +} + +test_verify_one_file! { + #[test] get_union_field_exec_mode_fail verus_code! { + union U { x: u8, y: bool } + + fn test_fail(u: U) { + let j = get_union_field::<_, u8>(u, "x"); + } + } => Err(err) => assert_vir_error_msg(err, "cannot get variant in exec mode") +} + +test_verify_one_file! { + #[test] get_union_field_tracked_mode_fail verus_code! { + union U { x: u8, y: bool } + + proof fn test_fail(u: U) { + let tracked j = get_union_field::<_, u8>(u, "x"); + } + } => Err(err) => assert_vir_error_msg(err, "expression has mode spec, expected mode proof") +} + +test_verify_one_file! { + #[test] get_union_field_tracked_mode_fail2 verus_code! { + union U { x: u8, y: bool } + + proof fn test_fail(tracked u: U) { + let tracked j = get_union_field::<_, u8>(u, "x"); + } + } => Err(err) => assert_vir_error_msg(err, "expression has mode spec, expected mode proof") +} + +test_verify_one_file! { + #[test] union_generics verus_code! { + union U { x: A, y: B } + + fn test_ok() { + let u = U:: { x: 3 }; + + assert(is_variant(u, "x")); + assert(!is_variant(u, "y")); + assert(get_union_field::<_, u8>(u, "x") == 3); + + unsafe { + let j = u.x; + assert(j == 3); + } + } + + fn test_fail() { + let u = U:: { x: 3 }; + + unsafe { + let j = u.y; // FAILS + } + } + } => Err(err) => assert_fails(err, 1) +} + +test_verify_one_file! { + #[test] tracked_union_not_supported verus_code! { + tracked union U { x: A, y: B } + } => Err(err) => assert_vir_error_msg(err, "a 'union' can only be exec-mode") +} + +test_verify_one_file! { + #[test] ghost_union_not_supported verus_code! { + tracked union U { x: A, y: B } + } => Err(err) => assert_vir_error_msg(err, "a 'union' can only be exec-mode") +} + +test_verify_one_file! { + #[test] tracked_union_field_not_supported verus_code! { + union U { tracked x: A, y: B } + } => Err(err) => assert_vir_error_msg(err, "a union field can only be exec-mode") +} + +test_verify_one_file! { + #[test] ghost_union_field_not_supported verus_code! { + union U { ghost x: A, y: B } + } => Err(err) => assert_vir_error_msg(err, "a union field can only be exec-mode") +} + +test_verify_one_file! { + #[test] tracked_union_not_supported_attr verus_code! { + #[verifier::spec] union U { x: A, y: B } + } => Err(err) => assert_vir_error_msg(err, "a 'union' can only be exec-mode") +} + +test_verify_one_file! { + #[test] ghost_union_not_supported_attr verus_code! { + #[verifier::proof] union U { x: A, y: B } + } => Err(err) => assert_vir_error_msg(err, "a 'union' can only be exec-mode") +} + +test_verify_one_file! { + #[test] tracked_union_field_not_supported_attr verus_code! { + union U { #[verifier::proof] x: A, y: B } + } => Err(err) => assert_vir_error_msg(err, "a union field can only be exec-mode") +} + +test_verify_one_file! { + #[test] ghost_union_field_not_supported_attr verus_code! { + union U { #[verifier::spec] x: A, y: B } + } => Err(err) => assert_vir_error_msg(err, "a union field can only be exec-mode") +} + +test_verify_one_file! { + #[test] lifetime_union verus_code! { + use vstd::*; + use core::mem::ManuallyDrop; + struct X { } + struct Y { } + + union U { + x: ManuallyDrop, + y: ManuallyDrop, + } + + fn test(u: U) { + unsafe { + let t = u.x; + let s = u.x; + } + } + } => Err(err) => assert_rust_error_msg(err, "use of moved value: `u`") +} + +test_verify_one_file! { + #[test] lifetime_union2 verus_code! { + use vstd::*; + use core::mem::ManuallyDrop; + struct X { } + struct Y { } + + union U { + x: ManuallyDrop, + y: ManuallyDrop, + } + + fn test(u: U) { + unsafe { + let t = u.x; + + proof { + let tracked z = &u.x; + } + } + } + } => Err(err) => assert_vir_error_msg(err, "borrow of moved value: `u`") +} + +test_verify_one_file! { + #[test] union_proof_mode verus_code! { + union U { x: u8, y: bool } + + proof fn test_ok() { + let u = U { x: 3 }; + + assert(is_variant(u, "x")); + assert(!is_variant(u, "y")); + assert(get_union_field::<_, u8>(u, "x") == 3); + + unsafe { + let j = u.x; + assert(j == 3); + } + } + + proof fn test_fail(u: U) { + unsafe { + let j = u.y; // FAILS + } + } + + proof fn test_fail2(tracked u: U) { + unsafe { + let tracked j = &u.y; // FAILS + } + } + + proof fn test_fail3(u: U) { + unsafe { + let j = &u.y; // FAILS + } + } + } => Err(err) => assert_fails(err, 3) +} + +test_verify_one_file! { + #[test] union_mode_error verus_code! { + union U { x: u8, y: bool } + + proof fn test(u: U) { + let tracked x = &u.x; + } + } => Err(err) => assert_vir_error_msg(err, "expression has mode spec, expected mode proof") +} + +test_verify_one_file! { + #[test] union_field_access_in_spec_func verus_code! { + union U { x: u8, y: bool } + + spec fn test(u: U) -> u8 { + u.x + } + // This error messages could be more specific + } => Err(err) => assert_vir_error_msg(err, "expected pure mathematical expression") +} + +test_verify_one_file! { + #[ignore] #[test] is_variant verus_code! { + // TODO support is_variant for unions + + #[is_variant] + union U { x: u8, y: bool } + + fn test_ok() { + let u = U { x: 3 }; + + assert(u.is_x()); + assert(!u.is_y()); + assert(u.get_x() == 3); + + unsafe { + let j = u.x; + assert(j == 3); + } + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] rec_types verus_code! { + use vstd::*; + use core::mem::ManuallyDrop; + + #[verifier::reject_recursive_types(T)] + struct X { + r: u64, + g: Ghost bool>, + } + + union U { + x: u8, + y: ManuallyDrop>, + } + } => Err(err) => assert_vir_error_msg(err, "non-positive position") +} + +test_verify_one_file! { + #[test] visibility verus_code! { + pub union U { x: u8, y: bool } + + pub open spec fn f(u: U) { + get_union_field::<_, u8>(u, "x"); + } + } => Err(err) => assert_vir_error_msg(err, "cannot access any field of a datatype where one or more fields are private") +} + +test_verify_one_file! { + #[test] visibility2 verus_code! { + pub union U { x: u8, pub y: bool } + + pub open spec fn f(b: bool) -> U { + U { y: b } + } + } => Err(err) => assert_vir_error_msg(err, "cannot use constructor of private datatype or datatype whose fields are private") +} diff --git a/source/vir/src/ast.rs b/source/vir/src/ast.rs index c2a0d4065d..9181af5a4c 100644 --- a/source/vir/src/ast.rs +++ b/source/vir/src/ast.rs @@ -256,12 +256,20 @@ pub enum UnaryOp { CharToInt, } +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord, ToDebugSNode)] +pub enum VariantCheck { + None, + //Recommends, + Yes, +} + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord, ToDebugSNode)] pub struct FieldOpr { pub datatype: Path, pub variant: Ident, pub field: Ident, pub get_variant: bool, + pub check: VariantCheck, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord, ToDebugSNode)] diff --git a/source/vir/src/ast_simplify.rs b/source/vir/src/ast_simplify.rs index 550c626843..4fe01201c9 100644 --- a/source/vir/src/ast_simplify.rs +++ b/source/vir/src/ast_simplify.rs @@ -6,8 +6,8 @@ use crate::ast::{ AssocTypeImpl, AutospecUsage, BinaryOp, Binder, BuiltinSpecFun, CallTarget, ChainedOp, Constant, Datatype, DatatypeTransparency, DatatypeX, Expr, ExprX, Exprs, Field, FieldOpr, Function, FunctionKind, Ident, IntRange, ItemKind, Krate, KrateX, Mode, MultiOp, Path, Pattern, - PatternX, SpannedTyped, Stmt, StmtX, TraitImpl, Typ, TypX, UnaryOp, UnaryOpr, VirErr, - Visibility, + PatternX, SpannedTyped, Stmt, StmtX, TraitImpl, Typ, TypX, UnaryOp, UnaryOpr, VariantCheck, + VirErr, Visibility, }; use crate::ast_util::int_range_from_type; use crate::ast_util::is_integer_type; @@ -165,6 +165,7 @@ fn pattern_to_exprs_rec( variant: variant.clone(), field: prefix_tuple_field(i), get_variant: false, + check: VariantCheck::None, }); let field_exp = pattern_field_expr(&pattern.span, expr, &pat.typ, field_op); let pattern_test = pattern_to_exprs_rec(ctx, state, &field_exp, pat, decls)?; @@ -184,6 +185,7 @@ fn pattern_to_exprs_rec( variant: variant.clone(), field: binder.name.clone(), get_variant: false, + check: VariantCheck::None, }); let field_exp = pattern_field_expr(&pattern.span, expr, &binder.a.typ, field_op); let pattern_test = pattern_to_exprs_rec(ctx, state, &field_exp, &binder.a, decls)?; @@ -344,6 +346,7 @@ fn simplify_one_expr( variant: variant.clone(), field: field.name.clone(), get_variant: false, + check: VariantCheck::None, }); let exprx = ExprX::UnaryOpr(op, update.clone()); let ty = subst_typ_for_datatype(&typ_positives, typ_args, &field.a.0); @@ -533,7 +536,13 @@ fn tuple_get_field_expr( let datatype = state.tuple_type_name(tuple_arity); let variant = prefix_tuple_variant(tuple_arity); let field = prefix_tuple_field(field); - let op = UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant: false }); + let op = UnaryOpr::Field(FieldOpr { + datatype, + variant, + field, + get_variant: false, + check: VariantCheck::None, + }); let field_expr = SpannedTyped::new(span, typ, ExprX::UnaryOpr(op, tuple_expr.clone())); field_expr } diff --git a/source/vir/src/ast_to_sst.rs b/source/vir/src/ast_to_sst.rs index 3fb1cd2f09..e8a1fa9396 100644 --- a/source/vir/src/ast_to_sst.rs +++ b/source/vir/src/ast_to_sst.rs @@ -1,7 +1,7 @@ use crate::ast::{ ArithOp, AssertQueryMode, AutospecUsage, BinaryOp, BitwiseOp, CallTarget, ComputeMode, - Constant, Expr, ExprX, Fun, Function, Ident, LoopInvariantKind, Mode, PatternX, SpannedTyped, - Stmt, StmtX, Typ, TypX, Typs, UnaryOp, UnaryOpr, VarAt, VirErr, + Constant, Expr, ExprX, FieldOpr, Fun, Function, Ident, LoopInvariantKind, Mode, PatternX, + SpannedTyped, Stmt, StmtX, Typ, TypX, Typs, UnaryOp, UnaryOpr, VarAt, VariantCheck, VirErr, }; use crate::ast::{BuiltinSpecFun, Exprs}; use crate::ast_util::{types_equal, undecorate_typ, QUANT_FORALL}; @@ -1224,8 +1224,36 @@ pub(crate) fn expr_to_stm_opt( Ok((stms, ReturnValue::Some(mk_exp(ExpX::Unary(*op, exp))))) } ExprX::UnaryOpr(op, expr) => { - let (stms, exp) = expr_to_stm_opt(ctx, state, expr)?; + let (mut stms, exp) = expr_to_stm_opt(ctx, state, expr)?; let exp = unwrap_or_return_never!(exp, stms); + match (op, state.checking_recommends(ctx)) { + ( + UnaryOpr::Field(FieldOpr { + datatype, + variant, + field: _, + get_variant: _, + check: VariantCheck::Yes, + }), + false, + ) => { + let unary = UnaryOpr::IsVariant { + datatype: datatype.clone(), + variant: variant.clone(), + }; + let is_variant = ExpX::UnaryOpr(unary, exp.clone()); + let is_variant = + SpannedTyped::new(&expr.span, &Arc::new(TypX::Bool), is_variant); + let error = crate::messages::error( + &expr.span, + "requirement not met: to access this field, the union must be in the correct variant", + ); + let assert = StmX::Assert(Some(error), is_variant); + let assert = Spanned::new(expr.span.clone(), assert); + stms.push(assert); + } + _ => {} + } Ok((stms, ReturnValue::Some(mk_exp(ExpX::UnaryOpr(op.clone(), exp))))) } ExprX::Binary(op, e1, e2) => { diff --git a/source/vir/src/modes.rs b/source/vir/src/modes.rs index db4d0dfe2f..24ac52a941 100644 --- a/source/vir/src/modes.rs +++ b/source/vir/src/modes.rs @@ -325,7 +325,7 @@ fn get_var_loc_mode( *to_mode } ExprX::UnaryOpr( - UnaryOpr::Field(FieldOpr { datatype, variant: _, field, get_variant }), + UnaryOpr::Field(FieldOpr { datatype, variant: _, field, get_variant, check: _ }), rcvr, ) => { let rcvr_mode = @@ -637,7 +637,7 @@ fn check_expr_handle_mut_arg( return check_expr_handle_mut_arg(typing, outer_mode, e1); } ExprX::UnaryOpr( - UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant }), + UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant, check: _ }), e1, ) => { if *get_variant && typing.check_ghost_blocks && typing.block_ghostness == Ghost::Exec { @@ -647,7 +647,8 @@ fn check_expr_handle_mut_arg( let datatype = &typing.datatypes[datatype]; let field = get_field(&datatype.x.get_variant(variant).a, field); let field_mode = field.a.1; - let mode_read = mode_join(e1_mode_read, field_mode); + let mode_read = + if *get_variant { Mode::Spec } else { mode_join(e1_mode_read, field_mode) }; if let Some(e1_mode_write) = e1_mode_write { return Ok((mode_read, Some(mode_join(e1_mode_write, field_mode)))); } else { diff --git a/source/vir/src/poly.rs b/source/vir/src/poly.rs index 51c14517d6..ae2efc6fda 100644 --- a/source/vir/src/poly.rs +++ b/source/vir/src/poly.rs @@ -422,7 +422,13 @@ fn poly_expr(ctx: &Ctx, state: &mut State, expr: &Expr) -> Expr { let exprx = ExprX::UnaryOpr(op.clone(), e1.clone()); SpannedTyped::new(&e1.span, &e1.typ, exprx) } - UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant: _ }) => { + UnaryOpr::Field(FieldOpr { + datatype, + variant, + field, + get_variant: _, + check: _, + }) => { let fields = &ctx.datatype_map[datatype].x.get_variant(variant).a; let field = crate::ast_util::get_field(fields, field); diff --git a/source/vir/src/sst_to_air.rs b/source/vir/src/sst_to_air.rs index bb17b33132..b3705e2a23 100644 --- a/source/vir/src/sst_to_air.rs +++ b/source/vir/src/sst_to_air.rs @@ -1,7 +1,8 @@ use crate::ast::{ ArithOp, AssertQueryMode, BinaryOp, BitwiseOp, FieldOpr, Fun, Ident, Idents, InequalityOp, IntRange, IntegerTypeBoundKind, InvAtomicity, MaskSpec, Mode, Params, Path, PathX, Primitive, - SpannedTyped, Typ, TypDecoration, TypX, Typs, UnaryOp, UnaryOpr, VarAt, VirErr, Visibility, + SpannedTyped, Typ, TypDecoration, TypX, Typs, UnaryOp, UnaryOpr, VarAt, VariantCheck, VirErr, + Visibility, }; use crate::ast_util::{ bitwidth_from_type, fun_as_friendly_rust_name, get_field, get_variant, undecorate_typ, @@ -862,7 +863,7 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< let name = Arc::new(ARCH_SIZE.to_string()); Arc::new(ExprX::Var(name)) } - UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant: _ }) => { + UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant: _, check: _ }) => { let expr = exp_to_expr(ctx, exp, expr_ctxt)?; Arc::new(ExprX::Apply( variant_field_ident(datatype, variant, field), @@ -1323,7 +1324,7 @@ fn assume_other_fields_unchanged_inner( [f] if f.len() == 0 => Ok(vec![]), _ => { let mut updated_fields: BTreeMap<_, Vec<_>> = BTreeMap::new(); - let FieldOpr { datatype, variant, field: _, get_variant: _ } = &updates[0][0]; + let FieldOpr { datatype, variant, field: _, get_variant: _, check: _ } = &updates[0][0]; for u in updates { assert!(u[0].datatype == *datatype && u[0].variant == *variant); updated_fields.entry(&u[0].field).or_insert(Vec::new()).push(u[1..].to_vec()); @@ -1353,6 +1354,7 @@ fn assume_other_fields_unchanged_inner( variant: variant.clone(), field: field.name.clone(), get_variant: false, + check: VariantCheck::None, }), base_exp, ), diff --git a/source/vir/src/triggers_auto.rs b/source/vir/src/triggers_auto.rs index 2f18766d28..3c3d1b6b92 100644 --- a/source/vir/src/triggers_auto.rs +++ b/source/vir/src/triggers_auto.rs @@ -359,7 +359,7 @@ fn gather_terms(ctxt: &mut Ctxt, ctx: &Ctx, exp: &Exp, depth: u64) -> (bool, Ter panic!("internal error: TupleField should have been removed before here") } ExpX::UnaryOpr( - UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant: _ }), + UnaryOpr::Field(FieldOpr { datatype, variant, field, get_variant: _, check: _ }), lhs, ) => { let (is_pure, arg) = gather_terms(ctxt, ctx, lhs, depth + 1); diff --git a/source/vir/src/well_formed.rs b/source/vir/src/well_formed.rs index 1f7102f412..5dc9f35716 100644 --- a/source/vir/src/well_formed.rs +++ b/source/vir/src/well_formed.rs @@ -268,7 +268,7 @@ fn check_one_expr( } } ExprX::UnaryOpr( - UnaryOpr::Field(FieldOpr { datatype: path, variant, field, get_variant: _ }), + UnaryOpr::Field(FieldOpr { datatype: path, variant, field, get_variant: _, check: _ }), _, ) => { if let Some(dt) = ctxt.dts.get(path) { From ede3736dd9dfa8a89a7977ed6a92f3d810801e7d Mon Sep 17 00:00:00 2001 From: Andrea Lattuada Date: Tue, 9 Jan 2024 12:20:38 +0100 Subject: [PATCH 05/11] Fix #955, fix #958: in lifetime_generate, associated type declarations need bounds --- source/rust_verify/src/lifetime_ast.rs | 2 +- source/rust_verify/src/lifetime_emit.rs | 18 ++++-- source/rust_verify/src/lifetime_generate.rs | 68 +++++++++++++++++---- source/rust_verify_test/tests/traits.rs | 40 ++++++++++++ 4 files changed, 111 insertions(+), 17 deletions(-) diff --git a/source/rust_verify/src/lifetime_ast.rs b/source/rust_verify/src/lifetime_ast.rs index ff109f6b7c..a88be390e5 100644 --- a/source/rust_verify/src/lifetime_ast.rs +++ b/source/rust_verify/src/lifetime_ast.rs @@ -158,7 +158,7 @@ pub(crate) struct TraitDecl { pub(crate) name: Id, pub(crate) generic_params: Vec, pub(crate) generic_bounds: Vec, - pub(crate) assoc_typs: Vec, + pub(crate) assoc_typs: Vec<(Id, Vec)>, } #[derive(Debug, PartialEq, Eq, Hash)] diff --git a/source/rust_verify/src/lifetime_emit.rs b/source/rust_verify/src/lifetime_emit.rs index d0d4f4058b..4bf6cfcb2e 100644 --- a/source/rust_verify/src/lifetime_emit.rs +++ b/source/rust_verify/src/lifetime_emit.rs @@ -670,10 +670,12 @@ fn emit_generic_params(state: &mut EmitState, generics: &Vec) { } } -fn emit_generic_bound(bound: &GenericBound) -> String { +fn emit_generic_bound(bound: &GenericBound, bare: bool) -> String { let mut buf = String::new(); - buf += &bound.typ.to_string(); - buf += ": "; + if !bare { + buf += &bound.typ.to_string(); + buf += ": "; + } if !bound.bound_vars.is_empty() { buf += "for<"; for b in bound.bound_vars.iter() { @@ -713,7 +715,7 @@ fn emit_generic_bounds(state: &mut EmitState, bounds: &Vec) { if bounds.len() > 0 { state.write(" where "); for bound in bounds.iter() { - state.write(emit_generic_bound(bound)); + state.write(emit_generic_bound(bound, false)); state.write(", "); } } @@ -837,10 +839,16 @@ pub(crate) fn emit_trait_decl(state: &mut EmitState, t: &TraitDecl) { emit_generic_bounds(state, &t.generic_bounds); state.write(" {"); state.push_indent(); - for a in &t.assoc_typs { + for (a, bounds) in &t.assoc_typs { state.newline(); state.write("type "); state.write(a.to_string()); + if bounds.len() > 0 { + state.write(" : "); + let bounds_strs: Vec<_> = + bounds.iter().map(|bound| emit_generic_bound(bound, true)).collect(); + state.write(bounds_strs.join("+")); + } state.write(";"); } state.newline_unindent(); diff --git a/source/rust_verify/src/lifetime_generate.rs b/source/rust_verify/src/lifetime_generate.rs index 2418ca63ce..9827d6f47f 100644 --- a/source/rust_verify/src/lifetime_generate.rs +++ b/source/rust_verify/src/lifetime_generate.rs @@ -97,6 +97,7 @@ pub(crate) struct State { // impl -> (t1, ..., tn) and process impl when t1...tn is empty remaining_typs_needed_for_each_impl: HashMap)>, enclosing_fun_id: Option, + enclosing_trait_ids: Vec, } impl State { @@ -122,6 +123,7 @@ impl State { typs_used_in_trait_impls_reverse_map: HashMap::new(), remaining_typs_needed_for_each_impl: HashMap::new(), enclosing_fun_id: None, + enclosing_trait_ids: Vec::new(), } } @@ -493,11 +495,19 @@ fn erase_ty<'tcx>(ctxt: &Context<'tcx>, state: &mut State, ty: &Ty<'tcx>) -> Typ let assoc_item = ctxt.tcx.associated_item(t.def_id); let name = state.typ_param(assoc_item.name.to_string(), None); let trait_def = ctxt.tcx.generics_of(t.def_id).parent; - match (trait_def, t.args.into_type_list(ctxt.tcx)) { - (Some(trait_def), typs) if typs.len() >= 1 => { + if let Some(trait_def) = trait_def { + // TODO: ignoring non-type arguments is probably not okay here + let typs = t.args.iter().filter_map(|ta| ta.as_type()).collect::>(); + if typs.len() >= 1 { let trait_path_vir = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, trait_def); erase_trait(ctxt, state, trait_def); - assert!(state.trait_decl_set.contains(&trait_path_vir)); + // If the type being erased is in one of the definitions of the trait it references, + // do not expect it to be in the `trait_decl_set`: we are in the process of erasing + // this very trait. + assert!( + state.enclosing_trait_ids.contains(&trait_def) + || state.trait_decl_set.contains(&trait_path_vir) + ); let trait_path = state.trait_name(&trait_path_vir); let mut typs_iter = typs.iter(); let self_ty = typs_iter.next().unwrap(); @@ -510,8 +520,11 @@ fn erase_ty<'tcx>(ctxt: &Context<'tcx>, state: &mut State, ty: &Ty<'tcx>) -> Typ let trait_as_datatype = Box::new(TypX::Datatype(trait_path, Vec::new(), trait_typ_args)); Box::new(TypX::Projection { self_typ, trait_as_datatype, name }) + } else { + panic!("unexpected TyKind::Alias"); } - _ => panic!("unexpected TyKind::Alias"), + } else { + panic!("unexpected TyKind::Alias"); } } TyKind::Closure(..) => Box::new(TypX::Closure), @@ -1637,9 +1650,25 @@ fn erase_mir_generics<'tcx>( } } } + erase_mir_predicates( + ctxt, + state, + mir_predicates.predicates.iter().map(|(c, _)| *c), + generic_bounds, + ); +} + +fn erase_mir_predicates<'a, 'tcx>( + ctxt: &Context<'tcx>, + state: &'a mut State, + mir_predicates: impl Iterator>, + generic_bounds: &mut Vec, +) where + 'tcx: 'a, +{ let mut fn_traits: Vec<(Typ, Vec, ClosureKind)> = Vec::new(); let mut fn_projections: HashMap = HashMap::new(); - for (pred, _) in mir_predicates.predicates.iter() { + for pred in mir_predicates { match (pred.kind().skip_binder(), &pred.kind().bound_vars()[..]) { (ClauseKind::RegionOutlives(pred), &[]) => { let x = erase_hir_region(ctxt, state, &pred.0).expect("bound"); @@ -1996,20 +2025,35 @@ fn erase_trait<'tcx>(ctxt: &Context<'tcx>, state: &mut State, trait_id: DefId) { return; } } - let mut assoc_typs: Vec = Vec::new(); + + state.enclosing_trait_ids.push(trait_id); + + let path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, trait_id); + + let mut assoc_typs: Vec<(Id, Vec)> = Vec::new(); let assoc_items = ctxt.tcx.associated_items(trait_id); for assoc_item in assoc_items.in_definition_order() { match assoc_item.kind { rustc_middle::ty::AssocKind::Const => {} rustc_middle::ty::AssocKind::Fn => {} rustc_middle::ty::AssocKind::Type => { - assoc_typs.push(state.typ_param(assoc_item.name.to_ident_string(), None)); + let mir_predicates = ctxt.tcx.item_bounds(assoc_item.def_id); + let mut generic_bounds = Vec::new(); + erase_mir_predicates( + ctxt, + state, + mir_predicates.skip_binder().iter(), + &mut generic_bounds, + ); + assoc_typs.push(( + state.typ_param(assoc_item.name.to_ident_string(), None), + generic_bounds, + )); } } } // We only need traits with associated type declarations. if assoc_typs.len() > 0 { - let path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, trait_id); let name = state.trait_name(&path); let mut lifetimes: Vec = Vec::new(); let mut typ_params: Vec = Vec::new(); @@ -2050,9 +2094,11 @@ fn erase_trait<'tcx>(ctxt: &Context<'tcx>, state: &mut State, trait_id: DefId) { erase_impl_assocs(ctxt, state, impl_id); } } + + assert!(state.enclosing_trait_ids.pop().is_some()); } -fn erase_trait_methods<'tcx>( +fn erase_trait_item<'tcx>( krate: &'tcx Crate<'tcx>, ctxt: &mut Context<'tcx>, state: &mut State, @@ -2082,7 +2128,7 @@ fn erase_trait_methods<'tcx>( body_id, ); } - TraitItemKind::Type(_, None) => {} + TraitItemKind::Type(_bounds, None) => {} _ => panic!("unexpected trait item"), } } @@ -2516,7 +2562,7 @@ pub(crate) fn gen_check_tracked_lifetimes<'tcx>( if vattrs.is_external(&ctxt.cmd_line_args) { continue; } - erase_trait_methods(krate, &mut ctxt, &mut state, id, trait_items); + erase_trait_item(krate, &mut ctxt, &mut state, id, trait_items); } ItemKind::Impl(impll) => { if vattrs.is_external(&ctxt.cmd_line_args) { diff --git a/source/rust_verify_test/tests/traits.rs b/source/rust_verify_test/tests/traits.rs index 8870cac08b..cffe85e519 100644 --- a/source/rust_verify_test/tests/traits.rs +++ b/source/rust_verify_test/tests/traits.rs @@ -882,6 +882,46 @@ test_verify_one_file! { } => Err(err) => assert_vir_error_msg(err, "found a cyclic self-reference in a trait definition") } +test_verify_one_file! { + #[test] test_assoc_bounds_2_pass verus_code! { + trait Z { type Y; } + trait T { + type X: Z; + + fn val() -> ::Y; + } + struct ZZ { } + impl Z for ZZ { + type Y = u64; + } + struct TT { } + impl T for TT { + type X = ZZ; + + fn val() -> ::Y { + 3 + } + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_assoc_bounds_regression_955 verus_code! { + use vstd::prelude::View; + + pub trait A { + type Input: View; + type Output: View; + } + + pub trait B { + type MyA: A; + + fn foo(input: ::Input) -> ::Output; + } + } => Ok(()) +} + test_verify_one_file! { #[test] test_termination_assoc_bounds_fail_3 verus_code! { trait Z { type Y; } From 1e62c4b74551b31d9a06831ff3c2e14fd40cf5dc Mon Sep 17 00:00:00 2001 From: Travis Hance Date: Thu, 18 Jan 2024 10:59:04 -0500 Subject: [PATCH 06/11] error for 'old' in exec-code, fixes #922 --- source/rust_verify_test/tests/modes.rs | 8 ++++++++ source/vir/src/modes.rs | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/source/rust_verify_test/tests/modes.rs b/source/rust_verify_test/tests/modes.rs index a855d883f0..d3048a6c19 100644 --- a/source/rust_verify_test/tests/modes.rs +++ b/source/rust_verify_test/tests/modes.rs @@ -1447,3 +1447,11 @@ test_verify_one_file! { } } => Ok(()) } + +test_verify_one_file! { + #[test] old_in_exec_mode_issue922 verus_code! { + fn stuff(x: &mut u8) { + let y = *old(x); + } + } => Err(err) => assert_vir_error_msg(err, "cannot use `old` in exec-code") +} diff --git a/source/vir/src/modes.rs b/source/vir/src/modes.rs index 24ac52a941..b6aca149de 100644 --- a/source/vir/src/modes.rs +++ b/source/vir/src/modes.rs @@ -412,6 +412,13 @@ fn check_expr_handle_mut_arg( let x_mode = typing.get(x).1; + if typing.check_ghost_blocks + && typing.block_ghostness == Ghost::Exec + && matches!(&expr.x, ExprX::VarAt(..)) + { + return Err(error(&expr.span, &format!("cannot use `old` in exec-code"))); + } + if typing.check_ghost_blocks && typing.block_ghostness == Ghost::Exec && x_mode != Mode::Exec From 5b406c2a704bd211eae470dfa311a1ca7db31c1e Mon Sep 17 00:00:00 2001 From: Travis Hance Date: Thu, 18 Jan 2024 11:26:19 -0500 Subject: [PATCH 07/11] nicer error message on crate deserialization failure --- source/rust_verify/src/import_export.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/source/rust_verify/src/import_export.rs b/source/rust_verify/src/import_export.rs index f2aadd611c..55a8bb196e 100644 --- a/source/rust_verify/src/import_export.rs +++ b/source/rust_verify/src/import_export.rs @@ -39,8 +39,14 @@ pub(crate) fn import_crates(args: &Args) -> Result { )); } }); - let CrateWithMetadata { krate, metadata } = - bincode::deserialize_from(file).expect("read crate from file"); + let CrateWithMetadata { krate, metadata } = match bincode::deserialize_from(file) { + Ok(crate_with_metadata) => crate_with_metadata, + Err(_e) => { + return Err(crate::util::error(format!( + "failed to deserialize imported library file `{file_path}` - it may need to be rebuilt by Verus" + ))); + } + }; // let libcrate: Krate = serde_json::from_reader(file).expect("read crate from file"); // We can also look at other packages: https://github.com/djkoloski/rust_serialization_benchmark vir_crates.push(krate); From 5a2358895bed13b79f37a488a688726106f6461a Mon Sep 17 00:00:00 2001 From: Chris Hawblitzel Date: Fri, 19 Jan 2024 18:55:33 -0800 Subject: [PATCH 08/11] Improve error messages for detected cycles (print one cycle, not whole SCC) --- source/rust_verify_test/tests/traits.rs | 13 ++++++++++ source/vir/src/func_to_air.rs | 2 +- source/vir/src/recursive_types.rs | 32 ++++++++++++++----------- source/vir/src/scc.rs | 22 ++++++++--------- 4 files changed, 43 insertions(+), 26 deletions(-) diff --git a/source/rust_verify_test/tests/traits.rs b/source/rust_verify_test/tests/traits.rs index cffe85e519..2fee78b8f2 100644 --- a/source/rust_verify_test/tests/traits.rs +++ b/source/rust_verify_test/tests/traits.rs @@ -667,6 +667,19 @@ test_verify_one_file! { } => Err(err) => assert_vir_error_msg(err, "found a cyclic self-reference in a trait definition") } +test_verify_one_file! { + #[test] test_termination_5_fail_8 verus_code! { + trait T { type A: T; } + } => Err(err) => assert_vir_error_msg(err, "found a cyclic self-reference in a trait definition") +} + +test_verify_one_file! { + #[test] test_termination_5_fail_9 verus_code! { + trait T1 { type A: T2; } + trait T2 { type A: T1; } + } => Err(err) => assert_vir_error_msg(err, "found a cyclic self-reference in a trait definition") +} + test_verify_one_file! { #[ignore] #[test] test_termination_bounds_1 verus_code! { trait T { diff --git a/source/vir/src/func_to_air.rs b/source/vir/src/func_to_air.rs index d54e241f53..eb4495e91b 100644 --- a/source/vir/src/func_to_air.rs +++ b/source/vir/src/func_to_air.rs @@ -298,7 +298,7 @@ fn func_body_to_air( // Example: f calls g, g calls f, so shortest cycle f --> g --> f has len 2 // We use this as the minimum default fuel for f let fun_node = crate::recursion::Node::Fun(function.x.name.clone()); - let cycle_len = ctx.global.func_call_graph.shortest_cycle_back_to_self(&fun_node); + let cycle_len = ctx.global.func_call_graph.shortest_cycle_back_to_self(&fun_node).len(); assert!(cycle_len >= 1); let rec_f = suffix_global_id(&fun_to_air_ident(&prefix_recursive_fun(&name))); diff --git a/source/vir/src/recursive_types.rs b/source/vir/src/recursive_types.rs index e58df2366c..d00c8f65af 100644 --- a/source/vir/src/recursive_types.rs +++ b/source/vir/src/recursive_types.rs @@ -207,7 +207,7 @@ fn check_positive_uses( let impl_node = TypNode::TraitImpl(impl_path.clone()); if global.type_graph.in_same_scc(&impl_node, &my_node) { let scc_rep = global.type_graph.get_scc_rep(&my_node); - let scc_nodes = global.type_graph.get_scc_nodes(&scc_rep); + let scc_nodes = global.type_graph.shortest_cycle_back_to_self(&scc_rep); return Err(type_scc_error(&global.krate, &my_node, &scc_nodes)); } } @@ -341,7 +341,11 @@ pub(crate) fn check_recursive_types(krate: &Krate) -> Result<(), VirErr> { for node in &global.type_graph.get_scc_nodes(scc) { match node { TypNode::TraitImpl(_) if global.type_graph.node_is_in_cycle(node) => { - return Err(type_scc_error(krate, node, &global.type_graph.get_scc_nodes(scc))); + return Err(type_scc_error( + krate, + node, + &global.type_graph.shortest_cycle_back_to_self(scc), + )); } _ => {} } @@ -378,14 +382,13 @@ fn type_scc_error(krate: &Krate, head: &TypNode, nodes: &Vec) -> VirErr "found a cyclic self-reference in a trait definition, which may result in nontermination" .to_string(); let mut err = crate::messages::error_bare(msg); - for node in nodes { + for (i, node) in nodes.iter().enumerate() { let mut push = |node: &TypNode, span: Span| { if node == head { err = err.primary_span(&span); - } else { - let msg = "may be part of cycle".to_string(); - err = err.secondary_label(&span, msg); } + let msg = format!("may be part of cycle (node {} of {} in cycle)", i + 1, nodes.len()); + err = err.secondary_label(&span, msg); }; match node { TypNode::Datatype(path) => { @@ -410,14 +413,13 @@ fn scc_error(krate: &Krate, head: &Node, nodes: &Vec) -> VirErr { "found a cyclic self-reference in a trait definition, which may result in nontermination" .to_string(); let mut err = crate::messages::error_bare(msg); - for node in nodes { + for (i, node) in nodes.iter().enumerate() { let mut push = |node: &Node, span: Span| { if node == head { err = err.primary_span(&span); - } else { - let msg = "may be part of cycle".to_string(); - err = err.secondary_label(&span, msg); } + let msg = format!("may be part of cycle (node {} of {} in cycle)", i + 1, nodes.len()); + err = err.secondary_label(&span, msg); }; match node { Node::Fun(fun) => { @@ -488,7 +490,6 @@ pub(crate) fn add_trait_to_graph(call_graph: &mut Graph, trt: &Trait) { let t_node = Node::Trait(t_path.clone()); for bound in trt.x.typ_bounds.iter().chain(trt.x.assoc_typs_bounds.iter()) { let GenericBoundX::Trait(u_path, _) = &**bound; - assert_ne!(t_path, u_path); let u_node = Node::Trait(u_path.clone()); call_graph.add_edge(t_node.clone(), u_node); } @@ -640,13 +641,16 @@ pub fn check_traits(krate: &Krate, ctx: &GlobalCtx) -> Result<(), VirErr> { // 2) Check function definitions using value dictionaries for scc in ctx.func_call_sccs.iter() { - let scc_nodes = ctx.func_call_graph.get_scc_nodes(scc); - for node in scc_nodes.iter() { + for node in ctx.func_call_graph.get_scc_nodes(scc).iter() { match node { // handled by decreases checking: Node::Fun(_) => {} _ if ctx.func_call_graph.node_is_in_cycle(node) => { - return Err(scc_error(krate, node, &scc_nodes)); + return Err(scc_error( + krate, + node, + &ctx.func_call_graph.shortest_cycle_back_to_self(node), + )); } _ => {} } diff --git a/source/vir/src/scc.rs b/source/vir/src/scc.rs index 5b2dadf959..a93c2e40ef 100644 --- a/source/vir/src/scc.rs +++ b/source/vir/src/scc.rs @@ -223,31 +223,31 @@ impl Graph { scc.nodes.iter().map(|i| self.nodes[*i].t.clone()).collect() } - pub fn shortest_cycle_back_to_self(&self, t: &T) -> usize { + pub fn shortest_cycle_back_to_self(&self, t: &T) -> Vec { assert!(self.has_run); assert!(self.h.contains_key(&t)); let root: NodeIndex = *self.h.get(t).expect("key not present"); - let mut at_depth: Vec = vec![root]; + let mut paths_at_depth: Vec> = vec![vec![root]]; let mut reached: HashSet = HashSet::new(); reached.insert(root); - let mut depth: usize = 1; loop { - let mut at_next_depth: Vec = Vec::new(); - assert!(at_depth.len() != 0); - for i in at_depth.into_iter() { - for edge in self.nodes[i].edges.iter() { + let mut paths_at_next_depth: Vec> = Vec::new(); + assert!(paths_at_depth.len() != 0); + for p in paths_at_depth.into_iter() { + for edge in self.nodes[*p.last().expect("path")].edges.iter() { if *edge == root { // reached the root, found cycle - return depth; + return p.into_iter().map(|i| self.nodes[i].t.clone()).collect(); } if !reached.contains(edge) { reached.insert(*edge); - at_next_depth.push(*edge); + let mut p_edge = p.clone(); + p_edge.push(*edge); + paths_at_next_depth.push(p_edge); } } } - depth += 1; - at_depth = at_next_depth; + paths_at_depth = paths_at_next_depth; } } } From f74c5d43a99039d830d9dce404c04ae1958eb3cc Mon Sep 17 00:00:00 2001 From: tjhance Date: Sun, 21 Jan 2024 01:31:51 -0500 Subject: [PATCH 09/11] tweak broadcast heuristic (#965) --- .../tests/broadcast_forall.rs | 46 +++++++++++++++++++ source/vir/src/context.rs | 2 +- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/source/rust_verify_test/tests/broadcast_forall.rs b/source/rust_verify_test/tests/broadcast_forall.rs index a1ca01b704..0dcdfdc5be 100644 --- a/source/rust_verify_test/tests/broadcast_forall.rs +++ b/source/rust_verify_test/tests/broadcast_forall.rs @@ -78,3 +78,49 @@ test_verify_one_file! { } } => Err(err) => assert_vir_error_msg(err, "cannot recursively reveal broadcast_forall") } + +test_verify_one_file! { + #[test] test_sm verus_code! { + // This tests the fix for an issue with the heuristic for pushing broadcast_forall + // functions to the front. + // Specifically, the state_machine! macro generates some external_body functions + // which got pushed to the front by those heurstics. But those external_body functions + // depended on the the proof fn `stuff_inductive` (via the extra_dependencies mechanism) + // This caused the `stuff_inductive` to end up BEFORE the broadcast_forall function + // it needed. + + use vstd::*; + use state_machines_macros::*; + + pub spec fn f() -> bool; + + #[verifier::external_body] + #[verifier::broadcast_forall] + proof fn f_is_true() + ensures f(), + { + } + + state_machine!{ X { + fields { + pub a: u8, + } + + transition!{ + stuff() { + update a = 5; + } + } + + #[invariant] + pub spec fn inv(&self) -> bool { + true + } + + #[inductive(stuff)] + fn stuff_inductive(pre: Self, post: Self) { + assert(f()); + } + }} + } => Ok(()) +} diff --git a/source/vir/src/context.rs b/source/vir/src/context.rs index 5c0cc35340..9bae901d58 100644 --- a/source/vir/src/context.rs +++ b/source/vir/src/context.rs @@ -218,7 +218,7 @@ impl GlobalCtx { // This is currently needed because external_body broadcast_forall functions // are currently implicitly imported. // In the future, this might become less important; we could remove this heuristic. - if f.x.body.is_none() { + if f.x.body.is_none() && f.x.extra_dependencies.len() == 0 { func_call_graph.add_node(Node::Fun(f.x.name.clone())); } } From 52ae96fd1234a5c6a3841e147d4196911e547b2f Mon Sep 17 00:00:00 2001 From: Travis Hance Date: Sun, 21 Jan 2024 01:43:00 -0500 Subject: [PATCH 10/11] cleanup body_stm_to_air parameters - package post-condition-related stuff into a struct - apply trait_typ_substs before calling body_stm_to_air --- source/vir/src/func_to_air.rs | 41 +++++++++++++++------------ source/vir/src/recursion.rs | 21 +++++++------- source/vir/src/sst_to_air.rs | 52 ++++++++++++++--------------------- 3 files changed, 56 insertions(+), 58 deletions(-) diff --git a/source/vir/src/func_to_air.rs b/source/vir/src/func_to_air.rs index eb4495e91b..31c2f0c2f9 100644 --- a/source/vir/src/func_to_air.rs +++ b/source/vir/src/func_to_air.rs @@ -15,9 +15,11 @@ use crate::def::{ use crate::inv_masks::MaskSet; use crate::messages::{error, Message, MessageLabel, Span}; use crate::sst::{BndX, Exp, ExpX, Par, ParPurpose, ParX, Pars, Stm, StmX}; +use crate::sst_to_air::PostConditionSst; use crate::sst_to_air::{ exp_to_expr, fun_to_air_ident, typ_invariant, typ_to_air, typ_to_ids, ExprCtxt, ExprMode, }; +use crate::sst_util::{subst_exp, subst_stm}; use crate::update_cell::UpdateCell; use crate::util::vec_map; use air::ast::{ @@ -885,19 +887,22 @@ pub fn func_def_to_air( for e in req_ens_function.x.require.iter() { let e_with_req_ens_params = map_expr_rename_vars(e, &req_ens_e_rename)?; if ctx.checking_spec_preconditions() { + // TODO: apply trait_typs_substs here? let (stms, exp) = crate::ast_to_sst::expr_to_pure_exp_check(ctx, &mut state, &e_with_req_ens_params)?; req_stms.extend(stms); req_stms.push(Spanned::new(exp.span.clone(), StmX::Assume(exp))); } else { // skip checks because we call expr_to_pure_exp_check above - reqs.push(crate::ast_to_sst::expr_to_exp_skip_checks( + let exp = crate::ast_to_sst::expr_to_exp_skip_checks( ctx, diagnostics, &state.fun_ssts, &req_pars, &e_with_req_ens_params, - )?); + )?; + let exp = subst_exp(&trait_typ_substs, &HashMap::new(), &exp); + reqs.push(exp); } } @@ -944,25 +949,28 @@ pub fn func_def_to_air( } let mut ens_spec_precondition_stms: Vec = Vec::new(); let mut enss: Vec = Vec::new(); - let mut enss_inherit: Vec = Vec::new(); if inherit { for e in req_ens_function.x.ensure.iter() { let e_with_req_ens_params = map_expr_rename_vars(e, &req_ens_e_rename)?; if ctx.checking_spec_preconditions() { - ens_spec_precondition_stms.extend(crate::ast_to_sst::check_pure_expr( - ctx, - &mut state, - &e_with_req_ens_params, - )?); + let stms = + crate::ast_to_sst::check_pure_expr(ctx, &mut state, &e_with_req_ens_params)?; + let stms: Vec<_> = stms + .iter() + .map(|stm| subst_stm(&trait_typ_substs, &HashMap::new(), &stm)) + .collect(); + ens_spec_precondition_stms.extend(stms); } else { // skip checks because we call expr_to_pure_exp_check above - enss_inherit.push(crate::ast_to_sst::expr_to_exp_skip_checks( + let exp = crate::ast_to_sst::expr_to_exp_skip_checks( ctx, diagnostics, &state.fun_ssts, &ens_pars, &e_with_req_ens_params, - )?); + )?; + let exp = subst_exp(&trait_typ_substs, &HashMap::new(), &exp); + enss.push(exp); } } } @@ -981,7 +989,6 @@ pub fn func_def_to_air( )?); } } - let enss = Arc::new(enss); // AST --> SST let mut stm = crate::ast_to_sst::expr_to_one_stm_with_post(&ctx, &mut state, &body)?; @@ -1042,22 +1049,22 @@ pub fn func_def_to_air( let (commands, snap_map) = crate::sst_to_air::body_stm_to_air( ctx, &function.span, - &trait_typ_substs, &function.x.typ_params, &function.x.params, &state.local_decls, &function.x.attrs.hidden, &reqs, - &enss, - &enss_inherit, - &ens_spec_precondition_stms, + &PostConditionSst { + dest, + ens_exps: enss, + ens_spec_precondition_stms, + kind: PostConditionKind::Ensures, + }, &mask_set, &stm, function.x.attrs.integer_ring, function.x.attrs.bit_vector, function.x.attrs.nonlinear, - dest, - PostConditionKind::Ensures, &state.statics.iter().cloned().collect(), )?; diff --git a/source/vir/src/recursion.rs b/source/vir/src/recursion.rs index fdbe192cbb..8db7636dee 100644 --- a/source/vir/src/recursion.rs +++ b/source/vir/src/recursion.rs @@ -18,6 +18,7 @@ use crate::sst::{ UniqueIdent, }; use crate::sst_to_air::PostConditionKind; +use crate::sst_to_air::PostConditionSst; use crate::sst_visitor::{exp_rename_vars, map_exp_visitor, map_stm_visitor}; use crate::util::vec_map_result; use air::ast::Binder; @@ -286,26 +287,26 @@ pub(crate) fn check_termination_commands( let (commands, _snap_map) = crate::sst_to_air::body_stm_to_air( ctx, &function.span, - &HashMap::new(), &function.x.typ_params, &function.x.params, &Arc::new(local_decls), &Arc::new(vec![]), &Arc::new(vec![]), - &Arc::new(vec![]), - &Arc::new(vec![]), - &Arc::new(vec![]), + &PostConditionSst { + dest: None, + kind: if uses_decreases_by { + PostConditionKind::DecreasesBy + } else { + PostConditionKind::DecreasesImplicitLemma + }, + ens_exps: vec![], + ens_spec_precondition_stms: vec![], + }, &MaskSet::empty(), &stm_block, false, false, false, - None, - if uses_decreases_by { - PostConditionKind::DecreasesBy - } else { - PostConditionKind::DecreasesImplicitLemma - }, &vec![], )?; diff --git a/source/vir/src/sst_to_air.rs b/source/vir/src/sst_to_air.rs index b3705e2a23..8182a7b3a7 100644 --- a/source/vir/src/sst_to_air.rs +++ b/source/vir/src/sst_to_air.rs @@ -29,7 +29,6 @@ use crate::sst::{ BndInfo, BndInfoUser, BndX, CallFun, Dest, Exp, ExpX, InternalFun, LocalDecl, Stm, StmX, UniqueIdent, }; -use crate::sst_util::{subst_exp, subst_stm}; use crate::sst_vars::{get_loc_var, AssignMap}; use crate::util::{vec_map, vec_map_result}; use air::ast::{ @@ -42,7 +41,7 @@ use air::ast_util::{ mk_option_command, mk_or, mk_sub, mk_xor, str_apply, str_ident, str_typ, str_var, string_var, }; use num_bigint::BigInt; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, HashSet}; use std::mem::swap; use std::sync::Arc; @@ -1137,12 +1136,25 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< Ok(result) } -pub(crate) enum PostConditionKind { +#[derive(Clone, Copy)] +pub enum PostConditionKind { Ensures, DecreasesImplicitLemma, DecreasesBy, } +pub struct PostConditionSst { + /// Identifier that holds the return value. + /// May be referenced by `ens_exprs` or `ens_spec_precondition_stms`. + pub dest: Option, + /// Post-conditions (only used in non-recommends-checking mode) + pub ens_exps: Vec, + /// Recommends checks (only used in recommends-checking mode) + pub ens_spec_precondition_stms: Vec, + /// Extra info about PostCondition for error reporting + pub kind: PostConditionKind, +} + struct PostConditionInfo { /// Identifier that holds the return value. /// May be referenced by `ens_exprs` or `ens_spec_precondition_stms`. @@ -2218,22 +2230,17 @@ fn mk_static_prelude(ctx: &Ctx, statics: &Vec) -> Vec { pub(crate) fn body_stm_to_air( ctx: &Ctx, func_span: &Span, - trait_typ_substs: &HashMap, typ_params: &Idents, params: &Params, local_decls: &Vec, hidden: &Vec, reqs: &Vec, - enss: &Vec, - inherit_enss: &Vec, - ens_spec_precondition_stms: &Vec, + post_condition: &PostConditionSst, mask_set: &MaskSet, stm: &Stm, is_integer_ring: bool, is_bit_vector_mode: bool, is_nonlinear: bool, - dest: Option, - post_condition_kind: PostConditionKind, statics: &Vec, ) -> Result<(Vec, Vec<(Span, SnapPos)>), VirErr> { // Verifying a single function can generate multiple SMT queries. @@ -2287,7 +2294,7 @@ pub(crate) fn body_stm_to_air( let initial_sid = Arc::new("0_entry".to_string()); let mut ens_exprs: Vec<(Span, Expr)> = Vec::new(); - for ens in enss { + for ens in post_condition.ens_exps.iter() { let e = if is_bit_vector_mode { let bv_expr_ctxt = &BvExprCtxt::new(); bv_exp_to_expr(ctx, &ens, bv_expr_ctxt)? @@ -2297,22 +2304,6 @@ pub(crate) fn body_stm_to_air( }; ens_exprs.push((ens.span.clone(), e)); } - for ens in inherit_enss { - let ens = subst_exp(&trait_typ_substs, &HashMap::new(), ens); - let e = if is_bit_vector_mode { - let bv_expr_ctxt = &BvExprCtxt::new(); - bv_exp_to_expr(ctx, &ens, bv_expr_ctxt)? - } else { - let expr_ctxt = &ExprCtxt::new_mode(ExprMode::Body); - exp_to_expr(ctx, &ens, expr_ctxt)? - }; - ens_exprs.push((ens.span.clone(), e)); - } - - let ens_spec_precondition_stms: Vec<_> = ens_spec_precondition_stms - .iter() - .map(|ens_recommend_stm| subst_stm(&trait_typ_substs, &HashMap::new(), ens_recommend_stm)) - .collect(); let mut may_be_used_in_old = HashSet::::new(); for param in params.iter() { @@ -2332,10 +2323,10 @@ pub(crate) fn body_stm_to_air( assign_map: indexmap::IndexMap::new(), mask: mask_set.clone(), post_condition_info: PostConditionInfo { - dest, + dest: post_condition.dest.clone(), ens_exprs, - ens_spec_precondition_stms: ens_spec_precondition_stms.clone(), - kind: post_condition_kind, + ens_spec_precondition_stms: post_condition.ens_spec_precondition_stms.clone(), + kind: post_condition.kind, }, loop_infos: Vec::new(), static_prelude: mk_static_prelude(ctx, statics), @@ -2386,7 +2377,6 @@ pub(crate) fn body_stm_to_air( } for req in reqs { - let req = subst_exp(&trait_typ_substs, &HashMap::new(), req); let e = if is_bit_vector_mode { let bv_expr_ctxt = &BvExprCtxt::new(); bv_exp_to_expr(ctx, &req, bv_expr_ctxt)? @@ -2419,7 +2409,7 @@ pub(crate) fn body_stm_to_air( let assert_stm = Arc::new(StmtX::Assert(error, air_expr)); singular_stmts.push(assert_stm); } - for ens in enss { + for ens in post_condition.ens_exps.iter() { let error = error_with_label( &ens.span, "Failed to translate this expression into a singular query".to_string(), From 9e351c871e99a8df091d869ca0fa27311d715884 Mon Sep 17 00:00:00 2001 From: Chris Hawblitzel Date: Tue, 23 Jan 2024 12:48:13 -0800 Subject: [PATCH 11/11] Add more information to cycle detection error messages --- source/rust_verify/src/rust_to_vir.rs | 5 +- source/vir/src/ast.rs | 2 +- source/vir/src/context.rs | 19 ++++- source/vir/src/recursion.rs | 14 +++- source/vir/src/recursive_types.rs | 109 +++++++++++++++++++++----- 5 files changed, 121 insertions(+), 28 deletions(-) diff --git a/source/rust_verify/src/rust_to_vir.rs b/source/rust_verify/src/rust_to_vir.rs index de614dca1e..5460899c41 100644 --- a/source/rust_verify/src/rust_to_vir.rs +++ b/source/rust_verify/src/rust_to_vir.rs @@ -344,6 +344,7 @@ fn check_item<'tcx>( )?); } let types = Arc::new(types); + let path_span = path.span.to(impll.self_ty.span); let path = def_id_to_vir_path(ctxt.tcx, &ctxt.verus_items, path.res.def_id()); let (typ_params, typ_bounds) = crate::rust_to_vir_base::check_generics_bounds_fun( ctxt.tcx, @@ -358,7 +359,7 @@ fn check_item<'tcx>( typ_bounds, trait_path: path.clone(), trait_typ_args: types.clone(), - trait_typ_arg_impls: impl_paths, + trait_typ_arg_impls: ctxt.spanned_new(path_span, impl_paths), }; vir.trait_impls.push(ctxt.spanned_new(item.span, trait_impl)); Some((trait_ref, path, types)) @@ -498,7 +499,7 @@ fn check_item<'tcx>( typ, impl_paths: Arc::new(impl_paths), }; - vir.assoc_type_impls.push(ctxt.spanned_new(item.span, assocx)); + vir.assoc_type_impls.push(ctxt.spanned_new(impl_item.span, assocx)); } else { unsupported_err!( item.span, diff --git a/source/vir/src/ast.rs b/source/vir/src/ast.rs index 9181af5a4c..b2c8c72e4b 100644 --- a/source/vir/src/ast.rs +++ b/source/vir/src/ast.rs @@ -988,7 +988,7 @@ pub struct TraitImplX { pub typ_bounds: GenericBounds, pub trait_path: Path, pub trait_typ_args: Typs, - pub trait_typ_arg_impls: ImplPaths, + pub trait_typ_arg_impls: Arc>, } #[derive(Clone, Debug, Hash, Serialize, Deserialize, ToDebugSNode, PartialEq, Eq)] diff --git a/source/vir/src/context.rs b/source/vir/src/context.rs index 9bae901d58..f8e682e9de 100644 --- a/source/vir/src/context.rs +++ b/source/vir/src/context.rs @@ -41,6 +41,7 @@ pub struct GlobalCtx { pub func_call_graph: Arc>, pub func_call_sccs: Arc>, pub(crate) datatype_graph: Arc>, + pub(crate) datatype_graph_span_infos: Vec, /// Connects quantifier identifiers to the original expression pub qid_map: RefCell>, pub(crate) rlimit: f32, @@ -256,14 +257,24 @@ impl GlobalCtx { func_call_graph.add_node(Node::TraitImpl(t.x.impl_path.clone())); } + let mut span_infos: Vec = Vec::new(); for t in &krate.trait_impls { - crate::recursive_types::add_trait_impl_to_graph(&mut func_call_graph, t); + crate::recursive_types::add_trait_impl_to_graph( + &mut span_infos, + &mut func_call_graph, + t, + ); } for f in &krate.functions { fun_bounds.insert(f.x.name.clone(), f.x.typ_bounds.clone()); func_call_graph.add_node(Node::Fun(f.x.name.clone())); - crate::recursion::expand_call_graph(&func_map, &mut func_call_graph, f)?; + crate::recursion::expand_call_graph( + &func_map, + &mut func_call_graph, + &mut span_infos, + f, + )?; } func_call_graph.compute_sccs(); @@ -292,7 +303,7 @@ impl GlobalCtx { } let qid_map = RefCell::new(HashMap::new()); - let datatype_graph = crate::recursive_types::build_datatype_graph(krate); + let datatype_graph = crate::recursive_types::build_datatype_graph(krate, &mut span_infos); Ok(GlobalCtx { chosen_triggers, @@ -302,6 +313,7 @@ impl GlobalCtx { func_call_graph: Arc::new(func_call_graph), func_call_sccs: Arc::new(func_call_sccs), datatype_graph: Arc::new(datatype_graph), + datatype_graph_span_infos: span_infos, qid_map, rlimit, interpreter_log, @@ -322,6 +334,7 @@ impl GlobalCtx { no_span: self.no_span.clone(), func_call_graph: self.func_call_graph.clone(), datatype_graph: self.datatype_graph.clone(), + datatype_graph_span_infos: self.datatype_graph_span_infos.clone(), func_call_sccs: self.func_call_sccs.clone(), qid_map, rlimit: self.rlimit, diff --git a/source/vir/src/recursion.rs b/source/vir/src/recursion.rs index 8db7636dee..f30cd93949 100644 --- a/source/vir/src/recursion.rs +++ b/source/vir/src/recursion.rs @@ -33,6 +33,9 @@ pub enum Node { Datatype(Path), Trait(Path), TraitImpl(Path), + // This is used to replace an X --> Y edge with X --> SpanInfo --> Y edges + // to give more precise span information than X or Y alone provide + SpanInfo { span_infos_index: usize, text: String }, } #[derive(Clone)] @@ -420,6 +423,7 @@ pub(crate) fn check_termination_stm( pub(crate) fn expand_call_graph( func_map: &HashMap, call_graph: &mut Graph, + span_infos: &mut Vec, function: &Function, ) -> Result<(), VirErr> { // See recursive_types::check_traits for more documentation @@ -488,7 +492,15 @@ pub(crate) fn expand_call_graph( continue; } } - call_graph.add_edge(f_node.clone(), Node::TraitImpl(impl_path.clone())); + let expr_node = crate::recursive_types::new_span_info_node( + span_infos, + expr.span.clone(), + ": application of a function to some type arguments, which may depend on \ + other trait implementations to satisfy trait bounds" + .to_string(), + ); + call_graph.add_edge(f_node.clone(), expr_node.clone()); + call_graph.add_edge(expr_node.clone(), Node::TraitImpl(impl_path.clone())); } // f --> f2 diff --git a/source/vir/src/recursive_types.rs b/source/vir/src/recursive_types.rs index d00c8f65af..5df9a39db5 100644 --- a/source/vir/src/recursive_types.rs +++ b/source/vir/src/recursive_types.rs @@ -126,12 +126,16 @@ fn check_well_founded_typ( pub(crate) enum TypNode { Datatype(Path), TraitImpl(Path), + // This is used to replace an X --> Y edge with X --> SpanInfo --> Y edges + // to give more precise span information than X or Y alone provide + SpanInfo { span_infos_index: usize, text: String }, } struct CheckPositiveGlobal { krate: Krate, datatypes: HashMap, type_graph: Graph, + span_infos: Vec, } struct CheckPositiveLocal { @@ -140,6 +144,18 @@ struct CheckPositiveLocal { tparams: HashMap, } +pub(crate) fn new_span_info_node(span_infos: &mut Vec, span: Span, text: String) -> Node { + let node = Node::SpanInfo { span_infos_index: span_infos.len(), text }; + span_infos.push(span); + node +} + +fn new_span_info_typ_node(span_infos: &mut Vec, span: Span, text: String) -> TypNode { + let node = TypNode::SpanInfo { span_infos_index: span_infos.len(), text }; + span_infos.push(span); + node +} + // polarity = Some(true) for positive, Some(false) for negative, None for neither fn check_positive_uses( global: &CheckPositiveGlobal, @@ -208,7 +224,12 @@ fn check_positive_uses( if global.type_graph.in_same_scc(&impl_node, &my_node) { let scc_rep = global.type_graph.get_scc_rep(&my_node); let scc_nodes = global.type_graph.shortest_cycle_back_to_self(&scc_rep); - return Err(type_scc_error(&global.krate, &my_node, &scc_nodes)); + return Err(type_scc_error( + &global.krate, + &global.span_infos, + &my_node, + &scc_nodes, + )); } } Ok(()) @@ -256,7 +277,7 @@ fn add_one_type_to_graph(type_graph: &mut Graph, src: &TypNode, typ: &T } } -pub(crate) fn build_datatype_graph(krate: &Krate) -> Graph { +pub(crate) fn build_datatype_graph(krate: &Krate, span_infos: &mut Vec) -> Graph { let mut type_graph: Graph = Graph::new(); // If datatype D1 has a field whose type mentions datatype D2, create a graph edge D1 --> D2 @@ -270,7 +291,15 @@ pub(crate) fn build_datatype_graph(krate: &Krate) -> Graph { } for a in &krate.assoc_type_impls { - let src = TypNode::TraitImpl(a.x.impl_path.clone()); + let trait_impl_src = TypNode::TraitImpl(a.x.impl_path.clone()); + let src = new_span_info_typ_node( + span_infos, + a.span.clone(), + ": associated type definition, which may depend on other trait implementations \ + to satisfy type bounds" + .to_string(), + ); + type_graph.add_edge(trait_impl_src, src.clone()); for impl_path in a.x.impl_paths.iter() { let dst = TypNode::TraitImpl(impl_path.clone()); type_graph.add_edge(src.clone(), dst); @@ -284,11 +313,13 @@ pub(crate) fn build_datatype_graph(krate: &Krate) -> Graph { } type_graph.compute_sccs(); - return type_graph; + + type_graph } pub(crate) fn check_recursive_types(krate: &Krate) -> Result<(), VirErr> { - let type_graph = build_datatype_graph(krate); + let mut span_infos: Vec = Vec::new(); + let type_graph = build_datatype_graph(krate, &mut span_infos); let mut datatypes: HashMap = HashMap::new(); let mut datatypes_well_founded: HashSet = HashSet::new(); @@ -296,7 +327,7 @@ pub(crate) fn check_recursive_types(krate: &Krate) -> Result<(), VirErr> { datatypes.insert(datatype.x.path.clone(), datatype.clone()); } - let global = CheckPositiveGlobal { krate: krate.clone(), datatypes, type_graph }; + let global = CheckPositiveGlobal { krate: krate.clone(), datatypes, type_graph, span_infos }; for function in &krate.functions { if let FunctionKind::TraitMethodDecl { .. } = function.x.kind { @@ -343,6 +374,7 @@ pub(crate) fn check_recursive_types(krate: &Krate) -> Result<(), VirErr> { TypNode::TraitImpl(_) if global.type_graph.node_is_in_cycle(node) => { return Err(type_scc_error( krate, + &global.span_infos, node, &global.type_graph.shortest_cycle_back_to_self(scc), )); @@ -377,75 +409,96 @@ pub(crate) fn check_recursive_types(krate: &Krate) -> Result<(), VirErr> { Ok(()) } -fn type_scc_error(krate: &Krate, head: &TypNode, nodes: &Vec) -> VirErr { +fn type_scc_error( + krate: &Krate, + span_infos: &Vec, + head: &TypNode, + nodes: &Vec, +) -> VirErr { let msg = "found a cyclic self-reference in a trait definition, which may result in nontermination" .to_string(); let mut err = crate::messages::error_bare(msg); for (i, node) in nodes.iter().enumerate() { - let mut push = |node: &TypNode, span: Span| { + let mut push = |node: &TypNode, span: Span, text: &str| { if node == head { err = err.primary_span(&span); } - let msg = format!("may be part of cycle (node {} of {} in cycle)", i + 1, nodes.len()); + let msg = format!( + "may be part of cycle (node {} of {} in cycle){}", + i + 1, + nodes.len(), + text + ); err = err.secondary_label(&span, msg); }; match node { TypNode::Datatype(path) => { if let Some(d) = krate.datatypes.iter().find(|t| t.x.path == *path) { let span = d.span.clone(); - push(node, span); + push(node, span, ": type definition"); } } TypNode::TraitImpl(impl_path) => { if let Some(t) = krate.trait_impls.iter().find(|t| t.x.impl_path == *impl_path) { let span = t.span.clone(); - push(node, span); + push(node, span, ": implementation of trait for a type"); } } + TypNode::SpanInfo { span_infos_index, text } => { + push(node, span_infos[*span_infos_index].clone(), text); + } } } err } -fn scc_error(krate: &Krate, head: &Node, nodes: &Vec) -> VirErr { +fn scc_error(krate: &Krate, span_infos: &Vec, head: &Node, nodes: &Vec) -> VirErr { let msg = "found a cyclic self-reference in a trait definition, which may result in nontermination" .to_string(); let mut err = crate::messages::error_bare(msg); for (i, node) in nodes.iter().enumerate() { - let mut push = |node: &Node, span: Span| { + let mut push = |node: &Node, span: Span, text: &str| { if node == head { err = err.primary_span(&span); } - let msg = format!("may be part of cycle (node {} of {} in cycle)", i + 1, nodes.len()); + let msg = format!( + "may be part of cycle (node {} of {} in cycle){}", + i + 1, + nodes.len(), + text + ); err = err.secondary_label(&span, msg); }; match node { Node::Fun(fun) => { if let Some(f) = krate.functions.iter().find(|f| f.x.name == *fun) { let span = f.span.clone(); - push(node, span); + push(node, span, ": function definition, whose body may have dependencies"); } } Node::Datatype(path) => { if let Some(d) = krate.datatypes.iter().find(|t| t.x.path == *path) { let span = d.span.clone(); - push(node, span); + push(node, span, ": type definition"); } } Node::Trait(trait_path) => { if let Some(t) = krate.traits.iter().find(|t| t.x.name == *trait_path) { let span = t.span.clone(); - push(node, span); + push(node, span, ": declaration of trait"); } } Node::TraitImpl(impl_path) => { if let Some(t) = krate.trait_impls.iter().find(|t| t.x.impl_path == *impl_path) { let span = t.span.clone(); - push(node, span); + push(node, span, ": implementation of trait for a type"); } } + Node::SpanInfo { span_infos_index, text } => { + push(node, span_infos[*span_infos_index].clone(), text); + } } } err @@ -495,14 +548,27 @@ pub(crate) fn add_trait_to_graph(call_graph: &mut Graph, trt: &Trait) { } } -pub(crate) fn add_trait_impl_to_graph(call_graph: &mut Graph, t: &crate::ast::TraitImpl) { +pub(crate) fn add_trait_impl_to_graph( + span_infos: &mut Vec, + call_graph: &mut Graph, + t: &crate::ast::TraitImpl, +) { // For // trait T<...> where ...: U1(...), ..., ...: Un(...) // impl T for ... { ... } // Add necessary impl_T_for_* --> impl_Ui_for_* edges // This corresponds to instantiating the a: Dictionary_U field in the comments below - let src_node = Node::TraitImpl(t.x.impl_path.clone()); - for imp in t.x.trait_typ_arg_impls.iter() { + let trait_impl_src_node = Node::TraitImpl(t.x.impl_path.clone()); + let src_node = new_span_info_node( + span_infos, + t.x.trait_typ_arg_impls.span.clone(), + ": an implementation of a trait, applying the trait to some type arguments, \ + for some `Self` type, where applying the trait to type arguments and declaring \ + the `Self` type may depend on other trait implementations to satisfy type bounds" + .to_string(), + ); + call_graph.add_edge(trait_impl_src_node, src_node.clone()); + for imp in t.x.trait_typ_arg_impls.x.iter() { if &t.x.impl_path != imp { call_graph.add_edge(src_node.clone(), Node::TraitImpl(imp.clone())); } @@ -648,6 +714,7 @@ pub fn check_traits(krate: &Krate, ctx: &GlobalCtx) -> Result<(), VirErr> { _ if ctx.func_call_graph.node_is_in_cycle(node) => { return Err(scc_error( krate, + &ctx.datatype_graph_span_infos, node, &ctx.func_call_graph.shortest_cycle_back_to_self(node), ));