Skip to content

Commit a9db3f5

Browse files
committed
wip
1 parent 95d9ae9 commit a9db3f5

13 files changed

+210
-54
lines changed

Diff for: aggregator/Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ark-std = "0.4.0"
1414
env_logger = "0.10.0"
1515
ethers-core = "0.17.0"
1616
log = "0.4"
17+
itertools = "0.10.3"
1718
serde = { version = "1.0", features = ["derive"] }
1819
serde_json = "1.0"
1920
rand = "0.8"
@@ -31,5 +32,5 @@ halo2-base = { git = "https://github.com/scroll-tech/halo2-lib", branch = "halo2
3132

3233

3334
[features]
34-
default = []
35-
# default = [ "ark-std/print-trace" ]
35+
# default = []
36+
default = [ "ark-std/print-trace" ]

Diff for: aggregator/configs/compression_wide.config

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"strategy":"Simple","degree":21,"num_advice":[25],"num_lookup_advice":[1],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3}
1+
{"strategy":"Simple","degree":21,"num_advice":[10],"num_lookup_advice":[1],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3}

Diff for: aggregator/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ mod chunk;
44
// This module implements `Batch` related data types.
55
// A batch is a list of chunk.
66
mod batch;
7+
/// Parameters for compression circuit
8+
pub(crate) mod param;
79
/// proof aggregation
810
mod proof_aggregation;
911
/// proof compression

Diff for: aggregator/src/proof_compression/param.rs renamed to aggregator/src/param.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ pub(crate) const LIMBS: usize = 3;
44
pub(crate) const BITS: usize = 88;
55

66
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
7-
/// Parameters for aggregation circuit configs.
8-
pub struct CompressionConfigParams {
7+
/// Parameters for aggregation circuit and compression circuit configs.
8+
pub struct ConfigParams {
99
pub strategy: FpStrategy,
1010
pub degree: u32,
1111
pub num_advice: Vec<usize>,

Diff for: aggregator/src/proof_aggregation.rs

+1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ mod circuit;
33
/// public input aggregation
44
mod public_input_aggregation;
55

6+
pub use circuit::AggregationCircuit;
67
pub use public_input_aggregation::*;

Diff for: aggregator/src/proof_aggregation/circuit.rs

+83-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
use halo2_proofs::{
22
circuit::Value,
3-
halo2curves::bn256::{Bn256, Fr, G1Affine},
4-
poly::kzg::commitment::ParamsKZG,
3+
halo2curves::bn256::{Bn256, Fq, Fr, G1Affine},
4+
poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG},
55
};
6+
use itertools::Itertools;
67
use rand::Rng;
7-
use snark_verifier::pcs::kzg::KzgSuccinctVerifyingKey;
8-
use snark_verifier_sdk::{Snark, SnarkWitness};
8+
use snark_verifier::{
9+
pcs::{
10+
kzg::{Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey},
11+
AccumulationSchemeProver,
12+
},
13+
util::arithmetic::fe_to_limbs,
14+
verifier::PlonkVerifier,
15+
};
16+
use snark_verifier_sdk::{
17+
halo2::{aggregation::Shplonk, PoseidonTranscript, POSEIDON_SPEC},
18+
NativeLoader, Snark, SnarkWitness,
19+
};
920

10-
use crate::{BatchHashCircuit, ChunkHash};
21+
use crate::{
22+
param::{BITS, LIMBS},
23+
BatchHashCircuit, ChunkHash,
24+
};
1125

1226
/// Aggregation circuit that does not re-expose any public inputs from aggregated snarks
1327
#[derive(Clone)]
@@ -56,6 +70,69 @@ impl AggregationCircuit {
5670
});
5771
}
5872

