Skip to content

Commit

Permalink
Add els in StmtX::Decl && simplify the let-els in ast_simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
ziqiaozhou committed Jan 31, 2025
1 parent 0103a5d commit 68822c8
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 137 deletions.
166 changes: 43 additions & 123 deletions source/rust_verify/src/rust_to_vir_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,6 @@ pub(crate) fn pat_to_var<'tcx>(pat: &Pat) -> Result<VarIdent, VirErr> {
Ok(name)
}

pub(crate) fn pat_extract_binds<'tcx>(pat: &Pat<'tcx>) -> Result<Vec<Pat<'tcx>>, VirErr> {
let Pat { hir_id: _, kind, span, default_binding_modes } = pat;
unsupported_err_unless!(default_binding_modes, *span, "default_binding_modes");
match *kind {
PatKind::Binding(..) => Ok(vec![pat.clone()]),
PatKind::Struct(_, fields, _) => {
let mut vars = Vec::new();
for f in fields {
vars.extend(pat_extract_binds(&f.pat)?);
}
Ok(vars)
}
PatKind::TupleStruct(_, pats, ..) => {
let mut vars = Vec::new();
for p in pats {
vars.extend(pat_extract_binds(p)?);
}
Ok(vars)
}
_ => {
unsupported_err!(*span, "limited support for pat_extract_binds")
}
}
}

pub(crate) fn extract_array<'tcx>(expr: &'tcx Expr<'tcx>) -> Vec<&'tcx Expr<'tcx>> {
match &expr.kind {
ExprKind::Array(fields) => fields.iter().collect(),
Expand Down Expand Up @@ -2049,13 +2024,36 @@ pub(crate) fn expr_to_vir_innermost<'tcx>(
let cond = cond.peel_drop_temps();
match cond.kind {
ExprKind::Let(LetExpr { pat, init: expr, ty: _, span: _, recovered: _ }) => {
let lhs = expr_to_vir(bctx, &lhs, modifier)?;
let rhs = if let Some(rhs) = rhs {
expr_to_vir(bctx, &rhs, modifier)?
} else {
mk_expr(ExprX::Block(Arc::new(Vec::new()), None))?
};
let_else_expr_vir(bctx, pat, expr, mk_expr, lhs, rhs, modifier)
// if let
let vir_expr = expr_to_vir(bctx, expr, modifier)?;
let mut vir_arms: Vec<vir::ast::Arm> = Vec::new();
/* lhs */
{
let pattern = pattern_to_vir(bctx, pat)?;
let guard = mk_expr(ExprX::Const(Constant::Bool(true)))?;
let body = expr_to_vir(bctx, &lhs, modifier)?;
let vir_arm = ArmX { pattern, guard, body };
vir_arms.push(bctx.spanned_new(lhs.span, vir_arm));
}
/* rhs */
{
let pat_typ = typ_of_node(bctx, pat.span, &pat.hir_id, false)?;
let pattern =
bctx.spanned_typed_new(cond.span, &pat_typ, PatternX::Wildcard(false));
{
let mut erasure_info = bctx.ctxt.erasure_info.borrow_mut();
erasure_info.hir_vir_ids.push((cond.hir_id, pattern.span.id));
}
let guard = mk_expr(ExprX::Const(Constant::Bool(true)))?;
let body = if let Some(rhs) = rhs {
expr_to_vir(bctx, &rhs, modifier)?
} else {
mk_expr(ExprX::Block(Arc::new(Vec::new()), None))?
};
let vir_arm = ArmX { pattern, guard, body };
vir_arms.push(bctx.spanned_new(lhs.span, vir_arm));
}
mk_expr(ExprX::Match(vir_expr, Arc::new(vir_arms)))
}
_ => {
let vir_cond = expr_to_vir(bctx, cond, modifier)?;
Expand Down Expand Up @@ -2487,83 +2485,6 @@ fn expr_assign_to_vir_innermost<'tcx>(
})
}

