@@ -21,7 +21,7 @@ use air::ast_util::str_ident;
21
21
use rustc_ast:: LitKind ;
22
22
use rustc_hir:: def:: Res ;
23
23
use rustc_hir:: { Expr , ExprKind , Node , QPath } ;
24
- use rustc_middle:: ty:: { GenericArgKind , TyKind } ;
24
+ use rustc_middle:: ty:: { GenericArg , GenericArgKind , TyKind } ;
25
25
use rustc_span:: def_id:: DefId ;
26
26
use rustc_span:: source_map:: Spanned ;
27
27
use rustc_span:: Span ;
@@ -30,10 +30,10 @@ use vir::ast::{
30
30
ArithOp , AssertQueryMode , AutospecUsage , BinaryOp , BitwiseOp , BuiltinSpecFun , CallTarget ,
31
31
ChainedOp , ComputeMode , Constant , ExprX , FieldOpr , FunX , HeaderExpr , HeaderExprX , InequalityOp ,
32
32
IntRange , IntegerTypeBoundKind , Mode , ModeCoercion , MultiOp , Quant , Typ , TypX , UnaryOp ,
33
- UnaryOpr , VarAt , VirErr ,
33
+ UnaryOpr , VarAt , VariantCheck , VirErr ,
34
34
} ;
35
35
use vir:: ast_util:: { const_int_from_string, typ_to_diagnostic_str, types_equal, undecorate_typ} ;
36
- use vir:: def:: positional_field_ident ;
36
+ use vir:: def:: field_ident_from_rust ;
37
37
38
38
pub ( crate ) fn fn_call_to_vir < ' tcx > (
39
39
bctx : & BodyCtxt < ' tcx > ,
@@ -96,9 +96,6 @@ pub(crate) fn fn_call_to_vir<'tcx>(
96
96
) ,
97
97
) ;
98
98
}
99
- Some ( RustItem :: TryTraitBranch ) => {
100
- return err_span ( expr. span , "Verus does not yet support the ? operator" ) ;
101
- }
102
99
Some ( RustItem :: Clone ) => {
103
100
// Special case `clone` for standard Rc and Arc types
104
101
// (Could also handle it for other types where cloning is the identity
@@ -167,6 +164,8 @@ pub(crate) fn fn_call_to_vir<'tcx>(
167
164
// If the resolution is statically known, we record the resolved function for the
168
165
// to be used by lifetime_generate.
169
166
167
+ let node_substs = fix_node_substs ( tcx, bctx. types , node_substs, rust_item, & args, expr) ;
168
+
170
169
let target_kind = if tcx. trait_of_item ( f) . is_none ( ) {
171
170
vir:: ast:: CallTargetKind :: Static
172
171
} else {
@@ -740,6 +739,33 @@ fn verus_item_to_vir<'tcx, 'a>(
740
739
variant : str_ident ( & variant_name) ,
741
740
field : variant_field. unwrap ( ) ,
742
741
get_variant : true ,
742
+ check : VariantCheck :: None ,
743
+ } ) ,
744
+ adt_arg,
745
+ ) )
746
+ }
747
+ ExprItem :: GetUnionField => {
748
+ record_spec_fn_allow_proof_args ( bctx, expr) ;
749
+ assert ! ( args. len( ) == 2 ) ;
750
+ let adt_arg = expr_to_vir ( bctx, & args[ 0 ] , ExprModifier :: REGULAR ) ?;
751
+ let field_name = get_string_lit_arg ( & args[ 1 ] , & f_name) ?;
752
+
753
+ let adt_path = check_union_field (
754
+ bctx,
755
+ expr. span ,
756
+ args[ 0 ] ,
757
+ & field_name,
758
+ & bctx. types . expr_ty ( expr) ,
759
+ ) ?;
760
+
761
+ let field_ident = str_ident ( & field_name) ;
762
+ mk_expr ( ExprX :: UnaryOpr (
763
+ UnaryOpr :: Field ( FieldOpr {
764
+ datatype : adt_path,
765
+ variant : field_ident. clone ( ) ,
766
+ field : field_ident_from_rust ( & field_ident) ,
767
+ get_variant : true ,
768
+ check : VariantCheck :: None ,
743
769
} ) ,
744
770
adt_arg,
745
771
) )
@@ -1652,6 +1678,33 @@ fn mk_is_smaller_than<'tcx>(
1652
1678
return Ok ( dec_exp) ;
1653
1679
}
1654
1680
1681
+ pub ( crate ) fn fix_node_substs < ' tcx , ' a > (
1682
+ tcx : rustc_middle:: ty:: TyCtxt < ' tcx > ,
1683
+ types : & ' tcx rustc_middle:: ty:: TypeckResults < ' tcx > ,
1684
+ node_substs : & ' tcx rustc_middle:: ty:: List < rustc_middle:: ty:: GenericArg < ' tcx > > ,
1685
+ rust_item : Option < RustItem > ,
1686
+ args : & ' a [ & ' tcx Expr < ' tcx > ] ,
1687
+ expr : & ' a Expr < ' tcx > ,
1688
+ ) -> & ' tcx rustc_middle:: ty:: List < rustc_middle:: ty:: GenericArg < ' tcx > > {
1689
+ match rust_item {
1690
+ Some ( RustItem :: TryTraitBranch ) => {
1691
+ // I don't understand why, but in this case, node_substs is empty instead
1692
+ // of having the type argument. Let's fix it here.
1693
+ // `branch` has type `fn branch(self) -> ...`
1694
+ // so we can get the Self argument from the first argument.
1695
+ let generic_arg = GenericArg :: from ( types. expr_ty_adjusted ( & args[ 0 ] ) ) ;
1696
+ tcx. mk_args ( & [ generic_arg] )
1697
+ }
1698
+ Some ( RustItem :: ResidualTraitFromResidual ) => {
1699
+ // `fn from_residual(residual: R) -> Self;`
1700
+ let generic_arg0 = GenericArg :: from ( types. expr_ty ( expr) ) ;
1701
+ let generic_arg1 = GenericArg :: from ( types. expr_ty_adjusted ( & args[ 0 ] ) ) ;
1702
+ tcx. mk_args ( & [ generic_arg0, generic_arg1] )
1703
+ }
1704
+ _ => node_substs,
1705
+ }
1706
+ }
1707
+
1655
1708
fn mk_typ_args < ' tcx > (
1656
1709
bctx : & BodyCtxt < ' tcx > ,
1657
1710
substs : & rustc_middle:: ty:: List < rustc_middle:: ty:: GenericArg < ' tcx > > ,
@@ -1771,11 +1824,6 @@ fn check_variant_field<'tcx>(
1771
1824
}
1772
1825
} ;
1773
1826
1774
- let variant_opt = adt. variants ( ) . iter ( ) . find ( |v| v. ident ( tcx) . as_str ( ) == variant_name) ;
1775
- let Some ( variant) = variant_opt else {
1776
- return err_span ( span, format ! ( "no variant `{variant_name:}` for this datatype" ) ) ;
1777
- } ;
1778
-
1779
1827
let vir_adt_ty = mid_ty_to_vir ( tcx, & bctx. ctxt . verus_items , bctx. fun_id , span, & ty, false ) ?;
1780
1828
let adt_path = match & * vir_adt_ty {
1781
1829
TypX :: Datatype ( path, _, _) => path. clone ( ) ,
@@ -1784,9 +1832,34 @@ fn check_variant_field<'tcx>(
1784
1832
}
1785
1833
} ;
1786
1834
1835
+ if adt. is_union ( ) {
1836
+ if field_name_typ. is_some ( ) {
1837
+ // Don't use get_variant_field with unions
1838
+ return err_span (
1839
+ span,
1840
+ format ! ( "this datatype is a union; consider `get_union_field` instead" ) ,
1841
+ ) ;
1842
+ }
1843
+ let variant = adt. non_enum_variant ( ) ;
1844
+ let field_opt = variant. fields . iter ( ) . find ( |f| f. ident ( tcx) . as_str ( ) == variant_name) ;
1845
+ if field_opt. is_none ( ) {
1846
+ return err_span ( span, format ! ( "no field `{variant_name:}` for this union" ) ) ;
1847
+ }
1848
+
1849
+ return Ok ( ( adt_path, None ) ) ;
1850
+ }
1851
+
1852
+ // Enum case:
1853
+
1854
+ let variant_opt = adt. variants ( ) . iter ( ) . find ( |v| v. ident ( tcx) . as_str ( ) == variant_name) ;
1855
+ let Some ( variant) = variant_opt else {
1856
+ return err_span ( span, format ! ( "no variant `{variant_name:}` for this datatype" ) ) ;
1857
+ } ;
1858
+
1787
1859
match field_name_typ {
1788
1860
None => Ok ( ( adt_path, None ) ) ,
1789
1861
Some ( ( field_name, expected_field_typ) ) => {
1862
+ // The 'get_variant_field' case
1790
1863
let field_opt = variant. fields . iter ( ) . find ( |f| f. ident ( tcx) . as_str ( ) == field_name) ;
1791
1864
let Some ( field) = field_opt else {
1792
1865
return err_span ( span, format ! ( "no field `{field_name:}` for this variant" ) ) ;
@@ -1807,18 +1880,65 @@ fn check_variant_field<'tcx>(
1807
1880
return err_span ( span, "field has the wrong type" ) ;
1808
1881
}
1809
1882
1810
- let field_ident = if field_name. as_str ( ) . bytes ( ) . nth ( 0 ) . unwrap ( ) . is_ascii_digit ( ) {
1811
- let i = field_name. parse :: < usize > ( ) . unwrap ( ) ;
1812
- positional_field_ident ( i)
1813
- } else {
1814
- str_ident ( & field_name)
1815
- } ;
1883
+ let field_ident = field_ident_from_rust ( & field_name) ;
1816
1884
1817
1885
Ok ( ( adt_path, Some ( field_ident) ) )
1818
1886
}
1819
1887
}
1820
1888
}
1821
1889
1890
+ fn check_union_field < ' tcx > (
1891
+ bctx : & BodyCtxt < ' tcx > ,
1892
+ span : Span ,
1893
+ adt_arg : & ' tcx Expr < ' tcx > ,
1894
+ field_name : & String ,
1895
+ expected_field_typ : & rustc_middle:: ty:: Ty < ' tcx > ,
1896
+ ) -> Result < vir:: ast:: Path , VirErr > {
1897
+ let tcx = bctx. ctxt . tcx ;
1898
+
1899
+ let ty = bctx. types . expr_ty_adjusted ( adt_arg) ;
1900
+ let ty = match ty. kind ( ) {
1901
+ rustc_middle:: ty:: TyKind :: Ref ( _, t, rustc_ast:: Mutability :: Not ) => t,
1902
+ _ => & ty,
1903
+ } ;
1904
+ let ( adt, substs) = match ty. kind ( ) {
1905
+ rustc_middle:: ty:: TyKind :: Adt ( adt, substs) => ( adt, substs) ,
1906
+ _ => {
1907
+ return err_span ( span, format ! ( "expected type to be datatype" ) ) ;
1908
+ }
1909
+ } ;
1910
+
1911
+ if !adt. is_union ( ) {
1912
+ return err_span ( span, format ! ( "get_union_field expects a union type" ) ) ;
1913
+ }
1914
+
1915
+ let variant = adt. non_enum_variant ( ) ;
1916
+
1917
+ let field_opt = variant. fields . iter ( ) . find ( |f| f. ident ( tcx) . as_str ( ) == field_name) ;
1918
+ let Some ( field) = field_opt else {
1919
+ return err_span ( span, format ! ( "no field `{field_name:}` for this union" ) ) ;
1920
+ } ;
1921
+
1922
+ let field_ty = field. ty ( tcx, substs) ;
1923
+ let vir_field_ty =
1924
+ mid_ty_to_vir ( tcx, & bctx. ctxt . verus_items , bctx. fun_id , span, & field_ty, false ) ?;
1925
+ let vir_expected_field_ty =
1926
+ mid_ty_to_vir ( tcx, & bctx. ctxt . verus_items , bctx. fun_id , span, & expected_field_typ, false ) ?;
1927
+ if !types_equal ( & vir_field_ty, & vir_expected_field_ty) {
1928
+ return err_span ( span, "field has the wrong type" ) ;
1929
+ }
1930
+
1931
+ let vir_adt_ty = mid_ty_to_vir ( tcx, & bctx. ctxt . verus_items , bctx. fun_id , span, & ty, false ) ?;
1932
+ let adt_path = match & * vir_adt_ty {
1933
+ TypX :: Datatype ( path, _, _) => path. clone ( ) ,
1934
+ _ => {
1935
+ return err_span ( span, format ! ( "expected type to be datatype" ) ) ;
1936
+ }
1937
+ } ;
1938
+
1939
+ Ok ( adt_path)
1940
+ }
1941
+
1822
1942
fn record_compilable_operator < ' tcx > ( bctx : & BodyCtxt < ' tcx > , expr : & Expr , op : CompilableOperator ) {
1823
1943
let resolved_call = ResolvedCall :: CompilableOperator ( op) ;
1824
1944
let mut erasure_info = bctx. ctxt . erasure_info . borrow_mut ( ) ;
0 commit comments