59-
todo!()
73+
let svk = params.get_g()[0].into();
74+
75+
// TODO: this is all redundant calculation to get the public output
76+
// Halo2 should just be able to expose public output to instance column directly
77+
let mut transcript_read =
78+
PoseidonTranscript::<NativeLoader, &[u8]>::from_spec(&[], POSEIDON_SPEC.clone());
79+
let accumulators = snarks
80+
.iter()
81+
.flat_map(|snark| {
82+
transcript_read.new_stream(snark.proof.as_slice());
83+
let proof = Shplonk::read_proof(
84+
&svk,
85+
&snark.protocol,
86+
&snark.instances,
87+
&mut transcript_read,
88+
);
89+
Shplonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof)
90+
})
91+
.collect::<Vec<_>>();
92+
93+
let (accumulator, as_proof) = {
94+
let mut transcript_write = PoseidonTranscript::<NativeLoader, Vec<u8>>::from_spec(
95+
vec![],
96+
POSEIDON_SPEC.clone(),
97+
);
98+
// We always use SHPLONK for accumulation scheme when aggregating proofs
99+
let accumulator = KzgAs::<Kzg<Bn256, Bdfg21>>::create_proof::<
100+
PoseidonTranscript<NativeLoader, Vec<u8>>,
101+
_,
102+
>(
103+
&Default::default(),
104+
&accumulators,
105+
&mut transcript_write,
106+
rng,
107+
)
108+
.unwrap();
109+
(accumulator, transcript_write.finalize())
110+
};
111+
112+
let KzgAccumulator::<G1Affine, NativeLoader> { lhs, rhs } = accumulator;
113+
let acc_instances = [lhs.x, lhs.y, rhs.x, rhs.y]
114+
.map(fe_to_limbs::<Fq, Fr, LIMBS, BITS>)
115+
.concat();
116+
let snark_instance = snarks.iter().flat_map(|snark| {
117+
snark
118+
.instances
119+
.iter()
120+
.flat_map(|instance| instance.iter().skip(12))
121+
});
122+
let flattened_instances = acc_instances
123+
.iter()
124+
.chain(snark_instance)
125+
.cloned()
126+
.collect();
127+
128+
let batch_hash_circuit = BatchHashCircuit::construct(chunk_hashes);
129+
130+
Self {
131+
svk,
132+
snarks: snarks.into_iter().cloned().map_into().collect(),
133+
flattened_instances,
134+
as_proof: Value::known(as_proof),
135+
batch_hash_circuit,
136+
}
60137
}
61138
}

Diff for: aggregator/src/proof_compression.rs

-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,5 @@ mod circuit;
99
mod circuit_ext;
1010
/// Config for compression circuit
1111
mod config;
12-
/// Parameters for compression circuit
13-
mod param;
1412

1513
pub use circuit::CompressionCircuit;
16-
pub use param::CompressionConfigParams;

Diff for: aggregator/src/proof_compression/circuit.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ use snark_verifier_sdk::{
3232
NativeLoader, Snark, SnarkWitness,
3333
};
3434

35-
use crate::proof_compression::param::{BITS, LIMBS};
35+
use crate::param::{ConfigParams, BITS, LIMBS};
3636

37-
use super::{config::CompressionConfig, param::CompressionConfigParams};
37+
use super::config::CompressionConfig;
3838