fn let_else_expr_vir<'tcx>(
bctx: &BodyCtxt<'tcx>,
pat: &rustc_hir::Pat<'tcx>,
expr: &Expr<'tcx>,
mk_expr: impl Fn(ExprX) -> Result<vir::ast::Expr, vir::messages::Message>,
lhs: vir::ast::Expr,
rhs: vir::ast::Expr,
modifier: ExprModifier,
) -> Result<vir::ast::Expr, VirErr> {
// if let
let vir_expr = expr_to_vir(bctx, expr, modifier)?;
let mut vir_arms: Vec<vir::ast::Arm> = Vec::new();
/* lhs */
{
let pattern = pattern_to_vir(bctx, pat)?;
let guard = mk_expr(ExprX::Const(Constant::Bool(true)))?;
let vir_arm = ArmX { pattern, guard, body: lhs };
vir_arms.push(bctx.spanned_new(expr.span, vir_arm));
}
/* rhs */
{
let pat_typ = typ_of_node(bctx, pat.span, &pat.hir_id, false)?;
let pattern = bctx.spanned_typed_new(pat.span, &pat_typ, PatternX::Wildcard(false));
let guard = mk_expr(ExprX::Const(Constant::Bool(true)))?;
let mut erasure_info = bctx.ctxt.erasure_info.borrow_mut();
erasure_info.hir_vir_ids.push((pat.hir_id, pattern.span.id));
let vir_arm = ArmX { pattern, guard, body: rhs };
vir_arms.push(bctx.spanned_new(expr.span, vir_arm));
}
let ret = mk_expr(ExprX::Match(vir_expr, Arc::new(vir_arms)));
let mut erasure_info = bctx.ctxt.erasure_info.borrow_mut();
erasure_info.hir_vir_ids.push((expr.hir_id, ret.clone().unwrap().span.id));
ret
}

