Skip to content

Commit 99c3935

Browse files
authored
Merge pull request #247 from bufferhe4d/generic-asserteq
Implement generic assert_eq
2 parents c275ef4 + 151ae9e commit 99c3935

File tree

6 files changed

+282
-54
lines changed

6 files changed

+282
-54
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
@ noname.0.7.0
2+
@ public inputs: 3
3+
4+
DoubleGeneric<1>
5+
DoubleGeneric<1>
6+
DoubleGeneric<1>
7+
DoubleGeneric<1,0,0,0,-3>
8+
DoubleGeneric<1,0,0,0,-3>
9+
DoubleGeneric<1,0,0,0,-1>
10+
DoubleGeneric<1,0,0,0,-2>
11+
DoubleGeneric<1,0,-1,0,2>
12+
DoubleGeneric<1,0,0,0,-5>
13+
DoubleGeneric<1,0,0,0,-1>
14+
DoubleGeneric<1,0,0,0,-2>
15+
DoubleGeneric<1,0,-1,0,2>
16+
DoubleGeneric<0,0,-1,1>
17+
DoubleGeneric<1,1,-1>
18+
DoubleGeneric<1,1,-1>
19+
DoubleGeneric<1,0,0,0,-5>
20+
DoubleGeneric<1,0,0,0,-1>
21+
DoubleGeneric<1,0,0,0,-1>
22+
DoubleGeneric<1,0,0,0,-2>
23+
DoubleGeneric<1,0,0,0,-3>
24+
DoubleGeneric<1,0,0,0,-4>
25+
DoubleGeneric<1,0,0,0,-5>
26+
(0,0) -> (5,0) -> (9,0) -> (12,1) -> (13,0) -> (16,0) -> (17,0)
27+
(1,0) -> (6,0) -> (10,0) -> (14,0) -> (18,0)
28+
(2,0) -> (3,0) -> (4,0) -> (7,0) -> (11,0) -> (12,0) -> (14,1) -> (19,0)
29+
(7,2) -> (8,0)
30+
(11,2) -> (15,0)
31+
(12,2) -> (13,1)
32+
(13,2) -> (20,0)
33+
(14,2) -> (21,0)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
@ noname.0.7.0
2+
@ public inputs: 3
3+
4+
3 == (v_3) * (1)
5+
3 == (v_3) * (1)
6+
1 == (v_1) * (1)
7+
2 == (v_2) * (1)
8+
5 == (v_3 + 2) * (1)
9+
1 == (v_1) * (1)
10+
2 == (v_2) * (1)
11+
v_4 == (v_3) * (v_1)
12+
5 == (v_3 + 2) * (1)
13+
1 == (v_1) * (1)
14+
1 == (v_1) * (1)
15+
2 == (v_2) * (1)
16+
3 == (v_3) * (1)
17+
4 == (v_1 + v_4) * (1)
18+
5 == (v_2 + v_3) * (1)