3939
/// Input a proof, this compression circuit generates a new proof that may have smaller size.
4040
///
@@ -79,7 +79,7 @@ impl Circuit<Fr> for CompressionCircuit {
7979
fn configure(meta: &mut ConstraintSystem<Fr>) -> Self::Config {
8080
let path = std::env::var("VERIFY_CONFIG")
8181
.unwrap_or_else(|_| "configs/verify_circuit.config".to_owned());
82-
let params: CompressionConfigParams = serde_json::from_reader(
82+
let params: ConfigParams = serde_json::from_reader(
8383
File::open(path.as_str()).unwrap_or_else(|_| panic!("{path:?} does not exist")),
8484
)
8585
.unwrap();

Diff for: aggregator/src/proof_compression/circuit_ext.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use halo2_proofs::{halo2curves::bn256::Fr, plonk::Selector};
44
use snark_verifier_sdk::CircuitExt;
55

6-
use crate::proof_compression::param::LIMBS;
6+
use crate::param::LIMBS;
77

88
use super::circuit::CompressionCircuit;
99

Diff for: aggregator/src/proof_compression/config.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ use halo2_proofs::{
99
};
1010
use snark_verifier::loader::halo2::halo2_ecc::fields::fp::FpConfig;
1111

12-
use crate::proof_compression::param::{BITS, LIMBS};
13-
14-
use super::param::CompressionConfigParams;
12+
use crate::param::{ConfigParams, BITS, LIMBS};
1513

1614
#[derive(Clone, Debug)]
1715
/// Configurations for compression circuit
@@ -25,7 +23,7 @@ pub struct CompressionConfig {
2523

2624
impl CompressionConfig {
2725
/// Build a configuration from parameters.
28-
pub fn configure(meta: &mut ConstraintSystem<Fr>, params: CompressionConfigParams) -> Self {
26+
pub fn configure(meta: &mut ConstraintSystem<Fr>, params: ConfigParams) -> Self {
2927
assert!(
3028
params.limb_bits == BITS && params.num_limbs == LIMBS,
3129
"For now we fix limb_bits = {}, otherwise change code",

Diff for: aggregator/src/tests/mock_chunk.rs

+25-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
use ark_std::test_rng;
1+
use ark_std::{test_rng, start_timer, end_timer};
2+
use halo2_base::utils::fs::gen_srs;
23
use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr};
34
use snark_verifier_sdk::CircuitExt;
5+
use snark_verifier_sdk::{
6+
gen_pk,
7+
halo2::{gen_snark_shplonk, verify_snark_shplonk},
8+
};
49

510
use crate::{ChunkHash, LOG_DEGREE};
611

@@ -11,7 +16,7 @@ mod config;
1116
#[derive(Debug, Default, Clone, Copy)]
1217
/// A mock chunk circuit
1318
pub struct MockChunkCircuit {
14-
chunk: ChunkHash,
19+
pub(crate) chunk: ChunkHash,
1520
}
1621

1722
#[test]
@@ -20,10 +25,27 @@ fn test_mock_chunk_prover() {
2025

2126
let mut rng = test_rng();
2227

28+
let param = gen_srs(LOG_DEGREE);
2329
let circuit = MockChunkCircuit::random(&mut rng);
2430
let instance = circuit.instances();
2531

2632
let mock_prover = MockProver::<Fr>::run(LOG_DEGREE, &circuit, instance).unwrap();
2733

28-
mock_prover.assert_satisfied_par()
34+
mock_prover.assert_satisfied_par();
35+
36+
let timer = start_timer!(|| format!("key generation for k = {}", LOG_DEGREE));
37+
let pk = gen_pk(&param, &circuit, None);
38+
end_timer!(timer);
39+
40+
let timer = start_timer!(|| "proving");
41+
let snark = gen_snark_shplonk(&param, &pk, circuit, &mut rng, None::<String>);
42+
end_timer!(timer);
43+
44+
let timer = start_timer!(|| "verifying");
45+
assert!(verify_snark_shplonk::<MockChunkCircuit>(
46+
&param,
47+
snark,
48+
pk.get_vk()
49+
));
50+
end_timer!(timer);
2951
}

Diff for: aggregator/src/tests/proof_aggregation.rs

+85-27
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,89 @@
1+
use std::{fs, path::Path, process};
2+
3+
use ark_std::test_rng;
14
use eth_types::H256;
2-
use halo2_proofs::halo2curves::bn256::Fr;
5+
use halo2_base::utils::fs::gen_srs;
6+
use halo2_proofs::{halo2curves::bn256::Fr, poly::commitment::Params};
7+
use itertools::Itertools;
8+
use snark_verifier_sdk::{
9+
gen_pk,
10+
halo2::{gen_snark_shplonk, verify_snark_shplonk},
11+
};
312

4-
use crate::{BatchHash, ChunkHash};
13+
use crate::{AggregationCircuit, BatchHash, ChunkHash};
514

6-
#[derive(Clone, Debug, Default)]
7-
// A test circuit that constraints
8-
// pi = keccak( chain id || prev state root || post state root
9-
// || withdraw root || data hash )
10-
pub(crate) struct TestCircuit {
11-
chunk_hashes: Vec<ChunkHash>,
12-
batch_hash: BatchHash,
13-
}
15+
use super::mock_chunk::MockChunkCircuit;
16+
17+
const CHUNKS_PER_BATCH: usize = 4;
18+
19+
#[test]
20+
fn test_aggregation_circuit() {
21+
env_logger::init();
22+
23+
let dir = format!("data/{}", process::id());
24+
let path = Path::new(dir.as_str());
25+
fs::create_dir(path).unwrap();
26+
27+
let k0 = 19;
28+
let k1 = 22;
29+
let layer_1_params = gen_srs(k1);
30+
31+
let mut rng = test_rng();
32+
let chunks = (0..CHUNKS_PER_BATCH)
33+
.map(|_| ChunkHash::mock_chunk_hash(&mut rng))
34+
.collect_vec();
1435

15-
// impl TestCircuit {
16-
// fn random<R: rand::RngCore>(r: &mut R, num_chunks: usize) -> Self {
17-
// let chunk_hashes = (0..num_chunks)
18-
// .map(|_| ChunkHash::mock_chunk_hash(r))
19-
// .collect::<Vec<_>>();
20-
// let batch_hash = BatchHash::construct(chunk_hashes.as_ref());
21-
22-
// Self {
23-
// chunk_hashes,
24-
// batch_hash,
25-
// }
26-
// }
27-
28-
// fn instances(&self) -> Vec<Vec<Fr>> {
29-
// self.batch_hash.data_hash
30-
// }
31-
// }
36+
// build layer 0 snarks
37+
let layer_0_snarks = {
38+
let layer_0_params = {
39+
let mut params = layer_1_params.clone();
40+
params.downsize(k0);
41+
params
42+
};
43+
44+
let circuits = chunks
45+
.iter()
46+
.map(|&chunk| MockChunkCircuit { chunk })
47+
.collect_vec();
48+
log::trace!("finished layer 0 pk generation for circuit");
49+
let layer_0_pk = gen_pk(
50+
&layer_0_params,
51+
&circuits[0],
52+
Some(&path.join(Path::new("layer_0.pkey"))),
53+
);
54+
log::trace!("finished layer 0 pk generation for circuit");
55+
56+
let layer_0_snarks = circuits
57+
.iter()
58+
.enumerate()
59+
.map(|(i, circuit)| {
60+
gen_snark_shplonk(
61+
&layer_0_params,
62+
&layer_0_pk,
63+
circuit.clone(),
64+
&mut rng,
65+
Some(&path.join(Path::new(format!("layer_0_{}.snark", i).as_str()))),
66+
)
67+
})
68+
.collect_vec();
69+
log::trace!("finished layer 0 snark generation for circuit");
70+
71+
// sanity checks
72+
layer_0_snarks.iter().for_each(|snark| {
73+
assert!(verify_snark_shplonk::<MockChunkCircuit>(
74+
&layer_0_params,
75+
snark.clone(),
76+
layer_0_pk.get_vk()
77+
))
78+
});
79+
log::trace!("finished layer 0 snark verification");
80+
81+
layer_0_snarks
82+
};
83+
84+
// build layer 1 the aggregation circuit
85+
{
86+
let aggregation_circuit =
87+
AggregationCircuit::new(&layer_1_params, &layer_0_snarks, rng, &chunks);
88+
}
89+
}

Diff for: aggregator/src/tests/proof_compression.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ fn test_proof_compression() {
137137
fs::create_dir(path).unwrap();
138138

139139
let k0 = 19;
140-
let k1 = 21;
140+
let k1 = 22;
141141

142142
let mut rng = test_rng();
143143
let layer_1_params = gen_srs(k1);
@@ -167,7 +167,7 @@ fn test_proof_compression() {
167167
);
168168
log::trace!("finished layer 0 snark generation for circuit");
169169

170-
assert!(verify_snark_shplonk::<TestCircuit>(
170+
assert!(verify_snark_shplonk::<MockChunkCircuit>(
171171
&layer_0_params,
172172
layer_0_snark.clone(),
173173
layer_0_pk.get_vk()

0 commit comments

Comments
 (0)