fn let_stmt_with_else_to_vir<'tcx>(
bctx: &BodyCtxt<'tcx>,
pattern: &rustc_hir::Pat<'tcx>,
expr: &Expr<'tcx>,
els: &Block<'tcx>,
) -> Result<(vir::ast::Pattern, vir::ast::Expr), VirErr> {
let binds = pat_extract_binds(pattern)?;
let mut vars = Vec::new();
for p in &binds {
let typ = typ_of_node(bctx, p.span, &p.hir_id, false).expect("pat type");
let var = bctx.spanned_typed_new(p.span, &typ, ExprX::Var(pat_to_var(&p).unwrap()));
let mut erasure_info = bctx.ctxt.erasure_info.borrow_mut();
erasure_info.hir_vir_ids.push((p.hir_id, var.span.id));
vars.push(var);
}
let n = binds.len();
let mut binders = Vec::new();
let mut typs = Vec::new();

for (i, p) in binds.iter().enumerate() {
let pat = pattern_to_vir(bctx, p)?;
binders.push(ident_binder(&positional_field_ident(i), &pat));
typs.push(typ_of_node(bctx, p.span, &p.hir_id, false)?);
}

let typ = mk_tuple_typ(&Arc::new(typs));
let variant_name = vir::def::prefix_tuple_variant(n);
let px = PatternX::Constructor(Dt::Tuple(n), variant_name, Arc::new(binders));
let pat = bctx.spanned_typed_new(expr.span, &typ, px);

let lhs = vir::ast_util::mk_tuple(&pattern_to_vir(bctx, pattern)?.span, &Arc::new(vars));
let els_typ = typ_of_node(bctx, els.span, &els.hir_id, false)?;
let rhs_tmp = block_to_vir(bctx, els, &els.span, &els_typ, ExprModifier::REGULAR)?;
let rhs = bctx.spanned_typed_new(els.span, &typ, ExprX::NeverToAny(rhs_tmp));
let mk_expr = move |x: ExprX| Ok(bctx.spanned_typed_new(expr.span, &typ, x));
let init = let_else_expr_vir(bctx, pattern, expr, mk_expr, lhs, rhs, ExprModifier::REGULAR)?;

let mut erasure_info = bctx.ctxt.erasure_info.borrow_mut();
erasure_info.hir_vir_ids.push((pattern.hir_id, pat.span.id));

Ok((pat, init))
}
pub(crate) fn let_stmt_to_vir<'tcx>(
bctx: &BodyCtxt<'tcx>,
pattern: &rustc_hir::Pat<'tcx>,
Expand All @@ -2573,21 +2494,20 @@ pub(crate) fn let_stmt_to_vir<'tcx>(
) -> Result<Vec<vir::ast::Stmt>, VirErr> {
let mode = get_var_mode(bctx.mode, attrs);
let infer_mode = parse_attrs_opt(attrs, None).contains(&Attr::InferMode);
let mut new_vir_pat = None;
let init = if let Some(expr) = initializer {
if let Some(els) = els {
if matches!(mode, Mode::Spec | Mode::Proof) {
unsupported_err!(els.span, "let-else in spec/proof", els);
}
let (pat, init) = let_stmt_with_else_to_vir(bctx, pattern, expr, els)?;
new_vir_pat = Some(pat);
Some(init)
} else {
Some(expr_to_vir(bctx, expr, ExprModifier::REGULAR)?)
}
let els = if let Some(els) = els {
if matches!(mode, Mode::Spec | Mode::Proof) {
unsupported_err!(els.span, "let-else in spec/proof", els);
}
let init = initializer.unwrap();
let init_type = typ_of_node(bctx, els.span, &init.hir_id, false)?;
let els_typ = typ_of_node(bctx, els.span, &els.hir_id, false)?;
let els_block = block_to_vir(bctx, els, &els.span, &els_typ, ExprModifier::REGULAR)?;
Some(bctx.spanned_typed_new(els.span, &init_type, ExprX::NeverToAny(els_block)))
} else {
None
};
let init = initializer.map(|e| expr_to_vir(bctx, e, ExprModifier::REGULAR)).transpose()?;

if parse_attrs_opt(attrs, Some(&mut *bctx.ctxt.diagnostics.borrow_mut()))
.contains(&Attr::UnwrappedBinding)
{
Expand All @@ -2614,9 +2534,9 @@ pub(crate) fn let_stmt_to_vir<'tcx>(
}
}

let vir_pattern = if let Some(p) = new_vir_pat { p } else { pattern_to_vir(bctx, pattern)? };
let vir_pattern = pattern_to_vir(bctx, pattern)?;
let mode = if infer_mode { None } else { Some(mode) };
Ok(vec![bctx.spanned_new(pattern.span, StmtX::Decl { pattern: vir_pattern, mode, init })])
Ok(vec![bctx.spanned_new(pattern.span, StmtX::Decl { pattern: vir_pattern, mode, init, els })])
}

fn unwrap_parameter_to_vir<'tcx>(
Expand Down
1 change: 1 addition & 0 deletions source/rust_verify/src/rust_to_vir_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ pub(crate) fn check_item_fn<'tcx>(
pattern: new_binding_pat,
mode: Some(mode),
init: Some(new_init_expr),
els: None,
},
);
mut_params_redecl.push(redecl);
Expand Down
2 changes: 1 addition & 1 deletion source/vir/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ pub enum StmtX {
/// The declaration may contain a pattern;
/// however, ast_simplify replaces all patterns with PatternX::Var
/// (The mode is only allowed to be None for one special case; see modes.rs)
Decl { pattern: Pattern, mode: Option<Mode>, init: Option<Expr> },
Decl { pattern: Pattern, mode: Option<Mode>, init: Option<Expr>, els: Option<Expr> },
}

