@@ -13,7 +13,8 @@ use crate::{
13
13
constants:: Span ,
14
14
error:: { Error , ErrorKind , Result } ,
15
15
helpers:: PrettyField ,
16
- parser:: types:: { GenericParameters , TyKind } ,
16
+ parser:: types:: { GenericParameters , ModulePath , TyKind } ,
17
+ type_checker:: FullyQualified ,
17
18
var:: { ConstOrCell , Value , Var } ,
18
19
} ;
19
20
@@ -33,13 +34,118 @@ impl Module for BuiltinsLib {
33
34
fn get_fns < B : Backend > ( ) -> Vec < ( & ' static str , FnInfoType < B > , bool ) > {
34
35
vec ! [
35
36
( ASSERT_FN , assert_fn, false ) ,
36
- ( ASSERT_EQ_FN , assert_eq_fn, false ) ,
37
+ ( ASSERT_EQ_FN , assert_eq_fn, true ) ,
37
38
// true -> skip argument type checking for log
38
39
( LOG_FN , log_fn, true ) ,
39
40
]
40
41
}
41
42
}
42
43
44
+ /// Represents a comparison that needs to be made
45
+ enum Comparison < B : Backend > {
46
+ /// Compare two variables
47
+ Vars ( B :: Var , B :: Var ) ,
48
+ /// Compare a variable with a constant
49
+ VarConst ( B :: Var , B :: Field ) ,
50
+ /// Compare two constants
51
+ Constants ( B :: Field , B :: Field ) ,
52
+ }
53
+
54
+ /// Helper function to generate all comparisons
55
+ fn assert_eq_values < B : Backend > (
56
+ compiler : & CircuitWriter < B > ,
57
+ lhs_info : & VarInfo < B :: Field , B :: Var > ,
58
+ rhs_info : & VarInfo < B :: Field , B :: Var > ,
59
+ typ : & TyKind ,
60
+ span : Span ,
61
+ ) -> Vec < Comparison < B > > {
62
+ let mut comparisons = Vec :: new ( ) ;
63
+
64
+ match typ {
65
+ // Field and Bool has the same logic
66
+ TyKind :: Field { .. } | TyKind :: Bool => {
67
+ let lhs_var = & lhs_info. var [ 0 ] ;
68
+ let rhs_var = & rhs_info. var [ 0 ] ;
69
+ match ( lhs_var, rhs_var) {
70
+ ( ConstOrCell :: Const ( a) , ConstOrCell :: Const ( b) ) => {
71
+ comparisons. push ( Comparison :: Constants ( a. clone ( ) , b. clone ( ) ) ) ;
72
+ }
73
+ ( ConstOrCell :: Const ( cst) , ConstOrCell :: Cell ( cvar) )
74
+ | ( ConstOrCell :: Cell ( cvar) , ConstOrCell :: Const ( cst) ) => {
75
+ comparisons. push ( Comparison :: VarConst ( cvar. clone ( ) , cst. clone ( ) ) ) ;
76
+ }
77
+ ( ConstOrCell :: Cell ( lhs) , ConstOrCell :: Cell ( rhs) ) => {
78
+ comparisons. push ( Comparison :: Vars ( lhs. clone ( ) , rhs. clone ( ) ) ) ;
79
+ }
80
+ }
81
+ }
82
+
83
+ // Arrays (fixed size)
84
+ TyKind :: Array ( element_type, size) => {
85
+ let size = * size as usize ;
86
+ let element_size = compiler. size_of ( element_type) ;
87
+
88
+ // compare each element recursively
89
+ for i in 0 ..size {
90
+ let start = i * element_size;
91
+ let mut element_comparisons = assert_eq_values (
92
+ compiler,
93
+ & VarInfo :: new (
94
+ Var :: new ( lhs_info. var . range ( start, element_size) . to_vec ( ) , span) ,
95
+ false ,
96
+ Some ( * element_type. clone ( ) ) ,
97
+ ) ,
98
+ & VarInfo :: new (
99
+ Var :: new ( rhs_info. var . range ( start, element_size) . to_vec ( ) , span) ,
100
+ false ,
101
+ Some ( * element_type. clone ( ) ) ,
102
+ ) ,
103
+ element_type,
104
+ span,
105
+ ) ;
106
+ comparisons. append ( & mut element_comparisons) ;
107
+ }
108
+ }
109
+
110
+ // Custom types (structs)
111
+ TyKind :: Custom { module, name } => {
112
+ let qualified = FullyQualified :: new ( module, name) ;
113
+ let struct_info = compiler. struct_info ( & qualified) . expect ( "struct not found" ) ;
114
+
115
+ // compare each field recursively
116
+ let mut offset = 0 ;
117
+ for ( _, field_type) in & struct_info. fields {
118
+ let field_size = compiler. size_of ( field_type) ;
119
+ let mut field_comparisons = assert_eq_values (
120
+ compiler,
121
+ & VarInfo :: new (
122
+ Var :: new ( lhs_info. var . range ( offset, field_size) . to_vec ( ) , span) ,
123
+ false ,
124
+ Some ( field_type. clone ( ) ) ,
125
+ ) ,
126
+ & VarInfo :: new (
127
+ Var :: new ( rhs_info. var . range ( offset, field_size) . to_vec ( ) , span) ,
128
+ false ,
129
+ Some ( field_type. clone ( ) ) ,
130
+ ) ,
131
+ field_type,
132
+ span,
133
+ ) ;
134
+ comparisons. append ( & mut field_comparisons) ;
135
+ offset += field_size;
136
+ }
137
+ }
138
+
139
+ // GenericSizedArray should be monomorphized to Array before reaching here
140
+ // no need to handle it seperately
141
+ TyKind :: GenericSizedArray ( _, _) => {
142
+ unreachable ! ( "GenericSizedArray should be monomorphized" )
143
+ }
144
+ }
145
+
146
+ comparisons
147
+ }
148
+
43
149
/// Asserts that two vars are equal.
44
150
fn assert_eq_fn < B : Backend > (
45
151
compiler : & mut CircuitWriter < B > ,
@@ -52,67 +158,53 @@ fn assert_eq_fn<B: Backend>(
52
158
let lhs_info = & vars[ 0 ] ;
53
159
let rhs_info = & vars[ 1 ] ;
54
160
55
- // they are both of type field
56
- if !matches ! ( lhs_info. typ, Some ( TyKind :: Field { .. } ) ) {
57
- let lhs = lhs_info. typ . clone ( ) . ok_or_else ( || {
58
- Error :: new (
59
- "constraint-generation" ,
60
- ErrorKind :: UnexpectedError ( "No type info for lhs of assertion" ) ,
61
- span,
62
- )
63
- } ) ?;
64
-
65
- Err ( Error :: new (
161
+ // get types of both arguments
162
+ let lhs_type = lhs_info. typ . as_ref ( ) . ok_or_else ( || {
163
+ Error :: new (
66
164
"constraint-generation" ,
67
- ErrorKind :: AssertTypeMismatch ( "rhs" , lhs) ,
165
+ ErrorKind :: UnexpectedError ( "No type info for lhs of assertion" ) ,
68
166
span,
69
- ) ) ?
70
- }
167
+ )
168
+ } ) ? ;
71
169
72
- if !matches ! ( rhs_info. typ, Some ( TyKind :: Field { .. } ) ) {
73
- let rhs = rhs_info. typ . clone ( ) . ok_or_else ( || {
74
- Error :: new (
75
- "constraint-generation" ,
76
- ErrorKind :: UnexpectedError ( "No type info for rhs of assertion" ) ,
77
- span,
78
- )
79
- } ) ?;
170
+ let rhs_type = rhs_info. typ . as_ref ( ) . ok_or_else ( || {
171
+ Error :: new (
172
+ "constraint-generation" ,
173
+ ErrorKind :: UnexpectedError ( "No type info for rhs of assertion" ) ,
174
+ span,
175
+ )
176
+ } ) ?;
80
177
81
- Err ( Error :: new (
178
+ // they have the same type
179
+ if !lhs_type. match_expected ( rhs_type, false ) {
180
+ return Err ( Error :: new (
82
181
"constraint-generation" ,
83
- ErrorKind :: AssertTypeMismatch ( "rhs" , rhs ) ,
182
+ ErrorKind :: AssertEqTypeMismatch ( lhs_type . clone ( ) , rhs_type . clone ( ) ) ,
84
183
span,
85
- ) ) ?
184
+ ) ) ;
86
185
}
87
186
88
- // retrieve the values
89
- let lhs_var = & lhs_info. var ;
90
- assert_eq ! ( lhs_var. len( ) , 1 ) ;
91
- let lhs_cvar = & lhs_var[ 0 ] ;
92
-
93
- let rhs_var = & rhs_info. var ;
94
- assert_eq ! ( rhs_var. len( ) , 1 ) ;
95
- let rhs_cvar = & rhs_var[ 0 ] ;
96
-
97
- match ( lhs_cvar, rhs_cvar) {
98
- // two constants
99
- ( ConstOrCell :: Const ( a) , ConstOrCell :: Const ( b) ) => {
100
- if a != b {
101
- Err ( Error :: new (
102
- "constraint-generation" ,
103
- ErrorKind :: AssertionFailed ,
104
- span,
105
- ) ) ?
106
- }
107
- }
187
+ // first collect all comparisons needed
188
+ let comparisons = assert_eq_values ( compiler, lhs_info, rhs_info, lhs_type, span) ;
108
189
109
- // a const and a var
110
- ( ConstOrCell :: Const ( cst) , ConstOrCell :: Cell ( cvar) )
111
- | ( ConstOrCell :: Cell ( cvar) , ConstOrCell :: Const ( cst) ) => {
112
- compiler. backend . assert_eq_const ( cvar, * cst, span)
113
- }
114
- ( ConstOrCell :: Cell ( lhs) , ConstOrCell :: Cell ( rhs) ) => {
115
- compiler. backend . assert_eq_var ( lhs, rhs, span)
190
+ // then add all the constraints
191
+ for comparison in comparisons {
192
+ match comparison {
193
+ Comparison :: Vars ( lhs, rhs) => {
194
+ compiler. backend . assert_eq_var ( & lhs, & rhs, span) ;
195
+ }
196
+ Comparison :: VarConst ( var, constant) => {
197
+ compiler. backend . assert_eq_const ( & var, constant, span) ;
198
+ }
199
+ Comparison :: Constants ( a, b) => {
200
+ if a != b {
201
+ return Err ( Error :: new (
202
+ "constraint-generation" ,
203
+ ErrorKind :: AssertionFailed ,
204
+ span,
205
+ ) ) ;
206
+ }
207
+ }
116
208
}
117
209
}
118
210
0 commit comments