examples/generic_assert_eq.no

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
const size = 2;
2+
struct Thing {
3+
xx: Field,
4+
yy: [Field; 2],
5+
}
6+
7+
struct Nestedthing {
8+
xx: Field,
9+
another: [Another; 2],
10+
}
11+
12+
struct Another {
13+
aa: Field,
14+
bb: [Field; 2],
15+
}
16+
17+
fn init_arr(element: Field, const LEN: Field) -> [Field; LEN] {
18+
let arr = [element; LEN];
19+
return arr;
20+
}
21+
22+
fn main(pub public_arr: [Field; 2], pub public_input: Field) {
23+
let generic_arr = init_arr(public_input, size);
24+
let arr = [3, 3];
25+
26+
assert_eq(generic_arr, arr);
27+
let mut concrete_arr = [1, 2];
28+
29+
// instead of the following:
30+
// assert_eq(public_arr[0], concrete_arr[0]);
31+
// assert_eq(public_arr[1], concrete_arr[1]);
32+
// we can write:
33+
assert_eq(public_arr, concrete_arr);
34+
35+
let thing = Thing { xx: 5, yy: [1, 2] };
36+
let other_thing = Thing { xx: generic_arr[0] + 2, yy: public_arr };
37+
38+
// instead of the following:
39+
// assert_eq(thing.xx, other_thing.xx);
40+
// assert_eq(thing.yy[0], other_thing.yy[0]);
41+
// assert_eq(thing.yy[1], other_thing.yy[1]);
42+
// we can write:
43+
assert_eq(thing, other_thing);
44+
45+
let nested_thing = Nestedthing { xx: 5, another: [
46+
Another { aa: public_arr[0], bb: [1, 2] },
47+
Another { aa: generic_arr[1], bb: [4, 5] }
48+
] };
49+
let other_nested_thing = Nestedthing { xx: generic_arr[0] + 2, another: [
50+
Another { aa: 1, bb: public_arr },
51+
Another { aa: 3, bb: [public_arr[0] + (public_input * public_arr[0]), public_arr[1] + public_input] }
52+
] };
53+
54+
// instead of the following:
55+
// assert_eq(nested_thing.xx, other_nested_thing.xx);
56+
// assert_eq(nested_thing.another[0].aa, other_nested_thing.another[0].aa);
57+
// assert_eq(nested_thing.another[0].bb[0], other_nested_thing.another[0].bb[0]);
58+
// assert_eq(nested_thing.another[0].bb[1], other_nested_thing.another[0].bb[1]);
59+
// assert_eq(nested_thing.another[1].aa, other_nested_thing.another[1].aa);
60+
// assert_eq(nested_thing.another[1].bb[0], other_nested_thing.another[1].bb[0]);
61+
// assert_eq(nested_thing.another[1].bb[1], other_nested_thing.another[1].bb[1]);
62+
// we can write:
63+
assert_eq(nested_thing, other_nested_thing);
64+
}

src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ pub enum ErrorKind {
4545
AssignmentToImmutableVariable,
4646
#[error("the {0} of assert_eq must be of type Field or BigInt. It was of type {1}")]
4747
AssertTypeMismatch(&'static str, TyKind),
48+
#[error("the types in assert_eq don't match: expected {0} but got {1}")]
49+
AssertEqTypeMismatch(TyKind, TyKind),
4850
#[error(
4951
"the dependency `{0}` does not appear to be listed in your manifest file `Noname.toml`"
5052
)]

src/stdlib/builtins.rs

Lines changed: 146 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ use crate::{
1313
constants::Span,
1414
error::{Error, ErrorKind, Result},
1515
helpers::PrettyField,
16-
parser::types::{GenericParameters, TyKind},
16+
parser::types::{GenericParameters, ModulePath, TyKind},
17+
type_checker::FullyQualified,
1718
var::{ConstOrCell, Value, Var},
1819
};
1920

@@ -33,13 +34,118 @@ impl Module for BuiltinsLib {
3334
fn get_fns<B: Backend>() -> Vec<(&'static str, FnInfoType<B>, bool)> {
3435
vec![
3536
(ASSERT_FN, assert_fn, false),
36-
(ASSERT_EQ_FN, assert_eq_fn, false),
37+
(ASSERT_EQ_FN, assert_eq_fn, true),
3738
// true -> skip argument type checking for log
3839
(LOG_FN, log_fn, true),
3940
]
4041
}
4142
}
4243