/// Function parameter
Expand Down
27 changes: 20 additions & 7 deletions source/vir/src/ast_simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ fn temp_expr(state: &mut State, expr: &Expr) -> (Stmt, Expr) {
let name = temp.clone();
let patternx = PatternX::Var { name, mutable: false };
let pattern = SpannedTyped::new(&expr.span, &expr.typ, patternx);
let decl = StmtX::Decl { pattern, mode: Some(Mode::Exec), init: Some(expr.clone()) };
let decl = StmtX::Decl { pattern, mode: Some(Mode::Exec), init: Some(expr.clone()), els: None };
let temp_decl = Spanned::new(expr.span.clone(), decl);
(temp_decl, SpannedTyped::new(&expr.span, &expr.typ, ExprX::Var(temp)))
}
Expand Down Expand Up @@ -139,7 +139,8 @@ fn pattern_to_exprs(
let patternx = PatternX::Var { name, mutable };
let pattern = SpannedTyped::new(&expr.span, &expr.typ, patternx);
// Mode doesn't matter at this stage; arbitrarily set it to 'exec'
let decl = StmtX::Decl { pattern, mode: Some(Mode::Exec), init: Some(expr.clone()) };
let decl =
StmtX::Decl { pattern, mode: Some(Mode::Exec), init: Some(expr.clone()), els: None };
decls.push(Spanned::new(expr.span.clone(), decl));
}

