@@ -21,7 +21,7 @@ use air::ast_util::str_ident;
2121use rustc_ast:: LitKind ;
2222use rustc_hir:: def:: Res ;
2323use rustc_hir:: { Expr , ExprKind , Node , QPath } ;
24- use rustc_middle:: ty:: { GenericArgKind , TyKind } ;
24+ use rustc_middle:: ty:: { GenericArg , GenericArgKind , TyKind } ;
2525use rustc_span:: def_id:: DefId ;
2626use rustc_span:: source_map:: Spanned ;
2727use rustc_span:: Span ;
@@ -30,10 +30,10 @@ use vir::ast::{
3030 ArithOp , AssertQueryMode , AutospecUsage , BinaryOp , BitwiseOp , BuiltinSpecFun , CallTarget ,
3131 ChainedOp , ComputeMode , Constant , ExprX , FieldOpr , FunX , HeaderExpr , HeaderExprX , InequalityOp ,
3232 IntRange , IntegerTypeBoundKind , Mode , ModeCoercion , MultiOp , Quant , Typ , TypX , UnaryOp ,
33- UnaryOpr , VarAt , VirErr ,
33+ UnaryOpr , VarAt , VariantCheck , VirErr ,
3434} ;
3535use vir:: ast_util:: { const_int_from_string, typ_to_diagnostic_str, types_equal, undecorate_typ} ;
36- use vir:: def:: positional_field_ident ;
36+ use vir:: def:: field_ident_from_rust ;
3737
3838pub ( crate ) fn fn_call_to_vir < ' tcx > (
3939 bctx : & BodyCtxt < ' tcx > ,
@@ -96,9 +96,6 @@ pub(crate) fn fn_call_to_vir<'tcx>(
9696 ) ,
9797 ) ;
9898 }
99- Some ( RustItem :: TryTraitBranch ) => {
100- return err_span ( expr. span , "Verus does not yet support the ? operator" ) ;
101- }
10299 Some ( RustItem :: Clone ) => {
103100 // Special case `clone` for standard Rc and Arc types
104101 // (Could also handle it for other types where cloning is the identity
@@ -167,6 +164,8 @@ pub(crate) fn fn_call_to_vir<'tcx>(
167164 // If the resolution is statically known, we record the resolved function for the
168165 // to be used by lifetime_generate.
169166
167+ let node_substs = fix_node_substs ( tcx, bctx. types , node_substs, rust_item, & args, expr) ;
168+
170169 let target_kind = if tcx. trait_of_item ( f) . is_none ( ) {
171170 vir:: ast:: CallTargetKind :: Static
172171 } else {
@@ -740,6 +739,33 @@ fn verus_item_to_vir<'tcx, 'a>(
740739 variant : str_ident ( & variant_name) ,
741740 field : variant_field. unwrap ( ) ,
742741 get_variant : true ,
742+ check : VariantCheck :: None ,
743+ } ) ,
744+ adt_arg,
745+ ) )
746+ }
747+ ExprItem :: GetUnionField => {
748+ record_spec_fn_allow_proof_args ( bctx, expr) ;
749+ assert ! ( args. len( ) == 2 ) ;
750+ let adt_arg = expr_to_vir ( bctx, & args[ 0 ] , ExprModifier :: REGULAR ) ?;
751+ let field_name = get_string_lit_arg ( & args[ 1 ] , & f_name) ?;
752+
753+ let adt_path = check_union_field (
754+ bctx,
755+ expr. span ,
756+ args[ 0 ] ,
757+ & field_name,
758+ & bctx. types . expr_ty ( expr) ,
759+ ) ?;
760+
761+ let field_ident = str_ident ( & field_name) ;
762+ mk_expr ( ExprX :: UnaryOpr (
763+ UnaryOpr :: Field ( FieldOpr {
764+ datatype : adt_path,
765+ variant : field_ident. clone ( ) ,
766+ field : field_ident_from_rust ( & field_ident) ,
767+ get_variant : true ,
768+ check : VariantCheck :: None ,
743769 } ) ,
744770 adt_arg,
745771 ) )
@@ -1652,6 +1678,33 @@ fn mk_is_smaller_than<'tcx>(
16521678 return Ok ( dec_exp) ;
16531679}
16541680
1681+ pub ( crate ) fn fix_node_substs < ' tcx , ' a > (
1682+ tcx : rustc_middle:: ty:: TyCtxt < ' tcx > ,
1683+ types : & ' tcx rustc_middle:: ty:: TypeckResults < ' tcx > ,
1684+ node_substs : & ' tcx rustc_middle:: ty:: List < rustc_middle:: ty:: GenericArg < ' tcx > > ,
1685+ rust_item : Option < RustItem > ,
1686+ args : & ' a [ & ' tcx Expr < ' tcx > ] ,
1687+ expr : & ' a Expr < ' tcx > ,
1688+ ) -> & ' tcx rustc_middle:: ty:: List < rustc_middle:: ty:: GenericArg < ' tcx > > {
1689+ match rust_item {
1690+ Some ( RustItem :: TryTraitBranch ) => {
1691+ // I don't understand why, but in this case, node_substs is empty instead
1692+ // of having the type argument. Let's fix it here.
1693+ // `branch` has type `fn branch(self) -> ...`
1694+ // so we can get the Self argument from the first argument.
1695+ let generic_arg = GenericArg :: from ( types. expr_ty_adjusted ( & args[ 0 ] ) ) ;
1696+ tcx. mk_args ( & [ generic_arg] )
1697+ }
1698+ Some ( RustItem :: ResidualTraitFromResidual ) => {
1699+ // `fn from_residual(residual: R) -> Self;`
1700+ let generic_arg0 = GenericArg :: from ( types. expr_ty ( expr) ) ;
1701+ let generic_arg1 = GenericArg :: from ( types. expr_ty_adjusted ( & args[ 0 ] ) ) ;
1702+ tcx. mk_args ( & [ generic_arg0, generic_arg1] )
1703+ }
1704+ _ => node_substs,
1705+ }
1706+ }
1707+
16551708fn mk_typ_args < ' tcx > (
16561709 bctx : & BodyCtxt < ' tcx > ,
16571710 substs : & rustc_middle:: ty:: List < rustc_middle:: ty:: GenericArg < ' tcx > > ,
@@ -1771,11 +1824,6 @@ fn check_variant_field<'tcx>(
17711824 }
17721825 } ;
17731826
1774- let variant_opt = adt. variants ( ) . iter ( ) . find ( |v| v. ident ( tcx) . as_str ( ) == variant_name) ;
1775- let Some ( variant) = variant_opt else {
1776- return err_span ( span, format ! ( "no variant `{variant_name:}` for this datatype" ) ) ;
1777- } ;
1778-
17791827 let vir_adt_ty = mid_ty_to_vir ( tcx, & bctx. ctxt . verus_items , bctx. fun_id , span, & ty, false ) ?;
17801828 let adt_path = match & * vir_adt_ty {
17811829 TypX :: Datatype ( path, _, _) => path. clone ( ) ,
@@ -1784,9 +1832,34 @@ fn check_variant_field<'tcx>(
17841832 }
17851833 } ;
17861834
1835+ if adt. is_union ( ) {
1836+ if field_name_typ. is_some ( ) {
1837+ // Don't use get_variant_field with unions
1838+ return err_span (
1839+ span,
1840+ format ! ( "this datatype is a union; consider `get_union_field` instead" ) ,
1841+ ) ;
1842+ }
1843+ let variant = adt. non_enum_variant ( ) ;
1844+ let field_opt = variant. fields . iter ( ) . find ( |f| f. ident ( tcx) . as_str ( ) == variant_name) ;
1845+ if field_opt. is_none ( ) {
1846+ return err_span ( span, format ! ( "no field `{variant_name:}` for this union" ) ) ;
1847+ }
1848+
1849+ return Ok ( ( adt_path, None ) ) ;
1850+ }
1851+
1852+ // Enum case:
1853+
1854+ let variant_opt = adt. variants ( ) . iter ( ) . find ( |v| v. ident ( tcx) . as_str ( ) == variant_name) ;
1855+ let Some ( variant) = variant_opt else {
1856+ return err_span ( span, format ! ( "no variant `{variant_name:}` for this datatype" ) ) ;
1857+ } ;
1858+
17871859 match field_name_typ {
17881860 None => Ok ( ( adt_path, None ) ) ,
17891861 Some ( ( field_name, expected_field_typ) ) => {
1862+ // The 'get_variant_field' case
17901863 let field_opt = variant. fields . iter ( ) . find ( |f| f. ident ( tcx) . as_str ( ) == field_name) ;
17911864 let Some ( field) = field_opt else {
17921865 return err_span ( span, format ! ( "no field `{field_name:}` for this variant" ) ) ;
@@ -1807,18 +1880,65 @@ fn check_variant_field<'tcx>(
18071880 return err_span ( span, "field has the wrong type" ) ;
18081881 }
18091882
1810- let field_ident = if field_name. as_str ( ) . bytes ( ) . nth ( 0 ) . unwrap ( ) . is_ascii_digit ( ) {
1811- let i = field_name. parse :: < usize > ( ) . unwrap ( ) ;
1812- positional_field_ident ( i)
1813- } else {
1814- str_ident ( & field_name)
1815- } ;
1883+ let field_ident = field_ident_from_rust ( & field_name) ;
18161884
18171885 Ok ( ( adt_path, Some ( field_ident) ) )
18181886 }
18191887 }
18201888}
18211889
1890+ fn check_union_field < ' tcx > (
1891+ bctx : & BodyCtxt < ' tcx > ,
1892+ span : Span ,
1893+ adt_arg : & ' tcx Expr < ' tcx > ,
1894+ field_name : & String ,
1895+ expected_field_typ : & rustc_middle:: ty:: Ty < ' tcx > ,
1896+ ) -> Result < vir:: ast:: Path , VirErr > {
1897+ let tcx = bctx. ctxt . tcx ;
1898+
1899+ let ty = bctx. types . expr_ty_adjusted ( adt_arg) ;
1900+ let ty = match ty. kind ( ) {
1901+ rustc_middle:: ty:: TyKind :: Ref ( _, t, rustc_ast:: Mutability :: Not ) => t,
1902+ _ => & ty,
1903+ } ;
1904+ let ( adt, substs) = match ty. kind ( ) {
1905+ rustc_middle:: ty:: TyKind :: Adt ( adt, substs) => ( adt, substs) ,
1906+ _ => {
1907+ return err_span ( span, format ! ( "expected type to be datatype" ) ) ;
1908+ }
1909+ } ;
1910+
1911+ if !adt. is_union ( ) {
1912+ return err_span ( span, format ! ( "get_union_field expects a union type" ) ) ;
1913+ }
1914+
1915+ let variant = adt. non_enum_variant ( ) ;
1916+
1917+ let field_opt = variant. fields . iter ( ) . find ( |f| f. ident ( tcx) . as_str ( ) == field_name) ;
1918+ let Some ( field) = field_opt else {
1919+ return err_span ( span, format ! ( "no field `{field_name:}` for this union" ) ) ;
1920+ } ;
1921+
1922+ let field_ty = field. ty ( tcx, substs) ;
1923+ let vir_field_ty =
1924+ mid_ty_to_vir ( tcx, & bctx. ctxt . verus_items , bctx. fun_id , span, & field_ty, false ) ?;
1925+ let vir_expected_field_ty =
1926+ mid_ty_to_vir ( tcx, & bctx. ctxt . verus_items , bctx. fun_id , span, & expected_field_typ, false ) ?;
1927+ if !types_equal ( & vir_field_ty, & vir_expected_field_ty) {
1928+ return err_span ( span, "field has the wrong type" ) ;
1929+ }
1930+
1931+ let vir_adt_ty = mid_ty_to_vir ( tcx, & bctx. ctxt . verus_items , bctx. fun_id , span, & ty, false ) ?;
1932+ let adt_path = match & * vir_adt_ty {
1933+ TypX :: Datatype ( path, _, _) => path. clone ( ) ,
1934+ _ => {
1935+ return err_span ( span, format ! ( "expected type to be datatype" ) ) ;
1936+ }
1937+ } ;
1938+
1939+ Ok ( adt_path)
1940+ }
1941+
18221942fn record_compilable_operator < ' tcx > ( bctx : & BodyCtxt < ' tcx > , expr : & Expr , op : CompilableOperator ) {
18231943 let resolved_call = ResolvedCall :: CompilableOperator ( op) ;
18241944 let mut erasure_info = bctx. ctxt . erasure_info . borrow_mut ( ) ;
0 commit comments