Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions constraint-solver/src/solver/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ where
progress |= self.exhaustive_search()?;
}

if !progress {
progress |= self.inline_affine();
}

if !progress {
break;
}
Expand Down Expand Up @@ -478,6 +482,27 @@ where
exprs
}

fn inline_affine(&mut self) -> bool {
let mut progress = false;
let affine_equivalences = self
.constraint_system
.system()
.algebraic_constraints()
.iter()
.filter(|constr| {
constr.expression.is_affine() && constr.referenced_unknown_variables().count() == 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should filter out constraints that contain new variables, although internally it could make sense.

})
.flat_map(|constr| {
let var = constr.referenced_unknown_variables().last().unwrap();
Some((var.clone(), constr.as_ref().try_solve_for(var)?))
})
.collect_vec();
for (v, expr) in affine_equivalences {
progress |= self.apply_assignment(&v, &expr);
}
progress
}

fn apply_effect(&mut self, effect: Effect<T, V>) -> bool {
match effect {
Effect::Assignment(v, expr) => {
Expand Down
22 changes: 13 additions & 9 deletions constraint-solver/src/solver/linearizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,11 @@ mod tests {
let assignments = solver.solve().unwrap();
expect!([r#"
lin_4 = 0
lin_7 = 0"#])
lin_7 = 0
lin_1 = z + 1
lin_3 = x - 1
lin_6 = c - 2
lin_8 = -(a)"#])
.assert_eq(
&assignments
.iter()
Expand All @@ -275,18 +279,18 @@ mod tests {
expect!([r#"
((x + y) * (z + 1)) * (x - 1) = 0
x + y - lin_0 = 0
z - lin_1 + 1 = 0
(lin_0) * (lin_1) - lin_2 = 0
x - lin_3 - 1 = 0
(lin_2) * (lin_3) = 0
0 = 0
(lin_0) * (z + 1) - lin_2 = 0
0 = 0
(lin_2) * (x - 1) = 0
0 = 0
(a + b) * (c - 2) = 0
a + b - lin_5 = 0
c - lin_6 - 2 = 0
(lin_5) * (lin_6) = 0
0 = 0
-(a + lin_8) = 0
BusInteraction { bus_id: 1, multiplicity: lin_1, payload: lin_0, lin_8, a }"#])
(lin_5) * (c - 2) = 0
0 = 0
0 = 0
BusInteraction { bus_id: 1, multiplicity: z + 1, payload: lin_0, -(a), a }"#])
.assert_eq(&solver.to_string());
}
}
18 changes: 9 additions & 9 deletions openvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1921,10 +1921,10 @@ mod tests {
AirMetrics {
widths: AirWidths {
preprocessed: 0,
main: 17300,
main: 17280,
log_up: 27896,
},
constraints: 8834,
constraints: 8792,
bus_interactions: 11925,
}
"#]],
Expand All @@ -1943,7 +1943,7 @@ mod tests {
},
after: AirWidths {
preprocessed: 0,
main: 17300,
main: 17280,
log_up: 27896,
},
}
Expand All @@ -1970,10 +1970,10 @@ mod tests {
AirMetrics {
widths: AirWidths {
preprocessed: 0,
main: 19928,
main: 19908,
log_up: 30924,
},
constraints: 11103,
constraints: 11061,
bus_interactions: 13442,
}
"#]],
Expand All @@ -1992,7 +1992,7 @@ mod tests {
},
after: AirWidths {
preprocessed: 0,
main: 19928,
main: 19908,
log_up: 30924,
},
}
Expand Down Expand Up @@ -2125,10 +2125,10 @@ mod tests {
AirMetrics {
widths: AirWidths {
preprocessed: 0,
main: 3246,
main: 3243,
log_up: 5264,
},
constraints: 598,
constraints: 591,
bus_interactions: 2562,
}
"#]],
Expand All @@ -2147,7 +2147,7 @@ mod tests {
},
after: AirWidths {
preprocessed: 0,
main: 3246,
main: 3243,
log_up: 5264,
},
}
Expand Down
24 changes: 7 additions & 17 deletions openvm/tests/apc_snapshots/complex/memcpy_block.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ Instructions:
BNE 52 0 248 1 1

APC advantage:
- Main columns: 172 -> 37 (4.65x reduction)
- Main columns: 172 -> 34 (5.06x reduction)
- Bus interactions: 87 -> 21 (4.14x reduction)
- Constraints: 111 -> 33 (3.36x reduction)
- Constraints: 111 -> 26 (4.27x reduction)

Symbolic machine using 37 unique main columns:
Symbolic machine using 34 unique main columns:
from_state__timestamp_0
reads_aux__0__base__prev_timestamp_0
reads_aux__0__base__timestamp_lt_aux__lower_decomp__0_0
Expand All @@ -27,9 +27,6 @@ Symbolic machine using 37 unique main columns:
b__3_0
cmp_result_1
diff_marker__0_1
diff_marker__1_1
diff_marker__2_1
diff_marker__3_1
diff_val_1
reads_aux__0__base__prev_timestamp_2
reads_aux__0__base__timestamp_lt_aux__lower_decomp__0_2
Expand Down Expand Up @@ -75,22 +72,15 @@ mult=is_valid * 1, args=[15360 * reads_aux__1__base__prev_timestamp_4 + 15360 *

// Bus 6 (BITWISE_LOOKUP):
mult=is_valid * 1, args=[b__0_0, 3, b__0_0 + 3 - 2 * a__0_0, 1]
mult=diff_marker__0_1 + diff_marker__1_1 + diff_marker__2_1 + diff_marker__3_1, args=[diff_val_1 - 1, 0, 0, 0]
mult=diff_marker__0_1, args=[diff_val_1 - 1, 0, 0, 0]
mult=diff_marker__0_2 + diff_marker__1_2 + diff_marker__2_2 + diff_marker__3_2, args=[diff_val_2 - 1, 0, 0, 0]

// Algebraic constraints:
cmp_result_1 * (cmp_result_1 - 1) = 0
diff_marker__3_1 * (diff_marker__3_1 - 1) = 0
diff_marker__3_1 * diff_val_1 = 0
diff_marker__2_1 * (diff_marker__2_1 - 1) = 0
diff_marker__2_1 * diff_val_1 = 0
diff_marker__1_1 * (diff_marker__1_1 - 1) = 0
diff_marker__1_1 * diff_val_1 = 0
diff_marker__0_1 * (diff_marker__0_1 - 1) = 0
(1 * is_valid - (diff_marker__0_1 + diff_marker__1_1 + diff_marker__2_1 + diff_marker__3_1)) * ((1 - a__0_0) * (2 * cmp_result_1 - 1)) = 0
(1 * is_valid - diff_marker__0_1) * ((1 - a__0_0) * (2 * cmp_result_1 - 1)) = 0
diff_marker__0_1 * ((a__0_0 - 1) * (2 * cmp_result_1 - 1) + diff_val_1) = 0
(diff_marker__0_1 + diff_marker__1_1 + diff_marker__2_1 + diff_marker__3_1) * (diff_marker__0_1 + diff_marker__1_1 + diff_marker__2_1 + diff_marker__3_1 - 1) = 0
(1 - (diff_marker__0_1 + diff_marker__1_1 + diff_marker__2_1 + diff_marker__3_1)) * cmp_result_1 = 0
(1 - diff_marker__0_1) * cmp_result_1 = 0
cmp_result_2 * (cmp_result_2 - 1) = 0
diff_marker__3_2 * (diff_marker__3_2 - 1) = 0
(1 - diff_marker__3_2) * (writes_aux__prev_data__3_2 * (2 * cmp_result_2 - 1)) = 0
Expand All @@ -109,6 +99,6 @@ diff_marker__0_2 * ((writes_aux__prev_data__0_2 - 1) * (2 * cmp_result_2 - 1) +
cmp_result_4 * (cmp_result_4 - 1) = 0
(1 - cmp_result_4) * (cmp_result_1 + cmp_result_2 - cmp_result_1 * cmp_result_2) = 0
(cmp_result_1 + cmp_result_2 - cmp_result_1 * cmp_result_2) * diff_inv_marker__0_4 - cmp_result_4 = 0
(1 - is_valid) * (diff_marker__0_1 + diff_marker__1_1 + diff_marker__2_1 + diff_marker__3_1) = 0
(1 - is_valid) * diff_marker__0_1 = 0
(1 - is_valid) * (diff_marker__0_2 + diff_marker__1_2 + diff_marker__2_2 + diff_marker__3_2) = 0
is_valid * (is_valid - 1) = 0
24 changes: 7 additions & 17 deletions openvm/tests/apc_snapshots/complex/unaligned_memcpy.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ Instructions:
BNE 68 0 -48 1 1

APC advantage:
- Main columns: 465 -> 105 (4.43x reduction)
- Main columns: 465 -> 102 (4.56x reduction)
- Bus interactions: 242 -> 58 (4.17x reduction)
- Constraints: 286 -> 69 (4.14x reduction)
- Constraints: 286 -> 62 (4.61x reduction)

Symbolic machine using 105 unique main columns:
Symbolic machine using 102 unique main columns:
from_state__timestamp_0
rs1_data__0_0
rs1_data__1_0
Expand Down Expand Up @@ -107,9 +107,6 @@ Symbolic machine using 105 unique main columns:
reads_aux__0__base__timestamp_lt_aux__lower_decomp__0_6
cmp_result_6
diff_marker__0_6
diff_marker__1_6
diff_marker__2_6
diff_marker__3_6
diff_val_6
cmp_result_7
diff_marker__0_7
Expand Down Expand Up @@ -180,7 +177,7 @@ mult=is_valid * 1, args=[15360 * reads_aux__0__base__prev_timestamp_6 + 15360 *

// Bus 6 (BITWISE_LOOKUP):
mult=is_valid * 1, args=[b__0_5, 3, b__0_5 + 3 - 2 * a__0_5, 1]
mult=diff_marker__0_6 + diff_marker__1_6 + diff_marker__2_6 + diff_marker__3_6, args=[diff_val_6 - 1, 0, 0, 0]
mult=diff_marker__0_6, args=[diff_val_6 - 1, 0, 0, 0]
mult=diff_marker__0_7 + diff_marker__1_7 + diff_marker__2_7 + diff_marker__3_7, args=[diff_val_7 - 1, 0, 0, 0]
mult=is_valid * 1, args=[a__0_1, a__1_1, 0, 0]
mult=is_valid * 1, args=[a__2_1, a__3_1, 0, 0]
Expand Down Expand Up @@ -223,17 +220,10 @@ flags__1_3 * (flags__1_3 - 1) + flags__2_3 * (flags__2_3 - 1) + 4 * flags__0_3 *
(120 * a__0_4 + 30720 * a__1_4 + 7864320 * a__2_4 + 121 * is_valid - (120 * writes_aux__prev_data__0_4 + 30720 * writes_aux__prev_data__1_4 + 7864320 * writes_aux__prev_data__2_4)) * (120 * a__0_4 + 30720 * a__1_4 + 7864320 * a__2_4 + 120 - (120 * writes_aux__prev_data__0_4 + 30720 * writes_aux__prev_data__1_4 + 7864320 * writes_aux__prev_data__2_4)) = 0
(943718400 * writes_aux__prev_data__0_4 + 120 * a__1_4 + 30720 * a__2_4 + 7864320 * a__3_4 - (120 * writes_aux__prev_data__1_4 + 30720 * writes_aux__prev_data__2_4 + 7864320 * writes_aux__prev_data__3_4 + 943718400 * a__0_4 + 943718399 * is_valid)) * (943718400 * writes_aux__prev_data__0_4 + 120 * a__1_4 + 30720 * a__2_4 + 7864320 * a__3_4 - (120 * writes_aux__prev_data__1_4 + 30720 * writes_aux__prev_data__2_4 + 7864320 * writes_aux__prev_data__3_4 + 943718400 * a__0_4 + 943718400)) = 0
cmp_result_6 * (cmp_result_6 - 1) = 0
diff_marker__3_6 * (diff_marker__3_6 - 1) = 0
diff_marker__3_6 * diff_val_6 = 0
diff_marker__2_6 * (diff_marker__2_6 - 1) = 0
diff_marker__2_6 * diff_val_6 = 0
diff_marker__1_6 * (diff_marker__1_6 - 1) = 0
diff_marker__1_6 * diff_val_6 = 0
diff_marker__0_6 * (diff_marker__0_6 - 1) = 0
(1 - (diff_marker__0_6 + diff_marker__1_6 + diff_marker__2_6 + diff_marker__3_6)) * (a__0_5 * (2 * cmp_result_6 - 1)) = 0
(1 - diff_marker__0_6) * (a__0_5 * (2 * cmp_result_6 - 1)) = 0
diff_marker__0_6 * (diff_val_6 - a__0_5 * (2 * cmp_result_6 - 1)) = 0
(diff_marker__0_6 + diff_marker__1_6 + diff_marker__2_6 + diff_marker__3_6) * (diff_marker__0_6 + diff_marker__1_6 + diff_marker__2_6 + diff_marker__3_6 - 1) = 0
(1 - (diff_marker__0_6 + diff_marker__1_6 + diff_marker__2_6 + diff_marker__3_6)) * cmp_result_6 = 0
(1 - diff_marker__0_6) * cmp_result_6 = 0
cmp_result_7 * (cmp_result_7 - 1) = 0
diff_marker__3_7 * (diff_marker__3_7 - 1) = 0
(1 - diff_marker__3_7) * (a__3_4 * (2 * cmp_result_7 - 1)) = 0
Expand All @@ -258,6 +248,6 @@ cmp_result_12 * (cmp_result_12 - 1) = 0
cmp_result_6 * cmp_result_7 * diff_inv_marker__0_12 - cmp_result_12 = 0
flags__2_3 * (flags__2_3 - 1) - (flags__0_3 * (flags__0_3 + flags__1_3 + flags__2_3 + flags__3_3 - 2) + 2 * flags__1_3 * (flags__0_3 + flags__1_3 + flags__2_3 + flags__3_3 - 2) + 3 * flags__2_3 * (flags__0_3 + flags__1_3 + flags__2_3 + flags__3_3 - 2)) = 0
opcode_loadb_flag0_0 * shifted_read_data__0_0 + (1 - opcode_loadb_flag0_0) * shifted_read_data__1_0 - read_data__0_3 = 0
(1 - is_valid) * (diff_marker__0_6 + diff_marker__1_6 + diff_marker__2_6 + diff_marker__3_6) = 0
(1 - is_valid) * diff_marker__0_6 = 0
(1 - is_valid) * (diff_marker__0_7 + diff_marker__1_7 + diff_marker__2_7 + diff_marker__3_7) = 0
is_valid * (is_valid - 1) = 0
22 changes: 9 additions & 13 deletions openvm/tests/apc_snapshots/single_instructions/single_loadhu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ Instructions:
LOADHU rd_rs2_ptr = 0, rs1_ptr = 2, imm = 22, mem_as = 2, needs_write = 0, imm_sign = 0

APC advantage:
- Main columns: 41 -> 18 (2.28x reduction)
- Main columns: 41 -> 17 (2.41x reduction)
- Bus interactions: 17 -> 12 (1.42x reduction)
- Constraints: 25 -> 9 (2.78x reduction)
- Constraints: 25 -> 6 (4.17x reduction)

Symbolic machine using 18 unique main columns:
Symbolic machine using 17 unique main columns:
from_state__timestamp_0
rs1_data__0_0
rs1_data__1_0
Expand All @@ -19,7 +19,6 @@ Symbolic machine using 18 unique main columns:
mem_ptr_limbs__0_0
mem_ptr_limbs__1_0
flags__1_0
flags__2_0
read_data__0_0
read_data__1_0
read_data__2_0
Expand All @@ -33,24 +32,21 @@ mult=is_valid * 1, args=[4, from_state__timestamp_0 + 3]
// Bus 1 (MEMORY):
mult=is_valid * -1, args=[1, 2, rs1_data__0_0, rs1_data__1_0, rs1_data__2_0, rs1_data__3_0, rs1_aux_cols__base__prev_timestamp_0]
mult=is_valid * 1, args=[1, 2, rs1_data__0_0, rs1_data__1_0, rs1_data__2_0, rs1_data__3_0, from_state__timestamp_0]
mult=is_valid * -1, args=[2, 2 * flags__1_0 * (flags__1_0 + flags__2_0 - 2) + 3 * flags__2_0 * (flags__1_0 + flags__2_0 - 2) + mem_ptr_limbs__0_0 + 65536 * mem_ptr_limbs__1_0 - flags__2_0 * (flags__2_0 - 1), read_data__0_0, read_data__1_0, read_data__2_0, read_data__3_0, read_data_aux__base__prev_timestamp_0]
mult=is_valid * 1, args=[2, 2 * flags__1_0 * (flags__1_0 + flags__2_0 - 2) + 3 * flags__2_0 * (flags__1_0 + flags__2_0 - 2) + mem_ptr_limbs__0_0 + 65536 * mem_ptr_limbs__1_0 - flags__2_0 * (flags__2_0 - 1), read_data__0_0, read_data__1_0, read_data__2_0, read_data__3_0, from_state__timestamp_0 + 1]
mult=is_valid * -1, args=[2, (flags__1_0 - 2) * (1 - flags__1_0) + mem_ptr_limbs__0_0 + 65536 * mem_ptr_limbs__1_0, read_data__0_0, read_data__1_0, read_data__2_0, read_data__3_0, read_data_aux__base__prev_timestamp_0]
mult=is_valid * 1, args=[2, (flags__1_0 - 2) * (1 - flags__1_0) + mem_ptr_limbs__0_0 + 65536 * mem_ptr_limbs__1_0, read_data__0_0, read_data__1_0, read_data__2_0, read_data__3_0, from_state__timestamp_0 + 1]

// Bus 3 (VARIABLE_RANGE_CHECKER):
mult=is_valid * 1, args=[rs1_aux_cols__base__timestamp_lt_aux__lower_decomp__0_0, 17]
mult=is_valid * 1, args=[15360 * rs1_aux_cols__base__prev_timestamp_0 + 15360 * rs1_aux_cols__base__timestamp_lt_aux__lower_decomp__0_0 + 15360 - 15360 * from_state__timestamp_0, 12]
mult=is_valid * 1, args=[503316480 * flags__2_0 * (flags__2_0 - 1) + 503316481 * flags__2_0 * (flags__1_0 + flags__2_0 - 2) + 503316480 * flags__1_0 * flags__2_0 - (1006632960 * flags__1_0 * (flags__1_0 + flags__2_0 - 2) + 503316480 * mem_ptr_limbs__0_0), 14]
mult=is_valid * 1, args=[(1006632960 - 503316480 * flags__1_0) * (1 - flags__1_0) + 503316480 * flags__1_0 * (2 - flags__1_0) - 503316480 * mem_ptr_limbs__0_0, 14]
mult=is_valid * 1, args=[mem_ptr_limbs__1_0, 13]
mult=is_valid * 1, args=[read_data_aux__base__timestamp_lt_aux__lower_decomp__0_0, 17]
mult=is_valid * 1, args=[15360 * read_data_aux__base__prev_timestamp_0 + 15360 * read_data_aux__base__timestamp_lt_aux__lower_decomp__0_0 - 15360 * from_state__timestamp_0, 12]

// Algebraic constraints:
flags__1_0 * ((flags__1_0 - 1) * (flags__1_0 - 2)) = 0
flags__2_0 * ((flags__2_0 - 1) * (flags__2_0 - 2)) = 0
(flags__1_0 + flags__2_0 - 1 * is_valid) * (flags__1_0 + flags__2_0 - 2) = 0
1006632960 * flags__1_0 * (flags__1_0 - 1) + 1006632960 * flags__2_0 * (flags__2_0 - 1) + flags__1_0 * (flags__1_0 + flags__2_0 - 2) + flags__2_0 * (flags__1_0 + flags__2_0 - 2) + 1 * is_valid = 0
1006632960 * flags__1_0 * (flags__1_0 - 1) + 1 - (1006632960 * flags__1_0 + 1) * (1 - flags__1_0) = 0
(30720 * mem_ptr_limbs__0_0 - (30720 * rs1_data__0_0 + 7864320 * rs1_data__1_0 + 675840 * is_valid)) * (30720 * mem_ptr_limbs__0_0 - (30720 * rs1_data__0_0 + 7864320 * rs1_data__1_0 + 675841)) = 0
(943718400 * rs1_data__0_0 + 30720 * mem_ptr_limbs__1_0 + 629145590 * is_valid - (120 * rs1_data__1_0 + 30720 * rs1_data__2_0 + 7864320 * rs1_data__3_0 + 943718400 * mem_ptr_limbs__0_0)) * (943718400 * rs1_data__0_0 + 30720 * mem_ptr_limbs__1_0 + 629145589 - (120 * rs1_data__1_0 + 30720 * rs1_data__2_0 + 7864320 * rs1_data__3_0 + 943718400 * mem_ptr_limbs__0_0)) = 0
flags__1_0 * (flags__1_0 - 1) + flags__2_0 * (flags__2_0 - 1) + 5 * flags__1_0 * flags__2_0 - (flags__1_0 * (flags__1_0 + flags__2_0 - 2) + flags__2_0 * (flags__1_0 + flags__2_0 - 2) + 2 * is_valid) = 0
flags__1_0 * flags__2_0 = 0
flags__1_0 * (flags__1_0 - 1) + (2 - flags__1_0) * (1 - flags__1_0) + 5 * flags__1_0 * (2 - flags__1_0) - 2 = 0
flags__1_0 * (2 - flags__1_0) = 0
is_valid * (is_valid - 1) = 0
Loading
Loading