44+
/// Represents a comparison that needs to be made
45+
enum Comparison<B: Backend> {
46+
/// Compare two variables
47+
Vars(B::Var, B::Var),
48+
/// Compare a variable with a constant
49+
VarConst(B::Var, B::Field),
50+
/// Compare two constants
51+
Constants(B::Field, B::Field),
52+
}
53+
54+
/// Helper function to generate all comparisons
55+
fn assert_eq_values<B: Backend>(
56+
compiler: &CircuitWriter<B>,
57+
lhs_info: &VarInfo<B::Field, B::Var>,
58+
rhs_info: &VarInfo<B::Field, B::Var>,
59+
typ: &TyKind,
60+
span: Span,
61+
) -> Vec<Comparison<B>> {
62+
let mut comparisons = Vec::new();
63+
64+
match typ {
65+
// Field and Bool has the same logic
66+
TyKind::Field { .. } | TyKind::Bool => {
67+
let lhs_var = &lhs_info.var[0];
68+
let rhs_var = &rhs_info.var[0];
69+
match (lhs_var, rhs_var) {
70+
(ConstOrCell::Const(a), ConstOrCell::Const(b)) => {
71+
comparisons.push(Comparison::Constants(a.clone(), b.clone()));
72+
}
73+
(ConstOrCell::Const(cst), ConstOrCell::Cell(cvar))
74+
| (ConstOrCell::Cell(cvar), ConstOrCell::Const(cst)) => {
75+
comparisons.push(Comparison::VarConst(cvar.clone(), cst.clone()));
76+
}
77+
(ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => {
78+
comparisons.push(Comparison::Vars(lhs.clone(), rhs.clone()));
79+
}
80+
}
81+
}
82+
83+
// Arrays (fixed size)
84+
TyKind::Array(element_type, size) => {
85+
let size = *size as usize;
86+
let element_size = compiler.size_of(element_type);
87+
88+
// compare each element recursively
89+
for i in 0..size {
90+
let start = i * element_size;
91+
let mut element_comparisons = assert_eq_values(
92+
compiler,
93+
&VarInfo::new(
94+
Var::new(lhs_info.var.range(start, element_size).to_vec(), span),
95+
false,
96+
Some(*element_type.clone()),
97+
),
98+
&VarInfo::new(
99+
Var::new(rhs_info.var.range(start, element_size).to_vec(), span),
100+
false,
101+
Some(*element_type.clone()),
102+
),
103+
element_type,
104+
span,
105+
);
106+
comparisons.append(&mut element_comparisons);
107+
}
108+
}
109+
110+
// Custom types (structs)
111+
TyKind::Custom { module, name } => {
112+
let qualified = FullyQualified::new(module, name);
113+
let struct_info = compiler.struct_info(&qualified).expect("struct not found");
114+
115+
// compare each field recursively
116+
let mut offset = 0;
117+
for (_, field_type) in &struct_info.fields {
118+
let field_size = compiler.size_of(field_type);
119+
let mut field_comparisons = assert_eq_values(
120+
compiler,
121+
&VarInfo::new(
122+
Var::new(lhs_info.var.range(offset, field_size).to_vec(), span),
123+
false,
124+
Some(field_type.clone()),
125+
),
126+
&VarInfo::new(
127+
Var::new(rhs_info.var.range(offset, field_size).to_vec(), span),
128+
false,
129+
Some(field_type.clone()),
130+
),
131+
field_type,
132+
span,
133+
);
134+
comparisons.append(&mut field_comparisons);
135+
offset += field_size;
136+
}
137+
}
138+
139+
// GenericSizedArray should be monomorphized to Array before reaching here
140+
// no need to handle it seperately
141+
TyKind::GenericSizedArray(_, _) => {
142+
unreachable!("GenericSizedArray should be monomorphized")
143+
}
144+
}
145+
146+
comparisons
147+
}
148+
43149
/// Asserts that two vars are equal.
44150
fn assert_eq_fn<B: Backend>(
45151
compiler: &mut CircuitWriter<B>,
@@ -52,67 +158,53 @@ fn assert_eq_fn<B: Backend>(
52158
let lhs_info = &vars[0];
53159
let rhs_info = &vars[1];
54160

55-
// they are both of type field
56-
if !matches!(lhs_info.typ, Some(TyKind::Field { .. })) {
57-
let lhs = lhs_info.typ.clone().ok_or_else(|| {
58-
Error::new(
59-
"constraint-generation",
60-
ErrorKind::UnexpectedError("No type info for lhs of assertion"),
61-
span,
62-
)
63-
})?;
64-
65-
Err(Error::new(
161+
// get types of both arguments
162+
let lhs_type = lhs_info.typ.as_ref().ok_or_else(|| {
163+
Error::new(
66164
"constraint-generation",
67-
ErrorKind::AssertTypeMismatch("rhs", lhs),
165+
ErrorKind::UnexpectedError("No type info for lhs of assertion"),
68166
span,
69-
))?
70-
}
167+
)
168+
})?;
71169

72-
if !matches!(rhs_info.typ, Some(TyKind::Field { .. })) {
73-
let rhs = rhs_info.typ.clone().ok_or_else(|| {
74-
Error::new(
75-
"constraint-generation",
76-
ErrorKind::UnexpectedError("No type info for rhs of assertion"),
77-
span,
78-
)
79-
})?;
170+
let rhs_type = rhs_info.typ.as_ref().ok_or_else(|| {
171+
Error::new(
172+
"constraint-generation",
173+
ErrorKind::UnexpectedError("No type info for rhs of assertion"),
174+
span,
175+
)
176+
})?;
80177

81-
Err(Error::new(
178+
// they have the same type
179+
if !lhs_type.match_expected(rhs_type, false) {
180+
return Err(Error::new(
82181
"constraint-generation",
83-
ErrorKind::AssertTypeMismatch("rhs", rhs),
182+
ErrorKind::AssertEqTypeMismatch(lhs_type.clone(), rhs_type.clone()),
84183
span,
85-
))?
184+
));
86185
}
87186

88-
// retrieve the values
89-
let lhs_var = &lhs_info.var;
90-
assert_eq!(lhs_var.len(), 1);
91-
let lhs_cvar = &lhs_var[0];
92-
93-
let rhs_var = &rhs_info.var;
94-
assert_eq!(rhs_var.len(), 1);
95-
let rhs_cvar = &rhs_var[0];
96-
97-
match (lhs_cvar, rhs_cvar) {
98-
// two constants
99-
(ConstOrCell::Const(a), ConstOrCell::Const(b)) => {
100-
if a != b {
101-
Err(Error::new(
102-
"constraint-generation",
103-
ErrorKind::AssertionFailed,
104-
span,
105-
))?
106-
}
107-
}
187+
// first collect all comparisons needed
188+
let comparisons = assert_eq_values(compiler, lhs_info, rhs_info, lhs_type, span);
108189

109-
// a const and a var
110-
(ConstOrCell::Const(cst), ConstOrCell::Cell(cvar))
111-
| (ConstOrCell::Cell(cvar), ConstOrCell::Const(cst)) => {
112-
compiler.backend.assert_eq_const(cvar, *cst, span)
113-
}
114-
(ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => {
115-
compiler.backend.assert_eq_var(lhs, rhs, span)
190+
// then add all the constraints
191+
for comparison in comparisons {
192+
match comparison {
193+
Comparison::Vars(lhs, rhs) => {
194+
compiler.backend.assert_eq_var(&lhs, &rhs, span);
195+
}
196+
Comparison::VarConst(var, constant) => {
197+
compiler.backend.assert_eq_const(&var, constant, span);
198+
}
199+
Comparison::Constants(a, b) => {
200+
if a != b {
201+
return Err(Error::new(
202+
"constraint-generation",
203+
ErrorKind::AssertionFailed,
204+
span,
205+
));
206+
}
207+
}
116208
}
117209
}
118210

src/tests/examples.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,25 @@ fn test_generic_array_nested(#[case] backend: BackendKind) -> miette::Result<()>
839839
Ok(())
840840
}
841841

842+
#[rstest]
843+
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
844+
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
845+
fn test_generic_assert_eq(#[case] backend: BackendKind) -> miette::Result<()> {
846+
let public_inputs = r#"{"public_arr": ["1", "2"], "public_input": "3"}"#;
847+
let private_inputs = r#"{}"#;
848+
849+
test_file(
850+
"generic_assert_eq",
851+
public_inputs,
852+
private_inputs,
853+
vec![],
854+
backend,
855+
DEFAULT_OPTIONS,
856+
)?;
857+
858+
Ok(())
859+
}
860+
842861
#[rstest]
843862
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
844863
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]

0 commit comments

Comments
 (0)