Expand Down Expand Up @@ -593,17 +594,27 @@ fn tuple_get_field_expr(

fn simplify_one_stmt(ctx: &GlobalCtx, state: &mut State, stmt: &Stmt) -> Result<Vec<Stmt>, VirErr> {
match &stmt.x {
StmtX::Decl { pattern, mode: _, init: None } => match &pattern.x {
StmtX::Decl { pattern, mode: _, init: None, els: None } => match &pattern.x {
PatternX::Var { .. } => Ok(vec![stmt.clone()]),
_ => Err(error(&stmt.span, "let-pattern declaration must have an initializer")),
},
StmtX::Decl { pattern, mode: _, init: Some(init) }
StmtX::Decl { pattern, mode: _, init: Some(init), els }
if !matches!(pattern.x, PatternX::Var { .. }) =>
{
let mut decls: Vec<Stmt> = Vec::new();
let (temp_decl, init) = small_or_temp(state, init);
decls.extend(temp_decl.into_iter());
let _ = pattern_to_exprs(ctx, state, &init, &pattern, &mut decls)?;
let mut decls2: Vec<Stmt> = Vec::new();
let pattern_check = pattern_to_exprs(ctx, state, &init, &pattern, &mut decls2)?;
if let Some(els) = &els {
let e = ExprX::Unary(UnaryOp::Not, pattern_check.clone());
let check = SpannedTyped::new(&pattern_check.span, &pattern_check.typ, e);
let ifx = ExprX::If(check.clone(), els.clone(), Some(init.clone()));
let init = SpannedTyped::new(&els.span, &init.typ, ifx);
let (temp_decl, _) = temp_expr(state, &init);
decls.push(temp_decl);
}
decls.extend(decls2);
Ok(decls)
}
_ => Ok(vec![stmt.clone()]),
Expand Down Expand Up @@ -732,7 +743,8 @@ fn exec_closure_spec_requires(
let patternx = PatternX::Var { name: p.name.clone(), mutable: false };
let pattern = SpannedTyped::new(span, &p.a, patternx);
let tuple_field = tuple_get_field_expr(state, span, &p.a, &tuple_var, params.len(), i);
let decl = StmtX::Decl { pattern, mode: Some(Mode::Spec), init: Some(tuple_field) };
let decl =
StmtX::Decl { pattern, mode: Some(Mode::Spec), init: Some(tuple_field), els: None };
decls.push(Spanned::new(span.clone(), decl));
}

Expand Down Expand Up @@ -792,7 +804,8 @@ fn exec_closure_spec_ensures(
let patternx = PatternX::Var { name: p.name.clone(), mutable: false };
let pattern = SpannedTyped::new(span, &p.a, patternx);
let tuple_field = tuple_get_field_expr(state, span, &p.a, &tuple_var, params.len(), i);
let decl = StmtX::Decl { pattern, mode: Some(Mode::Spec), init: Some(tuple_field) };
let decl =
StmtX::Decl { pattern, mode: Some(Mode::Spec), init: Some(tuple_field), els: None };
decls.push(Spanned::new(span.clone(), decl));
}

Expand Down
5 changes: 4 additions & 1 deletion source/vir/src/ast_to_sst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2257,7 +2257,10 @@ fn stmt_to_stm(
let (stms, exp) = expr_to_stm_opt(ctx, state, expr)?;
Ok((stms, exp, None))
}
StmtX::Decl { pattern, mode: _, init } => {
StmtX::Decl { pattern, mode: _, init, els } => {
if els.is_some() {
panic!("let-else should be simplified in ast_simpllify {:?}.", stmt)
}
let (name, mutable) = match &pattern.x {
PatternX::Var { name, mutable } => (name, mutable),
_ => panic!("internal error: Decl should have been simplified by ast_simplify"),
Expand Down
13 changes: 10 additions & 3 deletions source/vir/src/ast_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,16 @@ where
StmtX::Expr(e) => {
expr_visitor_control_flow!(expr_visitor_dfs(e, map, mf));
}
StmtX::Decl { pattern, mode: _, init } => {
StmtX::Decl { pattern, mode: _, init, els } => {
map.push_scope(true);
if let Some(init) = init {
expr_visitor_control_flow!(expr_visitor_dfs(init, map, mf));
}
insert_pattern_vars(map, &pattern, init.is_some());
if let Some(els) = els {
expr_visitor_control_flow!(expr_visitor_dfs(els, map, mf));
}
insert_pattern_vars(map, &pattern, els.is_some());
expr_visitor_control_flow!(pat_visitor_dfs(&pattern, map, mf));
}
}
Expand Down Expand Up @@ -1111,12 +1115,15 @@ where
let expr = map_expr_visitor_env(e, map, env, fe, fs, ft)?;
fs(env, map, &Spanned::new(stmt.span.clone(), StmtX::Expr(expr)))
}
StmtX::Decl { pattern, mode, init } => {
StmtX::Decl { pattern, mode, init, els } => {
let pattern = map_pattern_visitor_env(pattern, map, env, fe, fs, ft)?;
let init =
init.as_ref().map(|e| map_expr_visitor_env(e, map, env, fe, fs, ft)).transpose()?;
insert_pattern_vars(map, &pattern, init.is_some());
let decl = StmtX::Decl { pattern, mode: *mode, init };
let els =
els.as_ref().map(|e| map_expr_visitor_env(e, map, env, fe, fs, ft)).transpose()?;
insert_pattern_vars(map, &pattern, els.is_some());
let decl = StmtX::Decl { pattern, mode: *mode, init, els };
fs(env, map, &Spanned::new(stmt.span.clone(), decl))
}
}
Expand Down
4 changes: 2 additions & 2 deletions source/vir/src/modes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,7 @@ fn check_stmt(
let _ = check_expr(ctxt, record, typing, outer_mode, e)?;
Ok(())
}
StmtX::Decl { pattern, mode: None, init } => {
StmtX::Decl { pattern, mode: None, init, els: _ } => {
// Special case mode inference just for our encoding of "let tracked pat = ..."
// in Rust as "let xl; ... { let pat ... xl = xr; }".
match (&pattern.x, init) {
Expand All @@ -1550,7 +1550,7 @@ fn check_stmt(
}
Ok(())
}
StmtX::Decl { pattern, mode: Some(mode), init } => {
StmtX::Decl { pattern, mode: Some(mode), init, els: _ } => {
let mode = if typing.block_ghostness != Ghost::Exec && *mode == Mode::Exec {
Mode::Spec
} else {
Expand Down
1 change: 1 addition & 0 deletions source/vir/src/user_defined_type_invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ fn expr_followed_by_stmts(expr: &Expr, stmts: Vec<Stmt>, id_cell: &Cell<u64>) ->
),
mode: None,
init: Some(expr.clone()),
els: None,
};
stmts.insert(0, Spanned::new(expr.span.clone(), decl));
SpannedTyped::new(
Expand Down

0 comments on commit 68822c8

Please sign in to comment.