Skip to content
This repository was archived by the owner on Jul 5, 2024. It is now read-only.

Commit 01c5133

Browse files
committed
feat: use peak memory estimation from Han
1 parent 3d0d97d commit 01c5133

File tree

2 files changed

+394
-172
lines changed

2 files changed

+394
-172
lines changed
+373
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
//! Utility functions and types to get circuit stats from any halo2 circuit
2+
3+
use eth_types::Field;
4+
use halo2_proofs::plonk::ConstraintSystem;
5+
use std::collections::BTreeSet;
6+
7+
// From Scroll https://github.com/scroll-tech/zkevm-circuits/blob/7d9bc181953cfc6e7baf82ff0ce651281fd70a8a/zkevm-circuits/src/util.rs#L275
8+
#[allow(dead_code)]
9+
#[derive(Debug, Default)]
10+
pub(crate) struct CircuitStats {
11+
pub num_constraints: usize,
12+
pub num_fixed_columns: usize,
13+
pub num_lookups: usize,
14+
pub num_shuffles: usize,
15+
pub num_advice_columns: usize,
16+
pub num_instance_columns: usize,
17+
pub num_selectors: usize,
18+
// num_simple_selectors: usize,
19+
pub num_permutation_columns: usize,
20+
pub degree: usize,
21+
pub blinding_factors: usize,
22+
pub num_challenges: usize,
23+
pub max_phase: u8,
24+
pub num_rotation: usize,
25+
pub min_rotation: i32,
26+
pub max_rotation: i32,
27+
pub num_verification_ecmul: usize,
28+
// Aux data to diff between records
29+
num_advice_queries: usize,
30+
num_gates: usize,
31+
}
32+
33+
impl CircuitStats {
34+
// Peak memory analysis by Han:
35+
//
36+
// fn create_proof(params: &KZGParams, pk: &ProvingKey, circuit: &Circuit, instances: &[&[F]]) {
37+
// Let:
38+
//
39+
// - k: log 2 of number of rows
40+
// - n: `1 << k`
41+
// - d: Degree of circuit
42+
// - e: Extension magnitude, equal to `(d - 1).next_power_of_two()`
43+
// - c_f: number of fixed columns
44+
// - c_a: number of advice columns
45+
// - c_i: number of instance columns
46+
// - c_p: number of columns enabled with copy constraint
47+
// - c_pg: number of grand product in permutation argument, equal to `div_ceil(c_p, d - 2)`
48+
// - c_l: number of lookup argument
49+
//
50+
// The memory usage M.C and M.S stands for:
51+
//
52+
// - M.C: number of "n elliptic curve points" (with symbol ◯)
53+
// - M.S: number of "n field elements" (with symbol △)
54+
// - M.E: number of "e * n field elements" (with symbol ⬡)
55+
//
56+
// So the actual memory usage in terms of bytes will be:
57+
//
58+
// M = 32 * n * (2 * M.C + M.S + e * M.E)
59+
//
60+
// We'll ignore other values with sublinear amount to n.
61+
//
62+
//
63+
// 0. In the beginning:
64+
//
65+
// `params` has:
66+
// ◯ powers_of_tau
67+
// ◯ ifft(powers_of_tau)
68+
//
69+
// M.C = 2 (+= 2)
70+
// M.S = 0
71+
// M.E = 0
72+
//
73+
// `pk` has:
74+
// ⬡ l0
75+
// ⬡ l_last
76+
// ⬡ l_active_row
77+
// △ fixed_lagranges (c_f)
78+
// △ fixed_monomials (c_f)
79+
// ⬡ fixed_extended_lagranges (c_f)
80+
// △ permutation_lagranges (c_p)
81+
// △ permutation_monomials (c_p)
82+
// ⬡ permutation_extended_lagranges (c_p)
83+
//
84+
// M.C = 2
85+
// M.S = 2 * c_f + 2 * c_p (+= 2 * c_f + 2 * c_p)
86+
// M.E = 3 + c_f + c_p (+= 3 + c_f + c_p)
87+
//
88+
// And let's ignore `circuit`
89+
//
90+
//
91+
// ### 1. Pad instances as lagrange form and compute its monomial form.
92+
//
93+
// M.C = 2
94+
// M.S = 2 * c_f + 2 * c_p + 2 * c_i (+= 2 * c_i)
95+
// M.E = 3 + c_f + c_p
96+
// ```
97+
// let instance_lagranges = instances.to_lagranges();
98+
// let instance_monomials = instance_lagranges.to_monomials();
99+
// ```
100+
//
101+
//
102+
// ### 2. Synthesize circuit and collect advice column values.
103+
//
104+
// M.C = 2
105+
// M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a (+= c_a)
106+
// M.E = 3 + c_f + c_p
107+
// ```
108+
// let advice_lagranges = circuit.synthesize_all_phases();
109+
// ```
110+
//
111+
//
112+
// ### 3. Generate permuted input and table of lookup argument.
113+
// For each lookup argument, we have:
114+
//
115+
// △ compressed_input_lagranges - cached for later computation
116+
// △ permuted_input_lagranges
117+
// △ permuted_input_monomials
118+
// △ compressed_table_lagranges - cached for later computation
119+
// △ permuted_table_lagranges
120+
// △ permuted_table_monomials
121+
//
122+
// M.C = 2
123+
// M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a + 6 * c_l (+= 6 * c_l)
124+
// M.E = 3 + c_f + c_p
125+
// ```
126+
// let (
127+
// compressed_input_lagranges,
128+
// permuted_input_lagranges,
129+
// permuted_input_monomials,
130+
// compressed_table_lagranges,
131+
// permuted_table_lagranges,
132+
// permuted_table_monomials,
133+
// ) = lookup_permuted()
134+
// ```
135+
//
136+
//
137+
// ### 4. Generate grand products of permutation argument.
138+
//
139+
// M.C = 2
140+
// M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a + 6 * c_l + c_pg (+= c_pg)
141+
// M.E = 3 + c_f + c_p + c_pg (+= c_pg)
142+
// ```
143+
// let (
144+
// perm_grand_product_monomials,
145+
// perm_grand_product_extended_lagranges,
146+
// ) = permutation_grand_products();
147+
// ```
148+
//
149+
//
150+
// ### 5. Generate grand products of lookup argument.
151+
// And then drops unnecessary lagranges values.
152+
//
153+
// M.C = 2
154+
// M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg (-= 3 * c_l)
155+
// M.E = 3 + c_f + c_p + c_pg
156+
// > let lookup_product_monomials = lookup_grand_products();
157+
// > drop(compressed_input_lagranges);
158+
// > drop(permuted_input_lagranges);
159+
// > drop(compressed_table_lagranges);
160+
// > drop(permuted_table_lagranges);
161+
//
162+
//
163+
// ### 6. Generate random polynomial.
164+
//
165+
// M.C = 2
166+
// M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg (+= 1)
167+
// M.E = 3 + c_f + c_p + c_pg
168+
// ```
169+
// let random_monomial = random();
170+
// ```
171+
//
172+
//
173+
// ### 7. Turn advice_lagranges into advice_monomials.
174+
// ```
175+
// let advice_monomials = advice_lagranges.to_monomials();
176+
// drop(advice_lagranges);
177+
// ```
178+
//
179+
//
180+
// ### 8. Generate necessary extended lagranges.
181+
//
182+
// M.C = 2
183+
// M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg
184+
// M.E = 3 + c_f + c_p + c_pg + c_i + c_a (+= c_i + c_a)
185+
// ```
186+
// let instances_extended_lagrnages = instances_monomials.to_extended_lagranges();
187+
// let advice_extended_lagrnages = advice_monomials.to_extended_lagranges();
188+
// ```
189+
//
190+
//
191+
// ### 9. While computing the quotient, these extended lagranges:
192+
//
193+
// ⬡ permuted_input_extended_lagranges
194+
// ⬡ permuted_table_extended_lagranges
195+
// ⬡ lookup_product_extended_lagranges
196+
//
197+
// of each lookup argument are generated on the fly and drop before next.
198+
//
199+
// And 1 extra quotient_extended_lagrange is created. So the peak memory:
200+
//
201+
// M.C = 2
202+
// M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg
203+
// M.E = 4 + c_f + c_p + c_pg + c_i + c_a + 3 * (c_l > 0) (+= 3 * (c_l > 0) + 1)
204+
// ```
205+
// let quotient_extended_lagrange = quotient_extended_lagrange();
206+
// ```
207+
//
208+
//
209+
// ### 10. After quotient is comuputed, drop all the other extended lagranges.
210+
//
211+
// M.C = 2
212+
// M.S = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg
213+
// M.E = 4 + c_f + c_p (-= c_pg + c_i + c_a + 3 * (c_l > 0))
214+
// drop(instances_extended_lagrnages)
215+
// drop(advice_extended_lagrnages)
216+
// drop(perm_grand_product_extended_lagranges)
217+
//
218+
//
219+
// ### 11. Turn quotient_extended_lagrange into monomial form.
220+
// And then cut int `d - 1` pieces.
221+
//
222+
// M.C = 2
223+
// M.S = 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg + d (+= d - 1)
224+
// M.E = 3 + c_f + c_p (-= 1)
225+
// ```
226+
// let quotient_monomials = quotient_monomials()
227+
// drop(quotient_extended_lagrange)
228+
// ```
229+
//
230+
//
231+
// ### 12. Evaluate and open all polynomial except instance ones.
232+
// }
233+
pub(crate) fn estimate_peak_mem(&self, k: u32) -> usize {
234+
let field_bytes = 32;
235+
let c_f = self.num_fixed_columns;
236+
let c_a = self.num_advice_columns;
237+
let c_i = self.num_instance_columns;
238+
let c_p = self.num_permutation_columns;
239+
let c_l = self.num_lookups;
240+
let c_pg = (c_p + self.degree - 3) / (self.degree - 2);
241+
let e = (self.degree - 1).next_power_of_two();
242+
// number of "n elliptic curve points"
243+
let m_c = 2;
244+
// number of "n field elements"
245+
let m_s = 1 + 2 * c_f + 2 * c_p + 2 * c_i + c_a + 3 * c_l + c_pg;
246+
// number of "e * n field elements"
247+
let m_e = 4 + c_f + c_p + c_pg + c_i + c_a + 3 * (c_l > 0) as usize;
248+
let unit = m_c + m_s + e * m_e;
249+
unit * 2usize.pow(k) * field_bytes
250+
}
251+
}
252+
253+
// Return the stats in `meta`, accounting only for the circuit delta from the last aggregated stats
254+
// in `agg`.
255+
// Adaptaed from Scroll https://github.com/scroll-tech/zkevm-circuits/blob/7d9bc181953cfc6e7baf82ff0ce651281fd70a8a/zkevm-circuits/src/util.rs#L294
256+
pub(crate) fn circuit_stats<F: Field>(
257+
agg: &CircuitStats,
258+
meta: &ConstraintSystem<F>,
259+
) -> CircuitStats {
260+
let max_phase = meta
261+
.advice_column_phase()
262+
.iter()
263+
.skip(agg.num_advice_columns)
264+
.max()
265+
.copied()
266+
.unwrap_or_default();
267+
268+
let rotations = meta
269+
.advice_queries()
270+
.iter()
271+
.skip(agg.num_advice_queries)
272+
.map(|(_, q)| q.0)
273+
.collect::<BTreeSet<i32>>();
274+
275+
let num_fixed_columns = meta.num_fixed_columns() - agg.num_fixed_columns;
276+
let num_lookups = meta.lookups().len() - agg.num_lookups;
277+
let num_shuffles = meta.shuffles().len() - agg.num_shuffles;
278+
let num_advice_columns = meta.num_advice_columns() - agg.num_advice_columns;
279+
let num_instance_columns = meta.num_instance_columns() - agg.num_instance_columns;
280+
let num_selectors = meta.num_selectors() - agg.num_selectors;
281+
let num_permutation_columns =
282+
meta.permutation().get_columns().len() - agg.num_permutation_columns;
283+
284+
CircuitStats {
285+
num_constraints: meta
286+
.gates()
287+
.iter()
288+
.skip(agg.num_gates)
289+
.map(|g| g.polynomials().len())
290+
.sum::<usize>(),
291+
num_fixed_columns,
292+
num_lookups,
293+
num_shuffles,
294+
num_advice_columns,
295+
num_instance_columns,
296+
num_selectors,
297+
// num_simple_selectors: meta.num_simple_selectors(),
298+
num_permutation_columns,
299+
degree: meta.degree(),
300+
blinding_factors: meta.blinding_factors(),
301+
num_challenges: meta.num_challenges() - agg.num_challenges,
302+
max_phase,
303+
num_rotation: rotations.len(),
304+
min_rotation: rotations.first().cloned().unwrap_or_default(),
305+
max_rotation: rotations.last().cloned().unwrap_or_default(),
306+
num_verification_ecmul: num_advice_columns
307+
+ num_instance_columns
308+
+ num_permutation_columns
309+
+ num_shuffles
310+
+ num_selectors
311+
+ num_fixed_columns
312+
+ 3 * num_lookups
313+
+ rotations.len(),
314+
num_advice_queries: meta.advice_queries().len() - agg.num_advice_queries,
315+
num_gates: meta.gates().len() - agg.num_gates,
316+
}
317+
}
318+
319+
pub(crate) struct StatsCollection<F: Field> {
320+
aggregate: bool,
321+
shared_cs: ConstraintSystem<F>,
322+
pub(crate) agg: CircuitStats,
323+
pub(crate) list: Vec<(String, CircuitStats)>,
324+
}
325+
326+
impl<F: Field> StatsCollection<F> {
327+
// With aggregate=true, all records are overwritten each time, leading to a single
328+
// aggregate stats that represents the final circuit.
329+
// With aggregate=false, each record is stored in a different entry with a name, and the
330+
// ConstraintSystem is reset so that each entry is independent.
331+
pub(crate) fn new(aggregate: bool) -> Self {
332+
Self {
333+
aggregate,
334+
shared_cs: ConstraintSystem::default(),
335+
agg: CircuitStats::default(),
336+
list: Vec::new(),
337+
}
338+
}
339+
340+
// Record a shared table
341+
pub(crate) fn record_shared(&mut self, name: &str, meta: &mut ConstraintSystem<F>) {
342+
// Shared tables should only add columns, and nothing more
343+
assert_eq!(meta.lookups().len(), 0);
344+
assert_eq!(meta.shuffles().len(), 0);
345+
assert_eq!(meta.permutation().get_columns().len(), 0);
346+
assert_eq!(meta.degree(), 3); // 3 comes from the permutation argument
347+
assert_eq!(meta.blinding_factors(), 5); // 5 is the minimum blinding factor
348+
assert_eq!(meta.advice_queries().len(), 0);
349+
assert_eq!(meta.gates().len(), 0);
350+
351+
if self.aggregate {
352+
self.agg = circuit_stats(&CircuitStats::default(), meta);
353+
} else {
354+
let stats = circuit_stats(&self.agg, meta);
355+
self.agg = circuit_stats(&CircuitStats::default(), meta);
356+
self.list.push((name.to_string(), stats));
357+
// Keep the ConstraintSystem with all the tables
358+
self.shared_cs = meta.clone();
359+
}
360+
}
361+
362+
// Record a subcircuit
363+
pub(crate) fn record(&mut self, name: &str, meta: &mut ConstraintSystem<F>) {
364+
if self.aggregate {
365+
self.agg = circuit_stats(&CircuitStats::default(), meta);
366+
} else {
367+
let stats = circuit_stats(&self.agg, meta);
368+
self.list.push((name.to_string(), stats));
369+
// Revert meta to the ConstraintSystem just with the tables
370+
*meta = self.shared_cs.clone();
371+
}
372+
}
373+
}

0 commit comments

Comments
 (0)