diff --git a/Cargo.lock b/Cargo.lock index 0e7394dd37..699d69ef03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,7 +40,7 @@ dependencies = [ [[package]] name = "aggregator" -version = "0.11.0" +version = "0.12.0" dependencies = [ "ark-std 0.3.0", "bitstream-io", @@ -594,7 +594,7 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bus-mapping" -version = "0.11.0" +version = "0.12.0" dependencies = [ "ctor", "env_logger", @@ -760,7 +760,7 @@ dependencies = [ [[package]] name = "circuit-benchmarks" -version = "0.11.0" +version = "0.12.0" dependencies = [ "ark-std 0.3.0", "bus-mapping", @@ -1409,7 +1409,7 @@ dependencies = [ [[package]] name = "eth-types" -version = "0.11.0" +version = "0.12.0" dependencies = [ "base64 0.13.1", "ethers-core", @@ -1720,7 +1720,7 @@ dependencies = [ [[package]] name = "external-tracer" -version = "0.11.0" +version = "0.12.0" dependencies = [ "eth-types", "geth-utils", @@ -1947,7 +1947,7 @@ dependencies = [ [[package]] name = "gadgets" -version = "0.11.0" +version = "0.12.0" dependencies = [ "eth-types", "halo2_proofs", @@ -1971,7 +1971,7 @@ dependencies = [ [[package]] name = "geth-utils" -version = "0.11.0" +version = "0.12.0" dependencies = [ "env_logger", "gobuild", @@ -2490,7 +2490,7 @@ dependencies = [ [[package]] name = "integration-tests" -version = "0.11.0" +version = "0.12.0" dependencies = [ "bus-mapping", "env_logger", @@ -2798,7 +2798,7 @@ dependencies = [ [[package]] name = "mock" -version = "0.11.0" +version = "0.12.0" dependencies = [ "eth-types", "ethers-core", @@ -2812,7 +2812,7 @@ dependencies = [ [[package]] name = "mpt-zktrie" -version = "0.11.0" +version = "0.12.0" dependencies = [ "env_logger", "eth-types", @@ -3461,7 +3461,7 @@ dependencies = [ [[package]] name = "prover" -version = "0.11.0" +version = "0.12.0" dependencies = [ "aggregator", "anyhow", @@ -4607,7 +4607,7 @@ dependencies = [ [[package]] name = "testool" -version = "0.11.0" +version = "0.12.0" dependencies = [ "anyhow", "bus-mapping", @@ -5484,7 +5484,7 @@ dependencies = [ [[package]] name = "zkevm-circuits" -version = "0.11.0" +version = "0.12.0" dependencies = [ "array-init", "bus-mapping", diff --git a/Cargo.toml b/Cargo.toml index 76e6357800..19554d12ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.11.0" +version = "0.12.0" edition = "2021" license = "MIT OR Apache-2.0" diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index 25f7a3ce58..8f0535458b 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -46,5 +46,5 @@ csv = "1.1" [features] default = ["revm-precompile/c-kzg"] print-trace = ["ark-std/print-trace"] -# This feature is useful for unit tests where we check the SAT of pi aggregation circuit +# This feature is useful for unit tests where we check the SAT of pi batch circuit disable_proof_aggregation = [] diff --git a/aggregator/configs/bundle_circuit.config b/aggregator/configs/bundle_circuit.config new file mode 100644 index 0000000000..0e4b61516b --- /dev/null +++ b/aggregator/configs/bundle_circuit.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":[5],"num_lookup_advice":[1],"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/aggregator/src/aggregation.rs b/aggregator/src/aggregation.rs index a27aba3e0a..25f7ddafec 100644 --- a/aggregator/src/aggregation.rs +++ b/aggregator/src/aggregation.rs @@ -23,5 +23,5 @@ pub(crate) use blob_data::BlobDataConfig; pub(crate) use decoder::{witgen, DecoderConfig, DecoderConfigArgs}; pub(crate) use rlc::RlcConfig; -pub use circuit::AggregationCircuit; -pub use config::AggregationConfig; +pub use circuit::BatchCircuit; +pub use config::BatchCircuitConfig; diff --git a/aggregator/src/aggregation/circuit.rs b/aggregator/src/aggregation/circuit.rs index 63de57ca0b..c8a4fbbaba 100644 --- a/aggregator/src/aggregation/circuit.rs +++ b/aggregator/src/aggregation/circuit.rs @@ -1,6 +1,14 @@ -use crate::{blob::BatchData, witgen::MultiBlockProcessResult, LOG_DEGREE}; +use crate::{ + blob::BatchData, witgen::MultiBlockProcessResult, LOG_DEGREE, PI_CHAIN_ID, + PI_CURRENT_BATCH_HASH, PI_CURRENT_STATE_ROOT, PI_CURRENT_WITHDRAW_ROOT, PI_PARENT_BATCH_HASH, + PI_PARENT_STATE_ROOT, +}; use ark_std::{end_timer, start_timer}; use halo2_base::{Context, ContextParams}; + +#[cfg(not(feature = "disable_proof_aggregation"))] +use halo2_ecc::{ecc::EccChip, fields::fp::FpConfig}; + use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, halo2curves::bn256::{Bn256, Fr, G1Affine}, @@ -14,11 +22,11 @@ use std::rc::Rc; use std::{env, fs::File}; #[cfg(not(feature = "disable_proof_aggregation"))] -use snark_verifier::loader::halo2::halo2_ecc::halo2_base; +use snark_verifier::loader::halo2::{halo2_ecc::halo2_base::AssignedValue, Halo2Loader}; use snark_verifier::pcs::kzg::KzgSuccinctVerifyingKey; #[cfg(not(feature = "disable_proof_aggregation"))] use snark_verifier::{ - loader::halo2::{halo2_ecc::halo2_base::AssignedValue, Halo2Loader}, + loader::halo2::halo2_ecc::halo2_base, pcs::kzg::{Bdfg21, Kzg}, }; #[cfg(not(feature = "disable_proof_aggregation"))] @@ -35,18 +43,23 @@ use crate::{ AssignedBarycentricEvaluationConfig, ConfigParams, }; -use super::AggregationConfig; +use super::BatchCircuitConfig; -/// Aggregation circuit that does not re-expose any public inputs from aggregated snarks +/// Batch circuit, the chunk aggregation routine below recursion circuit #[derive(Clone)] -pub struct AggregationCircuit { +pub struct BatchCircuit { pub svk: KzgSuccinctVerifyingKey, // the input snarks for the aggregation circuit // it is padded already so it will have a fixed length of N_SNARKS pub snarks_with_padding: Vec, // the public instance for this circuit consists of // - an accumulator (12 elements) - // - the batch's public_input_hash (32 elements) + // - parent_state_root (2 elements, split hi_lo) + // - parent_batch_hash (2 elements) + // - current_state_root (2 elements) + // - current_batch_hash (2 elements) + // - chain id (1 element) + // - current_withdraw_root (2 elements) pub flattened_instances: Vec, // accumulation scheme proof, private input pub as_proof: Value>, @@ -55,7 +68,7 @@ pub struct AggregationCircuit { pub batch_hash: BatchHash, } -impl AggregationCircuit { +impl BatchCircuit { pub fn new( params: &ParamsKZG, snarks_with_padding: &[Snark], @@ -95,14 +108,21 @@ impl AggregationCircuit { let (as_proof, acc_instances) = extract_proof_and_instances_with_pairing_check(params, snarks_with_padding, rng)?; - // extract batch's public input hash - let public_input_hash = &batch_hash.instances_exclude_acc()[0]; - // the public instance for this circuit consists of // - an accumulator (12 elements) - // - the batch's public_input_hash (32 elements) - let flattened_instances: Vec = - [acc_instances.as_slice(), public_input_hash.as_slice()].concat(); + // - parent_state_root (2 elements, split hi_lo) + // - parent_batch_hash (2 elements) + // - current_state_root (2 elements) + // - current_batch_hash (2 elements) + // - chain id (1 element) + // - current_withdraw_root (2 elements) + let flattened_instances: Vec = [ + acc_instances.as_slice(), + batch_hash.instances_exclude_acc::()[0] + .clone() + .as_slice(), + ] + .concat(); end_timer!(timer); Ok(Self { @@ -119,8 +139,8 @@ impl AggregationCircuit { } } -impl Circuit for AggregationCircuit { - type Config = (AggregationConfig, Challenges); +impl Circuit for BatchCircuit { + type Config = (BatchCircuitConfig, Challenges); type FloorPlanner = SimpleFloorPlanner; fn without_witnesses(&self) -> Self { unimplemented!() @@ -138,7 +158,7 @@ impl Circuit for AggregationCircuit { ); let challenges = Challenges::construct_p1(meta); - let config = AggregationConfig::configure(meta, ¶ms, challenges); + let config = BatchCircuitConfig::configure(meta, ¶ms, challenges); log::info!( "aggregation circuit configured with k = {} and {:?} advice columns", params.degree, @@ -206,6 +226,7 @@ impl Circuit for AggregationCircuit { #[cfg(not(feature = "disable_proof_aggregation"))] let (accumulator_instances, snark_inputs, barycentric) = { + use halo2_proofs::halo2curves::bn256::Fq; let mut first_pass = halo2_base::SKIP_FIRST_PASS; let (accumulator_instances, snark_inputs, barycentric) = layouter.assign_region( @@ -224,6 +245,7 @@ impl Circuit for AggregationCircuit { let mut accumulator_instances: Vec> = vec![]; // stores public inputs for all snarks, including the padded ones let mut snark_inputs: Vec> = vec![]; + let ctx = Context::new( region, ContextParams { @@ -234,15 +256,16 @@ impl Circuit for AggregationCircuit { ); let ecc_chip = config.ecc_chip(); - let loader = Halo2Loader::new(ecc_chip, ctx); + let loader: Rc>>> = + Halo2Loader::new(ecc_chip, ctx); // // extract the assigned values for // - instances which are the public inputs of each chunk (prefixed with 12 // instances from previous accumulators) - // - new accumulator to be verified on chain + // - new accumulator // - log::debug!("aggregation: assigning aggregation"); + log::debug!("aggregation: chunk aggregation"); let (assigned_aggregation_instances, acc) = aggregate::>( &self.svk, &loader, @@ -266,10 +289,12 @@ impl Circuit for AggregationCircuit { .flat_map(|instance_column| instance_column.iter().skip(ACC_LEN)), ); - loader.ctx_mut().print_stats(&["snark aggregation"]); + loader + .ctx_mut() + .print_stats(&["snark aggregation [chunks -> batch]"]); let mut ctx = Rc::into_inner(loader).unwrap().into_ctx(); - log::debug!("aggregation: assigning barycentric"); + log::debug!("batching: assigning barycentric"); let barycentric = config.barycentric.assign( &mut ctx, &self.batch_hash.point_evaluation_assignments.coefficients, @@ -287,12 +312,12 @@ impl Circuit for AggregationCircuit { }, )?; - assert_eq!(snark_inputs.len(), N_SNARKS * DIGEST_LEN); (accumulator_instances, snark_inputs, barycentric) }; end_timer!(timer); + // ============================================== - // step 2: public input aggregation circuit + // step 2: public input batch circuit // ============================================== // extract all the hashes and load them to the hash table let challenges = challenge.values(&layouter); @@ -307,7 +332,7 @@ impl Circuit for AggregationCircuit { let timer = start_timer!(|| "extract hash"); // orders: - // - batch_public_input_hash + // - batch_hash // - chunk\[i\].piHash for i in \[0, N_SNARKS) // - batch_data_hash_preimage // - preimage for blob metadata @@ -346,26 +371,25 @@ impl Circuit for AggregationCircuit { assigned_batch_hash }; - // digests - let (batch_pi_hash_digest, chunk_pi_hash_digests, _potential_batch_data_hash_digest) = + + // Extract digests + #[cfg(feature = "disable_proof_aggregation")] + let (_batch_hash_digest, _chunk_pi_hash_digests, _potential_batch_data_hash_digest) = parse_hash_digest_cells::(&assigned_batch_hash.hash_output); - // ============================================== - // step 3: assert public inputs to the snarks are correct - // ============================================== - for (i, chunk) in chunk_pi_hash_digests.iter().enumerate() { - let hash = self.batch_hash.chunks_with_padding[i].public_input_hash(); - for j in 0..DIGEST_LEN { - log::trace!("pi {:02x} {:?}", hash[j], chunk[j].value()); - } - } + #[cfg(not(feature = "disable_proof_aggregation"))] + let (_batch_hash_digest, chunk_pi_hash_digests, _potential_batch_data_hash_digest) = + parse_hash_digest_cells::(&assigned_batch_hash.hash_output); + // ======================================================================== + // step 2.a: check accumulator including public inputs to the snarks + // ======================================================================== #[cfg(not(feature = "disable_proof_aggregation"))] let mut first_pass = halo2_base::SKIP_FIRST_PASS; #[cfg(not(feature = "disable_proof_aggregation"))] layouter.assign_region( - || "pi checks", + || "BatchCircuit: Chunk PI", |mut region| -> Result<(), Error> { if first_pass { // this region only use copy constraints and do not affect the shape of the @@ -398,10 +422,6 @@ impl Circuit for AggregationCircuit { }, )?; - // ============================================== - // step 4: assert public inputs to the aggregator circuit are correct - // ============================================== - // accumulator #[cfg(not(feature = "disable_proof_aggregation"))] { assert!(accumulator_instances.len() == ACC_LEN); @@ -410,19 +430,28 @@ impl Circuit for AggregationCircuit { } } - // public input hash - for (index, batch_pi_hash_digest_cell) in batch_pi_hash_digest.iter().enumerate() { - log::trace!( - "pi (circuit vs real): {:?} {:?}", - batch_pi_hash_digest_cell.value(), - self.instances()[0][index + ACC_LEN] - ); - - layouter.constrain_instance( - batch_pi_hash_digest_cell.cell(), - config.instance, - index + ACC_LEN, - )?; + // ======================================================================== + // step 2.b: constrain extracted public input cells against actual instance + // ======================================================================== + let hash_derived_public_input_cells = assigned_batch_hash.hash_derived_public_input_cells; + let instance_offsets: Vec = vec![ + PI_PARENT_BATCH_HASH, + PI_PARENT_BATCH_HASH + 1, + PI_CURRENT_BATCH_HASH, + PI_CURRENT_BATCH_HASH + 1, + PI_PARENT_STATE_ROOT, + PI_PARENT_STATE_ROOT + 1, + PI_CURRENT_STATE_ROOT, + PI_CURRENT_STATE_ROOT + 1, + PI_CURRENT_WITHDRAW_ROOT, + PI_CURRENT_WITHDRAW_ROOT + 1, + PI_CHAIN_ID, + ]; + for (c, inst_offset) in hash_derived_public_input_cells + .into_iter() + .zip(instance_offsets.into_iter()) + { + layouter.constrain_instance(c.cell(), config.instance, inst_offset)?; } // blob data config @@ -489,7 +518,7 @@ impl Circuit for AggregationCircuit { address_table_arr, sequence_exec_info_arr, &challenges, - LOG_DEGREE, // TODO: configure k for aggregation circuit instead of hard-coded here. + LOG_DEGREE, // TODO: configure k for batch circuit instead of hard-coded here. )?; layouter.assign_region( @@ -571,15 +600,18 @@ impl Circuit for AggregationCircuit { } } -impl CircuitExt for AggregationCircuit { +impl CircuitExt for BatchCircuit { fn num_instance(&self) -> Vec { - // 12 elements from accumulator - // 32 elements from batch's public_input_hash - vec![ACC_LEN + DIGEST_LEN] + // - 12 elements from accumulator + // - parent_state_root (2 elements, split hi_lo) + // - parent_batch_hash (2 elements) + // - current_state_root (2 elements) + // - current_batch_hash (2 elements) + // - chain id (1 element) + // - current_withdraw_root (2 elements) + vec![ACC_LEN + 11] } - // 12 elements from accumulator - // 32 elements from batch's public_input_hash fn instances(&self) -> Vec> { vec![self.flattened_instances.clone()] } diff --git a/aggregator/src/aggregation/config.rs b/aggregator/src/aggregation/config.rs index 1db10fd8b1..0d744ec060 100644 --- a/aggregator/src/aggregation/config.rs +++ b/aggregator/src/aggregation/config.rs @@ -25,9 +25,9 @@ use crate::{ #[derive(Debug, Clone)] #[rustfmt::skip] -/// Configurations for aggregation circuit. +/// Configurations for batch circuit. /// This config is hard coded for BN256 curve. -pub struct AggregationConfig { +pub struct BatchCircuitConfig { /// Non-native field chip configurations pub base_field_config: FpConfig, /// Keccak circuit configurations @@ -44,12 +44,16 @@ pub struct AggregationConfig { pub barycentric: BarycentricEvaluationConfig, /// Instance for public input; stores /// - accumulator from aggregation (12 elements) - /// - batch_public_input_hash (32 elements) - /// - the number of valid SNARKs (1 element) + /// - chain id (1 element) + /// - parent_state_root (2 elements, split hi_lo) + /// - parent_batch_hash (2 elements) + /// - current_state_root (2 elements) + /// - current_batch_hash (2 elements) + /// - current_withdraw_root (2 elements) pub instance: Column, } -impl AggregationConfig { +impl BatchCircuitConfig { /// Build a configuration from parameters. pub fn configure( meta: &mut ConstraintSystem, @@ -61,7 +65,7 @@ impl AggregationConfig { "For now we fix limb_bits = {BITS}, otherwise change code", ); - // hash configuration for aggregation circuit + // hash configuration for batch circuit let (keccak_table, keccak_circuit_config) = { let keccak_table = KeccakTable::construct(meta); @@ -80,7 +84,7 @@ impl AggregationConfig { // RLC configuration let rlc_config = RlcConfig::configure(meta, &keccak_table, challenges); - // base field configuration for aggregation circuit + // base field configuration for batch circuit let base_field_config = FpConfig::configure( meta, params.strategy.clone(), @@ -127,12 +131,14 @@ impl AggregationConfig { // Zstd decoder. let pow_rand_table = PowOfRandTable::construct(meta, &challenges_expr); + let pow2_table = Pow2Table::construct(meta); let range8 = RangeTable::construct(meta); let range16 = RangeTable::construct(meta); let range512 = RangeTable::construct(meta); let range_block_len = RangeTable::construct(meta); let bitwise_op_table = BitwiseOpTable::construct(meta); + let decoder_config = DecoderConfig::configure( meta, &challenges_expr, @@ -149,9 +155,14 @@ impl AggregationConfig { ); // Instance column stores public input column + // the public instance for this circuit consists of // - the accumulator - // - the batch public input hash - // - the number of valid SNARKs + // - chain id (1 element) + // - parent_state_root (2 elements, split hi_lo) + // - parent_batch_hash (2 elements) + // - current_state_root (2 elements) + // - current_batch_hash (2 elements) + // - current_withdraw_root (2 elements) let instance = meta.instance_column(); meta.enable_equality(instance); @@ -192,7 +203,7 @@ impl AggregationConfig { } #[test] -fn aggregation_circuit_degree() { +fn batch_circuit_degree() { use halo2_ecc::fields::fp::FpStrategy; let mut cs = ConstraintSystem::::default(); let param = ConfigParams { @@ -206,7 +217,7 @@ fn aggregation_circuit_degree() { num_limbs: 3, }; let challenges = Challenges::construct_p1(&mut cs); - AggregationConfig::<{ crate::constants::MAX_AGG_SNARKS }>::configure( + BatchCircuitConfig::<{ crate::constants::MAX_AGG_SNARKS }>::configure( &mut cs, ¶m, challenges, ); cs = cs.chunk_lookups(); diff --git a/aggregator/src/batch.rs b/aggregator/src/batch.rs index 30df76ba25..bcfebe39d2 100644 --- a/aggregator/src/batch.rs +++ b/aggregator/src/batch.rs @@ -3,13 +3,129 @@ use eth_types::{ToBigEndian, H256}; use ethers_core::utils::keccak256; -use gadgets::Field; +use gadgets::{util::split_h256, Field}; +use serde::{Deserialize, Serialize}; use crate::{ blob::{BatchData, PointEvaluationAssignments}, chunk::ChunkInfo, }; +/// Batch header provides additional fields from the context (within recursion) +/// for constructing the preimage of the batch hash. +#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)] +pub struct BatchHeader { + /// the batch version + pub version: u8, + /// the index of the batch + pub batch_index: u64, + /// Number of L1 messages popped in the batch + pub l1_message_popped: u64, + /// Number of total L1 messages popped after the batch + pub total_l1_message_popped: u64, + /// The parent batch hash + pub parent_batch_hash: H256, + /// The timestamp of the last block in this batch + pub last_block_timestamp: u64, + /// The data hash of the batch + pub data_hash: H256, + /// The versioned hash of the blob with this batch's data + pub blob_versioned_hash: H256, + /// The blob data proof: z (32), y (32) + pub blob_data_proof: [H256; 2], +} + +impl BatchHeader { + /// Constructs the correct batch header from chunks data and context variables + pub fn construct_from_chunks( + version: u8, + batch_index: u64, + l1_message_popped: u64, + total_l1_message_popped: u64, + parent_batch_hash: H256, + last_block_timestamp: u64, + chunks: &[ChunkInfo], + ) -> Self { + assert_ne!(chunks.len(), 0); + assert!(chunks.len() <= N_SNARKS); + + let mut chunks_with_padding = chunks.to_vec(); + if chunks.len() < N_SNARKS { + let last_chunk = chunks.last().unwrap(); + let mut padding_chunk = last_chunk.clone(); + padding_chunk.is_padding = true; + chunks_with_padding + .extend(std::iter::repeat(padding_chunk).take(N_SNARKS - chunks.len())); + } + + let number_of_valid_chunks = match chunks_with_padding + .iter() + .enumerate() + .find(|(_index, chunk)| chunk.is_padding) + { + Some((index, _)) => index, + None => N_SNARKS, + }; + + let batch_data_hash_preimage = chunks_with_padding + .iter() + .take(number_of_valid_chunks) + .flat_map(|chunk_info| chunk_info.data_hash.0.iter()) + .cloned() + .collect::>(); + let batch_data_hash = keccak256(batch_data_hash_preimage); + + let batch_data = BatchData::::new(number_of_valid_chunks, &chunks_with_padding); + let point_evaluation_assignments = PointEvaluationAssignments::from(&batch_data); + + Self { + version, + batch_index, + l1_message_popped, + total_l1_message_popped, + parent_batch_hash, + last_block_timestamp, + data_hash: batch_data_hash.into(), + blob_versioned_hash: batch_data.get_versioned_hash(), + blob_data_proof: [ + H256::from_slice(&point_evaluation_assignments.challenge.to_be_bytes()), + H256::from_slice(&point_evaluation_assignments.evaluation.to_be_bytes()), + ], + } + } + + /// Returns the batch hash as per BatchHeaderV3. + pub fn batch_hash(&self) -> H256 { + // the current batch hash is build as + // keccak256( + // version || + // batch_index || + // l1_message_popped || + // total_l1_message_popped || + // batch_data_hash || + // versioned_hash || + // parent_batch_hash || + // last_block_timestamp || + // z || + // y + // ) + let batch_hash_preimage = [ + vec![self.version].as_slice(), + self.batch_index.to_be_bytes().as_ref(), + self.l1_message_popped.to_be_bytes().as_ref(), + self.total_l1_message_popped.to_be_bytes().as_ref(), + self.data_hash.as_bytes(), + self.blob_versioned_hash.as_bytes(), + self.parent_batch_hash.as_bytes(), + self.last_block_timestamp.to_be_bytes().as_ref(), + self.blob_data_proof[0].to_fixed_bytes().as_ref(), + self.blob_data_proof[1].to_fixed_bytes().as_ref(), + ] + .concat(); + keccak256(batch_hash_preimage).into() + } +} + #[derive(Default, Debug, Clone)] /// A batch is a set of N_SNARKS num of continuous chunks /// - the first k chunks are from real traces @@ -17,6 +133,9 @@ use crate::{ /// A BatchHash consists of 2 hashes. /// - batch_pi_hash := keccak(chain_id || chunk_0.prev_state_root || chunk_k-1.post_state_root || /// chunk_k-1.withdraw_root || batch_data_hash || z || y || versioned_hash) +/// +/// - batchHash := keccak256(version || batch_index || l1_message_popped || total_l1_message_popped || +/// batch_data_hash || versioned_hash || parent_batch_hash || last_block_timestamp || z || y) /// - batch_data_hash := keccak(chunk_0.data_hash || ... || chunk_k-1.data_hash) pub struct BatchHash { /// Chain ID of the network. @@ -25,24 +144,36 @@ pub struct BatchHash { /// - the first [0..number_of_valid_chunks) are real ones /// - the last [number_of_valid_chunks, N_SNARKS) are padding pub(crate) chunks_with_padding: Vec, + /// the state root of the parent batch + pub(crate) parent_state_root: H256, + /// the state root of the current batch + pub(crate) current_state_root: H256, + /// the withdraw root of the current batch + pub(crate) current_withdraw_root: H256, /// The batch data hash: /// - keccak256([chunk.hash for chunk in batch]) pub(crate) data_hash: H256, - /// The public input hash, as calculated on-chain: - /// - keccak256( chain_id || prev_state_root || next_state_root || withdraw_trie_root || - /// batch_data_hash || z || y || versioned_hash) - pub(crate) public_input_hash: H256, + /// the current batch hash is calculated as: + /// - keccak256( version || batch_index || l1_message_popped || total_l1_message_popped || + /// batch_data_hash || versioned_hash || parent_batch_hash || last_block_timestamp || + /// z || y) + pub(crate) current_batch_hash: H256, /// The number of chunks that contain meaningful data, i.e. not padded chunks. pub(crate) number_of_valid_chunks: usize, /// 4844 point evaluation check related assignments. pub(crate) point_evaluation_assignments: PointEvaluationAssignments, /// The 4844 versioned hash for the blob. pub(crate) versioned_hash: H256, + /// The context batch header + pub(crate) batch_header: BatchHeader, } impl BatchHash { /// Build Batch hash from an ordered list of chunks. Will pad if needed - pub fn construct_with_unpadded(chunks: &[ChunkInfo]) -> Self { + pub fn construct_with_unpadded( + chunks: &[ChunkInfo], + batch_header: BatchHeader, + ) -> Self { assert_ne!(chunks.len(), 0); assert!(chunks.len() <= N_SNARKS); let mut chunks_with_padding = chunks.to_vec(); @@ -58,11 +189,14 @@ impl BatchHash { chunks_with_padding .extend(std::iter::repeat(padding_chunk).take(N_SNARKS - chunks.len())); } - Self::construct(&chunks_with_padding) + Self::construct(&chunks_with_padding, batch_header) } /// Build Batch hash from an ordered list of #N_SNARKS of chunks. - pub fn construct(chunks_with_padding: &[ChunkInfo]) -> Self { + pub fn construct( + chunks_with_padding: &[ChunkInfo], + batch_header: BatchHeader, + ) -> Self { assert_eq!( chunks_with_padding.len(), N_SNARKS, @@ -137,43 +271,38 @@ impl BatchHash { .collect::>(); let batch_data_hash = keccak256(preimage); + assert_eq!( + batch_header.data_hash, + H256::from_slice(&batch_data_hash), + "Expect provided BatchHeader's data_hash field to be correct" + ); + let batch_data = BatchData::::new(number_of_valid_chunks, chunks_with_padding); let point_evaluation_assignments = PointEvaluationAssignments::from(&batch_data); + + assert_eq!( + batch_header.blob_data_proof[0], + H256::from_slice(&point_evaluation_assignments.challenge.to_be_bytes()), + "Expect provided BatchHeader's blob_data_proof field 0 to be correct" + ); + assert_eq!( + batch_header.blob_data_proof[1], + H256::from_slice(&point_evaluation_assignments.evaluation.to_be_bytes()), + "Expect provided BatchHeader's blob_data_proof field 1 to be correct" + ); + let versioned_hash = batch_data.get_versioned_hash(); - // public input hash is build as - // keccak( - // chain_id || - // chunk[0].prev_state_root || - // chunk[k-1].post_state_root || - // chunk[k-1].withdraw_root || - // batch_data_hash || - // z || - // y || - // versioned_hash - // ) - let preimage = [ - chunks_with_padding[0].chain_id.to_be_bytes().as_ref(), - chunks_with_padding[0].prev_state_root.as_bytes(), - chunks_with_padding[N_SNARKS - 1].post_state_root.as_bytes(), - chunks_with_padding[N_SNARKS - 1].withdraw_root.as_bytes(), - batch_data_hash.as_slice(), - point_evaluation_assignments - .challenge - .to_be_bytes() - .as_ref(), - point_evaluation_assignments - .evaluation - .to_be_bytes() - .as_ref(), - versioned_hash.as_bytes(), - ] - .concat(); - let public_input_hash: H256 = keccak256(preimage).into(); + assert_eq!( + batch_header.blob_versioned_hash, versioned_hash, + "Expect provided BatchHeader's blob_versioned_hash field to be correct" + ); + + let current_batch_hash = batch_header.batch_hash(); log::info!( - "batch pi hash {:?}, datahash {}, z {}, y {}, versioned hash {:x}", - public_input_hash, + "batch hash {:?}, datahash {}, z {}, y {}, versioned hash {:x}", + current_batch_hash, hex::encode(batch_data_hash), hex::encode(point_evaluation_assignments.challenge.to_be_bytes()), hex::encode(point_evaluation_assignments.evaluation.to_be_bytes()), @@ -183,11 +312,15 @@ impl BatchHash { Self { chain_id: chunks_with_padding[0].chain_id, chunks_with_padding: chunks_with_padding.to_vec(), + parent_state_root: chunks_with_padding[0].prev_state_root, + current_state_root: chunks_with_padding[N_SNARKS - 1].post_state_root, + current_withdraw_root: chunks_with_padding[N_SNARKS - 1].withdraw_root, data_hash: batch_data_hash.into(), - public_input_hash, + current_batch_hash, number_of_valid_chunks, point_evaluation_assignments, versioned_hash, + batch_header, } } @@ -209,33 +342,45 @@ impl BatchHash { pub(crate) fn extract_hash_preimages(&self) -> Vec> { let mut res = vec![]; - // batchPiHash = - // keccak( - // chain_id || - // chunk[0].prev_state_root || - // chunk[k-1].post_state_root || - // chunk[k-1].withdraw_root || - // batch_data_hash || - // z || - // y || - // blob_versioned_hash - // ) - let batch_public_input_hash_preimage = [ - self.chain_id.to_be_bytes().as_ref(), - self.chunks_with_padding[0].prev_state_root.as_bytes(), - self.chunks_with_padding[N_SNARKS - 1] - .post_state_root - .as_bytes(), - self.chunks_with_padding[N_SNARKS - 1] - .withdraw_root - .as_bytes(), + // batchHash = + // keccak256( + // version || + // batch_index || + // l1_message_popped || + // total_l1_message_popped || + // batch_data_hash || + // versioned_hash || + // parent_batch_hash || + // last_block_timestamp || + // z || + // y + // ) + let batch_hash_preimage = [ + [self.batch_header.version].as_ref(), + self.batch_header.batch_index.to_be_bytes().as_ref(), + self.batch_header.l1_message_popped.to_be_bytes().as_ref(), + self.batch_header + .total_l1_message_popped + .to_be_bytes() + .as_ref(), self.data_hash.as_bytes(), - &self.point_evaluation_assignments.challenge.to_be_bytes(), - &self.point_evaluation_assignments.evaluation.to_be_bytes(), self.versioned_hash.as_bytes(), + self.batch_header.parent_batch_hash.as_bytes(), + self.batch_header + .last_block_timestamp + .to_be_bytes() + .as_ref(), + self.point_evaluation_assignments + .challenge + .to_be_bytes() + .as_ref(), + self.point_evaluation_assignments + .evaluation + .to_be_bytes() + .as_ref(), ] .concat(); - res.push(batch_public_input_hash_preimage); + res.push(batch_hash_preimage); // compute piHash for each chunk for i in [0..N_SNARKS) // chunk[i].piHash = @@ -277,14 +422,35 @@ impl BatchHash { res } - /// Compute the public inputs for this circuit, excluding the accumulator. - /// Content: the public_input_hash + /// Compute the public inputs for this circuit: + /// parent_state_root + /// parent_batch_hash + /// current_state_root + /// current_batch_hash + /// chain_id + /// current_withdraw_hash pub(crate) fn instances_exclude_acc(&self) -> Vec> { - vec![self - .public_input_hash - .as_bytes() - .iter() - .map(|&x| F::from(x as u64)) - .collect()] + let mut res: Vec = [ + self.parent_state_root, + self.batch_header.parent_batch_hash, + self.current_state_root, + self.current_batch_hash, + ] + .map(|h| { + let (hi, lo) = split_h256(h); + vec![hi, lo] + }) + .concat(); + + res.push(F::from(self.chain_id)); + let (withdraw_hi, withdraw_lo) = split_h256(self.current_withdraw_root); + res.extend_from_slice(vec![withdraw_hi, withdraw_lo].as_slice()); + + vec![res] + } + + /// ... + pub fn batch_header(&self) -> BatchHeader { + self.batch_header } } diff --git a/aggregator/src/blob.rs b/aggregator/src/blob.rs index aa6871c90d..e6003f429b 100644 --- a/aggregator/src/blob.rs +++ b/aggregator/src/blob.rs @@ -708,7 +708,7 @@ mod tests { vec![vec![]; MAX_AGG_SNARKS], ), ( - "max number of chunkks all non-empty", + "max number of chunks all non-empty", (0..MAX_AGG_SNARKS) .map(|i| (10u8..11 + u8::try_from(i).unwrap()).collect()) .collect(), @@ -742,15 +742,33 @@ mod tests { ] .iter() { + // batch header + let batch_header = crate::batch::BatchHeader { + version: 3, + batch_index: 6789, + l1_message_popped: 101, + total_l1_message_popped: 10101, + parent_batch_hash: H256::repeat_byte(1), + last_block_timestamp: 192837, + ..Default::default() + }; + let chunks_without_padding = crate::chunk::ChunkInfo::mock_chunk_infos(tcase); + let batch_hash = BatchHash::::construct_with_unpadded( + &chunks_without_padding, + batch_header, + ); + + // blob data let batch_data: BatchData = tcase.into(); let point_evaluation_assignments = PointEvaluationAssignments::from(&batch_data); let versioned_hash = batch_data.get_versioned_hash(); println!( - "[[ {:60} ]]\nchallenge (z) = {:0>64x}, evaluation (y) = {:0>64x}, versioned hash = {:0>64x}\n\n", + "[[ {:60} ]]\nchallenge (z) = {:0>64x}, evaluation (y) = {:0>64x}, versioned hash = {:0>64x}, batch_hash = {:0>64x}\n\n", annotation, point_evaluation_assignments.challenge, point_evaluation_assignments.evaluation, versioned_hash, + batch_hash.current_batch_hash, ); } } diff --git a/aggregator/src/chunk.rs b/aggregator/src/chunk.rs index 697d35237c..4aad73da65 100644 --- a/aggregator/src/chunk.rs +++ b/aggregator/src/chunk.rs @@ -175,6 +175,49 @@ impl ChunkInfo { H256(keccak256(&self.tx_bytes)) } + #[cfg(test)] + pub(crate) fn mock_chunk_infos(txs_data: &[Vec]) -> Vec { + use crate::MAX_AGG_SNARKS; + + assert!(txs_data.len() <= MAX_AGG_SNARKS); + let state_roots: [H256; MAX_AGG_SNARKS + 1] = (0..=MAX_AGG_SNARKS) + .map(|i| { + let i = i as u8; + let mut state_root = [0u8; 32]; + state_root[31] = i; + state_root.into() + }) + .collect::>() + .try_into() + .expect("should not fail"); + + txs_data + .iter() + .enumerate() + .map(|(i, tx_data)| { + let withdraw_root = { + let mut root = [0u8; 32]; + root[31] = 255 - (i as u8); + root.into() + }; + let data_hash = { + let mut root = [0u8; 32]; + root[0] = 255 - (i as u8); + root.into() + }; + ChunkInfo { + chain_id: 123456, + prev_state_root: state_roots[i], + post_state_root: state_roots[i + 1], + withdraw_root, + data_hash, + tx_bytes: tx_data.to_vec(), + is_padding: false, + } + }) + .collect::>() + } + /// Sample a chunk info from random (for testing) #[cfg(test)] pub(crate) fn mock_random_chunk_info_for_testing(r: &mut R) -> Self { diff --git a/aggregator/src/constants.rs b/aggregator/src/constants.rs index 7eab49c8ff..f7cdc73692 100644 --- a/aggregator/src/constants.rs +++ b/aggregator/src/constants.rs @@ -23,6 +23,7 @@ pub(crate) const LOG_DEGREE: u32 = 21; // - chunk_data_hash 32 bytes // - chunk_tx_data_hash 32 bytes +pub(crate) const CHUNK_CHAIN_ID_INDEX: usize = 0; pub(crate) const PREV_STATE_ROOT_INDEX: usize = 8; pub(crate) const POST_STATE_ROOT_INDEX: usize = 40; pub(crate) const WITHDRAW_ROOT_INDEX: usize = 72; @@ -30,22 +31,43 @@ pub(crate) const CHUNK_DATA_HASH_INDEX: usize = 104; pub(crate) const CHUNK_TX_DATA_HASH_INDEX: usize = 136; // ================================ -// indices for batch pi hash table +// indices for batch hash table // ================================ // // the preimages are arranged as -// - chain_id: 8 bytes -// - prev_state_root 32 bytes -// - post_state_root 32 bytes -// - withdraw_root 32 bytes -// - chunk_data_hash 32 bytes -// - z 32 bytes -// - y 32 bytes -// - versioned_hash 32 bytes +// - version: 1 byte +// - batch_index: 8 bytes +// - l1_message_popped 8 bytes +// - total_l1_message_popped 8 bytes +// - data_hash 32 bytes +// - blob_versioned_hash 32 bytes +// - parent_batch_hash 32 bytes +// - last_block_timestamp 8 bytes +// - z 32 bytes +// - y 32 bytes -pub(crate) const BATCH_Z_OFFSET: usize = 136; -pub(crate) const BATCH_Y_OFFSET: usize = 168; -pub(crate) const BATCH_VH_OFFSET: usize = 200; +pub(crate) const BATCH_DATA_HASH_OFFSET: usize = 25; +pub(crate) const BATCH_BLOB_VERSIONED_HASH_OFFSET: usize = 57; +pub(crate) const BATCH_PARENT_BATCH_HASH: usize = 89; +pub(crate) const BATCH_Z_OFFSET: usize = 129; +pub(crate) const BATCH_Y_OFFSET: usize = 161; + +// ================================ +// indices for public inputs +// ================================ +// +// - parent state root (2 cells: hi, lo) +// - parent batch hash .. +// - current state root .. +// - current batch hash .. +// - chain id (1 Fr cell) +// - current withdraw root .. +pub(crate) const PI_PARENT_STATE_ROOT: usize = ACC_LEN; +pub(crate) const PI_PARENT_BATCH_HASH: usize = ACC_LEN + 2; +pub(crate) const PI_CURRENT_STATE_ROOT: usize = ACC_LEN + 4; +pub(crate) const PI_CURRENT_BATCH_HASH: usize = ACC_LEN + 6; +pub(crate) const PI_CHAIN_ID: usize = ACC_LEN + 8; +pub(crate) const PI_CURRENT_WITHDRAW_ROOT: usize = ACC_LEN + 9; // ================================ // aggregator parameters diff --git a/aggregator/src/core.rs b/aggregator/src/core.rs index 509e229508..6e05c66216 100644 --- a/aggregator/src/core.rs +++ b/aggregator/src/core.rs @@ -33,10 +33,12 @@ use zkevm_circuits::{ use crate::{ constants::{ - BATCH_VH_OFFSET, BATCH_Y_OFFSET, BATCH_Z_OFFSET, CHAIN_ID_LEN, DIGEST_LEN, LOG_DEGREE, + BATCH_BLOB_VERSIONED_HASH_OFFSET, BATCH_Y_OFFSET, BATCH_Z_OFFSET, CHAIN_ID_LEN, DIGEST_LEN, + LOG_DEGREE, }, - util::{assert_conditional_equal, assert_equal, parse_hash_preimage_cells}, - RlcConfig, BITS, CHUNK_DATA_HASH_INDEX, CHUNK_TX_DATA_HASH_INDEX, LIMBS, POST_STATE_ROOT_INDEX, + util::{assert_conditional_equal, parse_hash_preimage_cells}, + RlcConfig, BATCH_DATA_HASH_OFFSET, BATCH_PARENT_BATCH_HASH, BITS, CHUNK_CHAIN_ID_INDEX, + CHUNK_DATA_HASH_INDEX, CHUNK_TX_DATA_HASH_INDEX, LIMBS, POST_STATE_ROOT_INDEX, PREV_STATE_ROOT_INDEX, WITHDRAW_ROOT_INDEX, }; @@ -155,6 +157,9 @@ pub(crate) struct ExtractedHashCells { chunks_are_padding: Vec>, } +// Computed cells to be constrained against public input. These cells are processed into hi/lo format from ExtractedHashCells. +pub(crate) struct HashDerivedPublicInputCells(Vec>); + impl ExtractedHashCells { /// Assign the cells for hash input/outputs and their RLCs. /// Padded the number of hashes to N_SNARKS @@ -180,7 +185,7 @@ impl ExtractedHashCells { let mut data_lens = vec![]; // preimages are padded as follows - // - the first hash is batch_public_input_hash + // - the first hash is batch_hash // - the next hashes are chunk\[i\].piHash, we padded it to N_SNARKS by repeating the last // chunk // - the last hash is batch_data_hash, its input is padded to 32*N_SNARKS @@ -233,11 +238,14 @@ impl ExtractedHashCells { { let mut preimage_cells = vec![]; + for input in batch_data_hash_padded_preimage { let v = Fr::from(input as u64); let cell = plonk_config.load_private(region, &v, offset)?; + preimage_cells.push(cell); } + let input_rlc = plonk_config.rlc_with_flag( region, &preimage_cells, @@ -245,19 +253,23 @@ impl ExtractedHashCells { chunk_is_valid_cell32s, offset, )?; + inputs.push(preimage_cells); input_rlcs.push(input_rlc); } { let mut digest_cells = vec![]; + for output in batch_data_hash_digest.iter() { let v = Fr::from(*output as u64); let cell = plonk_config.load_private(region, &v, offset)?; digest_cells.push(cell); } + let output_rlc = plonk_config.rlc(region, &digest_cells, evm_word_challenge, offset)?; + outputs.push(digest_cells); output_rlcs.push(output_rlc) } @@ -323,6 +335,7 @@ pub(crate) struct AssignedBatchHash { pub(crate) blob: ExpectedBlobCells, pub(crate) num_valid_snarks: AssignedCell, pub(crate) chunks_are_padding: Vec>, + pub(crate) hash_derived_public_input_cells: Vec>, } /// Input the hash input bytes, @@ -332,11 +345,7 @@ pub(crate) struct AssignedBatchHash { // // This function asserts the following constraints on the hashes // -// 1. batch_data_hash digest is reused for public input hash -// 2. batch_pi_hash used same roots as chunk_pi_hash -// 2.1. batch_pi_hash and chunk[0] use a same prev_state_root -// 2.2. batch_pi_hash and chunk[N_SNARKS-1] use a same post_state_root -// 2.3. batch_pi_hash and chunk[N_SNARKS-1] use a same withdraw_root +// 1. batch_data_hash digest is reused for batch hash // 3. batch_data_hash and chunk[i].pi_hash use a same chunk[i].data_hash when chunk[i] is not padded // 4. chunks are continuous: they are linked via the state roots // 5. batch and all its chunks use a same chain id @@ -348,6 +357,7 @@ pub(crate) struct AssignedBatchHash { // - batch's data_hash length is 32 * number_of_valid_snarks // 8. batch data hash is correct w.r.t. its RLCs // 9. is_final_cells are set correctly +#[allow(clippy::too_many_arguments)] pub(crate) fn assign_batch_hashes( keccak_config: &KeccakCircuitConfig, rlc_config: &RlcConfig, @@ -367,27 +377,23 @@ pub(crate) fn assign_batch_hashes( // 6. chunk[i]'s chunk_pi_hash_rlc_cells == chunk[i-1].chunk_pi_hash_rlc_cells when chunk[i] is // padded // 7. batch data hash is correct w.r.t. its RLCs - let extracted_hash_cells = conditional_constraints::( - rlc_config, - layouter, - challenges, - chunks_are_valid, - num_valid_chunks, - preimages, - )?; - - // 2. batch_pi_hash used same roots as chunk_pi_hash - // 2.1. batch_pi_hash and chunk[0] use a same prev_state_root - // 2.2. batch_pi_hash and chunk[N_SNARKS-1] use a same post_state_root - // 2.3. batch_pi_hash and chunk[N_SNARKS-1] use a same withdraw_root - // 5. batch and all its chunks use a same chain id - copy_constraints::(layouter, &extracted_hash_cells.inputs)?; - - let batch_pi_input = &extracted_hash_cells.inputs[0]; //[0..INPUT_LEN_PER_ROUND * 2]; + let (extracted_hash_cells, hash_derived_public_input_cells) = + conditional_constraints::( + rlc_config, + layouter, + challenges, + chunks_are_valid, + num_valid_chunks, + preimages, + )?; + + let batch_hash_input = &extracted_hash_cells.inputs[0]; //[0..INPUT_LEN_PER_ROUND * 2]; let expected_blob_cells = ExpectedBlobCells { - z: batch_pi_input[BATCH_Z_OFFSET..BATCH_Z_OFFSET + DIGEST_LEN].to_vec(), - y: batch_pi_input[BATCH_Y_OFFSET..BATCH_Y_OFFSET + DIGEST_LEN].to_vec(), - versioned_hash: batch_pi_input[BATCH_VH_OFFSET..BATCH_VH_OFFSET + DIGEST_LEN].to_vec(), + z: batch_hash_input[BATCH_Z_OFFSET..BATCH_Z_OFFSET + DIGEST_LEN].to_vec(), + y: batch_hash_input[BATCH_Y_OFFSET..BATCH_Y_OFFSET + DIGEST_LEN].to_vec(), + versioned_hash: batch_hash_input + [BATCH_BLOB_VERSIONED_HASH_OFFSET..BATCH_BLOB_VERSIONED_HASH_OFFSET + DIGEST_LEN] + .to_vec(), chunk_tx_data_digests: (0..N_SNARKS) .map(|i| { extracted_hash_cells.inputs[i + 1] @@ -402,6 +408,7 @@ pub(crate) fn assign_batch_hashes( blob: expected_blob_cells, num_valid_snarks: extracted_hash_cells.num_valid_snarks, chunks_are_padding: extracted_hash_cells.chunks_are_padding, + hash_derived_public_input_cells: hash_derived_public_input_cells.0, }) } @@ -416,13 +423,16 @@ pub(crate) fn assign_keccak_table( let timer = start_timer!(|| ("multi keccak").to_string()); // preimages consists of the following parts - // (1) batchPiHash preimage = - // (chain_id || - // chunk[0].prev_state_root || - // chunk[k-1].post_state_root || - // chunk[k-1].withdraw_root || - // batch_data_hash|| - // z || y ||versioned_hash) + // (1) batchHash preimage = + // (version || + // batch_index || + // l1_message_popped || + // total_l1_message_popped || + // batch_data_hash || + // versioned_hash || + // parent_batch_hash || + // last_block_timestamp || + // z || y) // (2) chunk[i].piHash preimage = // (chain id || // chunk[i].prevStateRoot || chunk[i].postStateRoot || @@ -451,152 +461,9 @@ pub(crate) fn assign_keccak_table( Ok(()) } -// Assert the following constraints -// 2. batch_pi_hash used same roots as chunk_pi_hash -// 2.1. batch_pi_hash and chunk[0] use a same prev_state_root -// 2.2. batch_pi_hash and chunk[N_SNARKS-1] use a same post_state_root -// 2.3. batch_pi_hash and chunk[N_SNARKS-1] use a same withdraw_root -// 5. batch and all its chunks use a same chain id -fn copy_constraints( - layouter: &mut impl Layouter, - hash_input_cells: &[Vec>], -) -> Result<(), Error> { - let mut is_first_time = true; - - layouter - .assign_region( - || "copy constraints", - |mut region| -> Result<(), halo2_proofs::plonk::Error> { - if is_first_time { - // this region only use copy constraints and do not affect the shape of the - // layouter - is_first_time = false; - return Ok(()); - } - // ==================================================== - // parse the hashes - // ==================================================== - // preimages - let ( - batch_pi_hash_preimage, - chunk_pi_hash_preimages, - _potential_batch_data_hash_preimage, - ) = parse_hash_preimage_cells::(hash_input_cells); - - // ==================================================== - // Constraint the relations between hash preimages - // via copy constraints - // ==================================================== - // - // 2 batch_pi_hash used same roots as chunk_pi_hash - // - // batch_pi_hash = - // keccak( - // chain_id || - // chunk[0].prev_state_root || - // chunk[k-1].post_state_root || - // chunk[k-1].withdraw_root || - // batch_data_hash || - // z || - // y || - // versioned_hash - // ) - // - // chunk[i].piHash = - // keccak( - // chain id || - // chunk[i].prevStateRoot || - // chunk[i].postStateRoot || - // chunk[i].withdrawRoot || - // chunk[i].datahash || - // chunk[i].tx_data_hash - // ) - // - // PREV_STATE_ROOT_INDEX, POST_STATE_ROOT_INDEX, WITHDRAW_ROOT_INDEX - // used below are byte positions for - // prev_state_root, post_state_root, withdraw_root - for i in 0..DIGEST_LEN { - // 2.1 chunk[0].prev_state_root - // sanity check - assert_equal( - &batch_pi_hash_preimage[i + PREV_STATE_ROOT_INDEX], - &chunk_pi_hash_preimages[0][i + PREV_STATE_ROOT_INDEX], - format!( - "chunk and batch's prev_state_root do not match: {:?} {:?}", - &batch_pi_hash_preimage[i + PREV_STATE_ROOT_INDEX].value(), - &chunk_pi_hash_preimages[0][i + PREV_STATE_ROOT_INDEX].value(), - ) - .as_str(), - )?; - region.constrain_equal( - batch_pi_hash_preimage[i + PREV_STATE_ROOT_INDEX].cell(), - chunk_pi_hash_preimages[0][i + PREV_STATE_ROOT_INDEX].cell(), - )?; - // 2.2 chunk[k-1].post_state_root - // sanity check - assert_equal( - &batch_pi_hash_preimage[i + POST_STATE_ROOT_INDEX], - &chunk_pi_hash_preimages[N_SNARKS - 1][i + POST_STATE_ROOT_INDEX], - format!( - "chunk and batch's post_state_root do not match: {:?} {:?}", - &batch_pi_hash_preimage[i + POST_STATE_ROOT_INDEX].value(), - &chunk_pi_hash_preimages[N_SNARKS - 1][i + POST_STATE_ROOT_INDEX] - .value(), - ) - .as_str(), - )?; - region.constrain_equal( - batch_pi_hash_preimage[i + POST_STATE_ROOT_INDEX].cell(), - chunk_pi_hash_preimages[N_SNARKS - 1][i + POST_STATE_ROOT_INDEX].cell(), - )?; - // 2.3 chunk[k-1].withdraw_root - assert_equal( - &batch_pi_hash_preimage[i + WITHDRAW_ROOT_INDEX], - &chunk_pi_hash_preimages[N_SNARKS - 1][i + WITHDRAW_ROOT_INDEX], - format!( - "chunk and batch's withdraw_root do not match: {:?} {:?}", - &batch_pi_hash_preimage[i + WITHDRAW_ROOT_INDEX].value(), - &chunk_pi_hash_preimages[N_SNARKS - 1][i + WITHDRAW_ROOT_INDEX].value(), - ) - .as_str(), - )?; - region.constrain_equal( - batch_pi_hash_preimage[i + WITHDRAW_ROOT_INDEX].cell(), - chunk_pi_hash_preimages[N_SNARKS - 1][i + WITHDRAW_ROOT_INDEX].cell(), - )?; - } - - // 5 assert hashes use a same chain id - for (i, chunk_pi_hash_preimage) in chunk_pi_hash_preimages.iter().enumerate() { - for (lhs, rhs) in batch_pi_hash_preimage - .iter() - .take(CHAIN_ID_LEN) - .zip(chunk_pi_hash_preimage.iter().take(CHAIN_ID_LEN)) - { - // sanity check - assert_equal( - lhs, - rhs, - format!( - "chunk_{i} and batch's chain id do not match: {:?} {:?}", - &lhs.value(), - &rhs.value(), - ) - .as_str(), - )?; - region.constrain_equal(lhs.cell(), rhs.cell())?; - } - } - Ok(()) - }, - ) - .map_err(|e| Error::AssertionFailure(format!("assign keccak rows: {e}")))?; - Ok(()) -} - // Assert the following constraints // This function asserts the following constraints on the hashes -// 1. batch_data_hash digest is reused for public input hash +// 1. batch_data_hash digest is reused for batch hash // 3. batch_data_hash and chunk[i].pi_hash use a same chunk[i].data_hash when chunk[i] is not padded // 4. chunks are continuous: they are linked via the state roots // 6. chunk[i]'s chunk_pi_hash_rlc_cells == chunk[i-1].chunk_pi_hash_rlc_cells when chunk[i] is @@ -615,11 +482,14 @@ pub(crate) fn conditional_constraints( chunks_are_valid: &[bool], num_valid_chunks: usize, preimages: &[Vec], -) -> Result, Error> { +) -> Result<(ExtractedHashCells, HashDerivedPublicInputCells), Error> { layouter .assign_region( || "rlc conditional constraints", - |mut region| -> Result, halo2_proofs::plonk::Error> { + |mut region| -> Result< + (ExtractedHashCells, HashDerivedPublicInputCells), + halo2_proofs::plonk::Error, + > { let mut offset = 0; rlc_config.init(&mut region)?; // ==================================================== @@ -630,6 +500,8 @@ pub(crate) fn conditional_constraints( rlc_config.read_challenge1(&mut region, challenges, &mut offset)?; let evm_word_challenge = rlc_config.read_challenge2(&mut region, challenges, &mut offset)?; + let byte_accumulator = + rlc_config.load_private(&mut region, &Fr::from(256), &mut offset)?; let chunk_is_valid_cells = chunks_are_valid .iter() @@ -644,7 +516,7 @@ pub(crate) fn conditional_constraints( let chunk_is_valid_cell32s = chunk_is_valid_cells .iter() - .flat_map(|cell| vec![cell; 32]) + .flat_map(|cell: &AssignedCell| vec![cell; 32]) .cloned() .collect::>(); @@ -683,7 +555,7 @@ pub(crate) fn conditional_constraints( // parse the hashes // ==================================================== // preimages - let (batch_pi_hash_preimage, chunk_pi_hash_preimages, batch_data_hash_preimage) = + let (batch_hash_preimage, chunk_pi_hash_preimages, batch_data_hash_preimage) = parse_hash_preimage_cells::(&assigned_hash_cells.inputs); // ==================================================== @@ -691,27 +563,28 @@ pub(crate) fn conditional_constraints( // ==================================================== // // ==================================================== - // 1. batch_data_hash digest is reused for public input hash + // 1. batch_data_hash digest is reused for batch hash // ==================================================== // - // - // public input hash is build as - // public_input_hash = keccak( - // chain_id || - // chunk[0].prev_state_root || - // chunk[k-1].post_state_root || - // chunk[k-1].withdraw_root || + // batch_hash = keccak256( + // version || + // batch_index || + // l1_message_popped || + // total_l1_message_popped || // batch_data_hash || - // z || y || versioned_hash) + // versioned_hash || + // parent_batch_hash || + // last_block_timestamp || + // z || y) // // batchDataHash = keccak(chunk[0].dataHash || ... || chunk[k-1].dataHash) - // the strategy here is to generate the RLCs of the batch_pi_hash_preimage and + // the strategy here is to generate the RLCs of the batch_hash_preimage and // compare it with batchDataHash's input RLC let batch_data_hash_rlc = rlc_config.rlc( &mut region, - batch_pi_hash_preimage - [CHUNK_DATA_HASH_INDEX..CHUNK_DATA_HASH_INDEX + DIGEST_LEN] + batch_hash_preimage + [BATCH_DATA_HASH_OFFSET..BATCH_DATA_HASH_OFFSET + DIGEST_LEN] .as_ref(), &evm_word_challenge, &mut offset, @@ -731,6 +604,43 @@ pub(crate) fn conditional_constraints( assigned_hash_cells.output_rlcs[N_SNARKS + 1].cell(), )?; + // ==================================================== + // 1.a batch_parent_batch_hash is the same from public input + // ==================================================== + let batch_parent_batch_hash_hi = rlc_config.rlc( + &mut region, + batch_hash_preimage + [BATCH_PARENT_BATCH_HASH..BATCH_PARENT_BATCH_HASH + DIGEST_LEN / 2] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + let batch_parent_batch_hash_lo = rlc_config.rlc( + &mut region, + batch_hash_preimage[BATCH_PARENT_BATCH_HASH + DIGEST_LEN / 2 + ..BATCH_PARENT_BATCH_HASH + DIGEST_LEN] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + + // ==================================================== + // 1.b result batch_hash is the same from public input + // ==================================================== + let batch_hash_results = assigned_hash_cells.outputs[0].clone(); + let batch_hash_hi = rlc_config.rlc( + &mut region, + batch_hash_results[0..DIGEST_LEN / 2].as_ref(), + &byte_accumulator, + &mut offset, + )?; + let batch_hash_lo = rlc_config.rlc( + &mut region, + batch_hash_results[DIGEST_LEN / 2..DIGEST_LEN].as_ref(), + &byte_accumulator, + &mut offset, + )?; + // 3 batch_data_hash and chunk[i].pi_hash use a same chunk[i].data_hash when // chunk[i] is not padded // @@ -782,7 +692,7 @@ pub(crate) fn conditional_constraints( ); // ==================================================== - // 4 __valid__ chunks are continuous: they are linked via the state roots + // 4.a __valid__ chunks are continuous: they are linked via the state roots // ==================================================== // chunk[i].piHash = // keccak( @@ -797,7 +707,7 @@ pub(crate) fn conditional_constraints( &chunk_pi_hash_preimages[i][POST_STATE_ROOT_INDEX + j], &chunk_is_valid_cells[i + 1], format!( - "chunk_{i} is not continuous: {:?} {:?} {:?}", + "chunk_{i} is not continuous (state roots): {:?} {:?} {:?}", &chunk_pi_hash_preimages[i + 1][PREV_STATE_ROOT_INDEX + j].value(), &chunk_pi_hash_preimages[i][POST_STATE_ROOT_INDEX + j].value(), &chunk_is_valid_cells[i + 1].value(), @@ -814,6 +724,34 @@ pub(crate) fn conditional_constraints( } } + // ==================================================== + // 4.b __valid__ chunks are continuous: chain_id are the same + // ==================================================== + for i in 0..N_SNARKS - 1 { + for j in 0..CHAIN_ID_LEN { + // sanity check + assert_conditional_equal( + &chunk_pi_hash_preimages[i + 1][CHUNK_CHAIN_ID_INDEX + j], + &chunk_pi_hash_preimages[i][CHUNK_CHAIN_ID_INDEX + j], + &chunk_is_valid_cells[i + 1], + format!( + "chunk_{i} is not continuous (chain_id): {:?} {:?} {:?}", + &chunk_pi_hash_preimages[i + 1][CHUNK_CHAIN_ID_INDEX + j].value(), + &chunk_pi_hash_preimages[i][CHUNK_CHAIN_ID_INDEX + j].value(), + &chunk_is_valid_cells[i + 1].value(), + ) + .as_str(), + )?; + rlc_config.conditional_enforce_equal( + &mut region, + &chunk_pi_hash_preimages[i + 1][CHUNK_CHAIN_ID_INDEX + j], + &chunk_pi_hash_preimages[i][CHUNK_CHAIN_ID_INDEX + j], + &chunk_is_valid_cells[i + 1], + &mut offset, + )?; + } + } + // ==================================================== // 6. chunk[i]'s chunk_pi_hash_rlc_cells == chunk[i-1].chunk_pi_hash_rlc_cells when // chunk[i] is padded @@ -857,8 +795,115 @@ pub(crate) fn conditional_constraints( assigned_hash_cells.input_rlcs[N_SNARKS + 1].cell(), )?; + // ============================================================================= + // 8. state roots in public input corresponds correctly to chunk-level preimages + // ============================================================================= + + // Values in the public input are split into hi lo components + // To compare byte-wise assigned hash pre-image cells on the chunk-level, + // reconstruct two values for each pre-image. + + // chunk[i].piHash = + // keccak( + // &chain id || + // chunk[i].prevStateRoot || + // chunk[i].postStateRoot || + // chunk[i].withdrawRoot || + // chunk[i].datahash || + // chunk[i].tx_data_hash + // ) + + // BatchCircuit PI + // - parent state root (2 cells: hi, lo) + // - parent batch hash .. + // - current state root .. + // - current batch hash .. + // - chain id (1 Fr cell) + // - current withdraw root .. + + // pi.parent_state_root = chunks[0].prev_state_root + let chunk_prev_state_hi = rlc_config.rlc( + &mut region, + chunk_pi_hash_preimages[0] + [PREV_STATE_ROOT_INDEX..PREV_STATE_ROOT_INDEX + DIGEST_LEN / 2] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + let chunk_prev_state_lo = rlc_config.rlc( + &mut region, + chunk_pi_hash_preimages[0][PREV_STATE_ROOT_INDEX + DIGEST_LEN / 2 + ..PREV_STATE_ROOT_INDEX + DIGEST_LEN] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + + // pi.current_state_root = chunks[N_SNARKS - 1].post_state_root + let chunk_current_state_hi = rlc_config.rlc( + &mut region, + chunk_pi_hash_preimages[N_SNARKS - 1] + [POST_STATE_ROOT_INDEX..POST_STATE_ROOT_INDEX + DIGEST_LEN / 2] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + let chunk_current_state_lo = rlc_config.rlc( + &mut region, + chunk_pi_hash_preimages[N_SNARKS - 1][POST_STATE_ROOT_INDEX + DIGEST_LEN / 2 + ..POST_STATE_ROOT_INDEX + DIGEST_LEN] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + + // pi.current_withdraw_root = chunks[N_SNARKS - 1].withdraw_root + let chunk_current_withdraw_root_hi = rlc_config.rlc( + &mut region, + chunk_pi_hash_preimages[N_SNARKS - 1] + [WITHDRAW_ROOT_INDEX..WITHDRAW_ROOT_INDEX + DIGEST_LEN / 2] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + let chunk_current_withdraw_root_lo = rlc_config.rlc( + &mut region, + chunk_pi_hash_preimages[N_SNARKS - 1] + [WITHDRAW_ROOT_INDEX + DIGEST_LEN / 2..WITHDRAW_ROOT_INDEX + DIGEST_LEN] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + + // pi.chain_id = chunks[N_SNARKS - 1].chain_id + // Note: Chunk-chaining constraints in 4.b guarantee that previously assigned chain_id cells have the same values. + let chunk_chain_id = rlc_config.rlc( + &mut region, + chunk_pi_hash_preimages[N_SNARKS - 1] + [CHUNK_CHAIN_ID_INDEX..CHUNK_CHAIN_ID_INDEX + CHAIN_ID_LEN] + .as_ref(), + &byte_accumulator, + &mut offset, + )?; + log::trace!("rlc chip uses {} rows", offset); - Ok(assigned_hash_cells) + + Ok(( + assigned_hash_cells, + HashDerivedPublicInputCells(vec![ + batch_parent_batch_hash_hi, + batch_parent_batch_hash_lo, + batch_hash_hi, + batch_hash_lo, + chunk_prev_state_hi, + chunk_prev_state_lo, + chunk_current_state_hi, + chunk_current_state_lo, + chunk_current_withdraw_root_hi, + chunk_current_withdraw_root_lo, + chunk_chain_id, + ]), + )) }, ) .map_err(|e| Error::AssertionFailure(format!("aggregation: {e}"))) diff --git a/aggregator/src/lib.rs b/aggregator/src/lib.rs index deea31d278..e31abe4279 100644 --- a/aggregator/src/lib.rs +++ b/aggregator/src/lib.rs @@ -1,5 +1,4 @@ #![feature(lazy_cell)] - /// proof aggregation mod aggregation; /// This module implements `Batch` related data types. @@ -7,6 +6,8 @@ mod aggregation; mod batch; /// blob struct and constants mod blob; +/// Config to recursive aggregate multiple aggregations +mod recursion; // This module implements `Chunk` related data types. // A chunk is a list of blocks. mod chunk; @@ -26,10 +27,11 @@ mod tests; pub use self::core::extract_proof_and_instances_with_pairing_check; pub use aggregation::*; -pub use batch::BatchHash; +pub use batch::{BatchHash, BatchHeader}; pub use blob::BatchData; pub use chunk::ChunkInfo; pub use compression::*; pub use constants::MAX_AGG_SNARKS; pub(crate) use constants::*; pub use param::*; +pub use recursion::*; diff --git a/aggregator/src/recursion.rs b/aggregator/src/recursion.rs new file mode 100644 index 0000000000..8f5a7c55b6 --- /dev/null +++ b/aggregator/src/recursion.rs @@ -0,0 +1,116 @@ +//! A recursion circuit generates a new proof for multiple +//! target circuit (now it is compression circuit) in a recursive fashion +//! It use the begin and final inputs (block hashes) of the aggregated snarks +//! The designation base on https://github.com/axiom-crypto/snark-verifier/blob/main/snark-verifier/examples/recursion.rs + +/// Circuit implementation of recursion circuit. +mod circuit; + +/// Common functionality utilised by the recursion circuit. +mod common; + +/// Config for recursion circuit +mod config; + +/// Some utility functions. +mod util; + +pub use circuit::RecursionCircuit; +pub(crate) use common::dynamic_verify; +pub use util::{gen_recursion_pk, initial_recursion_snark}; + +use halo2_proofs::{ + halo2curves::{ + bn256::{Bn256, Fq, Fr, G1Affine}, + group::ff::Field, + }, + plonk::{Circuit, ConstraintSystem, Error, ProvingKey, Selector, VerifyingKey}, +}; +use itertools::Itertools; +use rand::Rng; +use snark_verifier::{ + loader::{ + halo2::halo2_ecc::halo2_base as sv_halo2_base, native::NativeLoader, Loader, ScalarLoader, + }, + system::halo2::{compile, Config}, + verifier::{PlonkProof, PlonkVerifier}, +}; +use snark_verifier_sdk::{ + types::{PoseidonTranscript, POSEIDON_SPEC}, + CircuitExt, Snark, +}; +use sv_halo2_base::halo2_proofs; + +use crate::constants::{BITS, LIMBS}; + +/// Any data that can be recursively bundled must implement the described state transition +/// trait. +pub trait StateTransition: Sized { + type Input: Clone; + type Circuit: CircuitExt; + + /// Initialise a new type that implements the state transition behaviour. + fn new(state: Self::Input) -> Self; + + /// Transition to the next state. + fn state_transition(&self, round: usize) -> Self::Input; + + /// Returns the number of fields used to represent state. The public input consists of twice + /// this number as both the previous and current states are included in the public input. + fn num_transition_instance() -> usize; + + /// Returns the number of fields required by the circuit in addition to the fields to represent + /// its state. + fn num_additional_instance() -> usize { + 0 + } + + /// The number of instance cells for the circuit. + fn num_instance() -> usize { + Self::num_accumulator_instance() + + Self::num_transition_instance() * 2 + + Self::num_additional_instance() + } + + /// Returns the number of instance cells used to hold the accumulator. + fn num_accumulator_instance() -> usize { + Self::Circuit::accumulator_indices() + .map(|v| v.len()) + .unwrap_or_default() + } + + /// Following is the indices of the layout of instance + /// for StateTransition circuit, the default suppose + /// single col of instance, and the layout is: + /// accumulator | prev_state | state | additional + /// + /// Notice we do not verify the layout of accumulator + /// simply suppose they are put in the beginning + fn accumulator_indices() -> Vec { + let start = 0; + let end = Self::num_accumulator_instance(); + (start..end).collect() + } + + /// The accumulator is followed by the instance cells representing the previous state. + fn state_prev_indices() -> Vec { + let start = Self::num_accumulator_instance(); + let end = start + Self::num_transition_instance(); + (start..end).collect() + } + + /// The previous state is followed by the instance cells representing the current state. + fn state_indices() -> Vec { + let start = Self::num_accumulator_instance() + Self::num_transition_instance(); + let end = start + Self::num_transition_instance(); + (start..end).collect() + } + + /// The indices of any other instances cells in addition to the accumulator and state + /// transition cells. + fn additional_indices() -> Vec { + let start = Self::num_accumulator_instance() + 2 * Self::num_transition_instance(); + let end = Self::num_instance(); + (start..end).collect() + } +} diff --git a/aggregator/src/recursion/circuit.rs b/aggregator/src/recursion/circuit.rs new file mode 100644 index 0000000000..7cb7cb8a9e --- /dev/null +++ b/aggregator/src/recursion/circuit.rs @@ -0,0 +1,477 @@ +#![allow(clippy::type_complexity)] +use std::{fs::File, iter, marker::PhantomData, rc::Rc}; + +use halo2_proofs::{ + circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, +}; +use snark_verifier::{ + loader::halo2::{halo2_ecc::halo2_base as sv_halo2_base, EccInstructions, IntegerInstructions}, + pcs::{ + kzg::{Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey}, + AccumulationScheme, AccumulationSchemeProver, + }, + util::{ + arithmetic::{fe_to_fe, fe_to_limbs}, + hash, + }, +}; +use snark_verifier_sdk::{ + types::{Halo2Loader, Plonk}, + SnarkWitness, +}; +use sv_halo2_base::{ + gates::GateInstructions, halo2_proofs, AssignedValue, Context, ContextParams, + QuantumCell::Existing, +}; + +use crate::param::ConfigParams as BatchCircuitConfigParams; + +use super::*; + +type Svk = KzgSuccinctVerifyingKey; +type Pcs = Kzg; +type As = KzgAs; + +fn select_accumulator<'a>( + loader: &Rc>, + condition: &AssignedValue, + lhs: &KzgAccumulator>>, + rhs: &KzgAccumulator>>, +) -> Result>>, Error> { + let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] + .iter() + .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) + .map(|(lhs, rhs)| { + loader + .ecc_chip() + .select(&mut loader.ctx_mut(), lhs, rhs, condition) + }) + .collect::>() + .try_into() + .unwrap(); + Ok(KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs), + )) +} + +fn accumulate<'a>( + loader: &Rc>, + accumulators: Vec>>>, + as_proof: Value<&'_ [u8]>, +) -> KzgAccumulator>> { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() +} + +#[derive(Clone)] +pub struct RecursionCircuit { + svk: Svk, + default_accumulator: KzgAccumulator, + app: SnarkWitness, + previous: SnarkWitness, + round: usize, + instances: Vec, + as_proof: Value>, + app_is_aggregation: bool, + _marker: PhantomData, +} + +impl RecursionCircuit { + const PREPROCESSED_DIGEST_ROW: usize = 4 * LIMBS; + const INITIAL_STATE_ROW: usize = Self::PREPROCESSED_DIGEST_ROW + 1; + + pub fn new( + params: &ParamsKZG, + app: Snark, + previous: Snark, + rng: impl Rng + Send, + round: usize, + ) -> Self { + let svk = params.get_g()[0].into(); + let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); + + let succinct_verify = |snark: &Snark| { + let mut transcript = PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof) + }; + + let accumulators = iter::empty() + .chain(succinct_verify(&app)) + .chain( + (round > 0) + .then(|| succinct_verify(&previous)) + .unwrap_or_else(|| { + let num_accumulator = 1 + previous.protocol.accumulator_indices.len(); + vec![default_accumulator.clone(); num_accumulator] + }), + ) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, rng).unwrap(); + (accumulator, transcript.finalize()) + }; + + let init_instances = if round > 0 { + // pick from prev snark + Vec::from( + &previous.instances[0][Self::INITIAL_STATE_ROW + ..Self::INITIAL_STATE_ROW + ST::num_transition_instance()], + ) + } else { + // pick from app + ST::state_prev_indices() + .into_iter() + .map(|i| app.instances[0][i]) + .collect::>() + }; + + let state_instances = ST::state_indices() + .into_iter() + .map(|i| &app.instances[0][i]) + .chain( + ST::additional_indices() + .into_iter() + .map(|i| &app.instances[0][i]), + ); + + let preprocessed_digest = { + let inputs = previous + .protocol + .preprocessed + .iter() + .flat_map(|preprocessed| [preprocessed.x, preprocessed.y]) + .map(fe_to_fe) + .chain(previous.protocol.transcript_initial_state) + .collect_vec(); + let mut hasher = hash::Poseidon::from_spec(&NativeLoader, POSEIDON_SPEC.clone()); + hasher.update(&inputs); + hasher.squeeze() + }; + + let instances = [ + accumulator.lhs.x, + accumulator.lhs.y, + accumulator.rhs.x, + accumulator.rhs.y, + ] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain(iter::once(preprocessed_digest)) + .chain(init_instances) + .chain(state_instances.copied()) + .chain(iter::once(Fr::from(round as u64))) + .collect(); + + log::debug!("recursive instance: {:#?}", instances); + + Self { + svk, + default_accumulator, + app: app.into(), + previous: previous.into(), + round, + instances, + as_proof: Value::known(as_proof), + app_is_aggregation: true, + _marker: Default::default(), + } + } + + fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } + + fn load_default_accumulator<'a>( + &self, + loader: &Rc>, + ) -> Result>>, Error> { + let [lhs, rhs] = + [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { + let assigned = loader + .ecc_chip() + .assign_constant(&mut loader.ctx_mut(), default) + .unwrap(); + loader.ec_point_from_assigned(assigned) + }); + Ok(KzgAccumulator::new(lhs, rhs)) + } + + /// Returns the number of instance cells in the Recursion Circuit, help to refine the CircuitExt trait + pub fn num_instance_fixed() -> usize { + // [..lhs, ..rhs, preprocessed_digest, initial_state, state, round] + 4 * LIMBS + 2 * ST::num_transition_instance() + ST::num_additional_instance() + 2 + } +} + +impl Circuit for RecursionCircuit { + type Config = config::RecursionConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + default_accumulator: self.default_accumulator.clone(), + app: self.app.without_witnesses(), + previous: self.previous.without_witnesses(), + round: self.round, + instances: self.instances.clone(), + as_proof: Value::unknown(), + _marker: Default::default(), + app_is_aggregation: self.app_is_aggregation, + } + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let path = std::env::var("BUNDLE_CONFIG") + .unwrap_or_else(|_| "configs/bundle_circuit.config".to_owned()); + let params: BatchCircuitConfigParams = serde_json::from_reader( + File::open(path.as_str()).unwrap_or_else(|err| panic!("{err:?}")), + ) + .unwrap(); + + Self::Config::configure(meta, params) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + config.range().load_lookup_table(&mut layouter)?; + let max_rows = config.range().gate.max_rows; + let main_gate = config.gate(); + + let mut first_pass = halo2_base::SKIP_FIRST_PASS; // assume using simple floor planner + let assigned_instances = layouter.assign_region( + || "", + |region| -> Result, Error> { + if first_pass { + first_pass = false; + return Ok(vec![]); + } + let mut ctx = Context::new( + region, + ContextParams { + max_rows, + num_context_ids: 1, + fixed_columns: config.base_field_config.range.gate.constants.clone(), + }, + ); + + let init_state_row_beg = Self::INITIAL_STATE_ROW; + let state_row_beg = init_state_row_beg + ST::num_transition_instance(); + let addition_state_beg = state_row_beg + ST::num_transition_instance(); + let round_row = addition_state_beg + ST::num_additional_instance(); + log::debug!( + "state position: init {}|cur {}|add {}", + state_row_beg, + addition_state_beg, + round_row + ); + + let [preprocessed_digest, round] = [ + self.instances[Self::PREPROCESSED_DIGEST_ROW], + self.instances[round_row], + ] + .map(|instance| { + main_gate + .assign_integer(&mut ctx, Value::known(instance)) + .unwrap() + }); + + let initial_state = self.instances[init_state_row_beg..state_row_beg] + .iter() + .map(|&instance| { + main_gate + .assign_integer(&mut ctx, Value::known(instance)) + .unwrap() + }) + .collect::>(); + + let state = self.instances[state_row_beg..round_row] + .iter() + .map(|&instance| { + main_gate + .assign_integer(&mut ctx, Value::known(instance)) + .unwrap() + }) + .collect::>(); + + let first_round = main_gate.is_zero(&mut ctx, &round); + let not_first_round = main_gate.not(&mut ctx, Existing(first_round)); + + let loader = Halo2Loader::new(config.ecc_chip(), ctx); + let (mut app_instances, app_accumulators) = + dynamic_verify::(&self.svk, &loader, &self.app, None); + let (mut previous_instances, previous_accumulators) = dynamic_verify::( + &self.svk, + &loader, + &self.previous, + Some(preprocessed_digest), + ); + + let default_accumulator = self.load_default_accumulator(&loader)?; + let previous_accumulators = previous_accumulators + .iter() + .map(|previous_accumulator| { + select_accumulator( + &loader, + &first_round, + &default_accumulator, + previous_accumulator, + ) + }) + .collect::, Error>>()?; + + let KzgAccumulator { lhs, rhs } = accumulate( + &loader, + [app_accumulators, previous_accumulators].concat(), + self.as_proof(), + ); + + let lhs = lhs.into_assigned(); + let rhs = rhs.into_assigned(); + let app_instances = app_instances.pop().unwrap(); + let previous_instances = previous_instances.pop().unwrap(); + + let mut ctx = loader.ctx_mut(); + let initial_state_propagate = initial_state + .iter() + .zip_eq(previous_instances[init_state_row_beg..state_row_beg].iter()) + .zip_eq( + ST::state_prev_indices() + .into_iter() + .map(|i| &app_instances[i]), + ) + .flat_map(|((&st, &previous_st), &app_inst)| { + [ + // Propagate initial_state + ( + main_gate.mul(&mut ctx, Existing(st), Existing(not_first_round)), + previous_st, + ), + // Verify initial_state is same as the first application snark + ( + main_gate.mul(&mut ctx, Existing(st), Existing(first_round)), + main_gate.mul(&mut ctx, Existing(app_inst), Existing(first_round)), + ), + ] + }) + .collect::>(); + + // Verify current state is same as the current application snark + let verify_app_state = state + .iter() + .zip_eq( + ST::state_indices() + .into_iter() + .map(|i| &app_instances[i]) + .chain( + ST::additional_indices() + .into_iter() + .map(|i| &app_instances[i]), + ), + ) + .map(|(&st, &app_inst)| (st, app_inst)) + .collect::>(); + + // Verify previous state (additional state not included) is same as the current application snark + let verify_app_init_state = previous_instances[state_row_beg..addition_state_beg] + .iter() + .zip_eq( + ST::state_prev_indices() + .into_iter() + .map(|i| &app_instances[i]), + ) + .map(|(&st, &app_inst)| { + ( + main_gate.mul(&mut ctx, Existing(app_inst), Existing(not_first_round)), + st, + ) + }) + .collect::>(); + + for (lhs, rhs) in [ + // Propagate preprocessed_digest + ( + main_gate.mul( + &mut ctx, + Existing(preprocessed_digest), + Existing(not_first_round), + ), + previous_instances[Self::PREPROCESSED_DIGEST_ROW], + ), + // Verify round is increased by 1 when not at first round + ( + round, + main_gate.add( + &mut ctx, + Existing(not_first_round), + Existing(previous_instances[round_row]), + ), + ), + ] + .into_iter() + .chain(initial_state_propagate) + .chain(verify_app_state) + .chain(verify_app_init_state) + { + ctx.region.constrain_equal(lhs.cell(), rhs.cell())?; + } + + // IMPORTANT: + config.base_field_config.finalize(&mut ctx); + #[cfg(feature = "display")] + dbg!(ctx.total_advice); + #[cfg(feature = "display")] + println!("Advice columns used: {}", ctx.advice_alloc[0][0].0 + 1); + + Ok([lhs.x(), lhs.y(), rhs.x(), rhs.y()] + .into_iter() + .flat_map(|coordinate| coordinate.limbs()) + .chain(iter::once(&preprocessed_digest)) + .chain(initial_state.iter()) + .chain(state.iter()) + .chain(iter::once(&round)) + .map(|assigned| assigned.cell()) + .collect()) + }, + )?; + + assert_eq!(assigned_instances.len(), self.num_instance()[0]); + for (row, limb) in assigned_instances.into_iter().enumerate() { + layouter.constrain_instance(limb, config.instance, row)?; + } + + Ok(()) + } +} + +impl CircuitExt for RecursionCircuit { + fn num_instance(&self) -> Vec { + vec![Self::num_instance_fixed()] + } + + fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) + } + + fn selectors(config: &Self::Config) -> Vec { + config.base_field_config.range.gate.basic_gates[0] + .iter() + .map(|gate| gate.q_enable) + .collect() + } +} diff --git a/aggregator/src/recursion/common.rs b/aggregator/src/recursion/common.rs new file mode 100644 index 0000000000..dc18efce8a --- /dev/null +++ b/aggregator/src/recursion/common.rs @@ -0,0 +1,86 @@ +use std::rc::Rc; + +use snark_verifier::{ + loader::halo2::EccInstructions, + pcs::{kzg::KzgAccumulator, MultiOpenScheme, PolynomialCommitmentScheme}, + util::hash, +}; +use snark_verifier_sdk::{ + types::{BaseFieldEccChip, Halo2Loader, Plonk}, + SnarkWitness, +}; + +use super::*; + +type AssignedScalar<'a> = >::AssignedScalar; + +fn poseidon>(loader: &L, inputs: &[L::LoadedScalar]) -> L::LoadedScalar { + let mut hasher = hash::Poseidon::from_spec(loader, POSEIDON_SPEC.clone()); + hasher.update(inputs); + hasher.squeeze() +} + +/// It is similar to `succinct_verify` method inside of snark-verifier +/// but allow it allow loader to load preprocessed part as witness (so ANY circuit) +/// can be verified. +pub fn dynamic_verify<'a, PCS>( + svk: &PCS::SuccinctVerifyingKey, + loader: &Rc>, + snark: &SnarkWitness, + preprocessed_digest: Option>, +) -> (Vec>>, Vec) +where + PCS: PolynomialCommitmentScheme< + G1Affine, + Rc>, + Accumulator = KzgAccumulator>>, + > + MultiOpenScheme>>, +{ + let protocol = if let Some(preprocessed_digest) = preprocessed_digest { + let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); + let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); + let inputs = protocol + .preprocessed + .iter() + .flat_map(|preprocessed| { + let assigned = preprocessed.assigned(); + [assigned.x(), assigned.y()] + .map(|coordinate| loader.scalar_from_assigned(*coordinate.native())) + }) + .chain(protocol.transcript_initial_state.clone()) + .collect_vec(); + loader + .assert_eq("", &poseidon(loader, &inputs), &preprocessed_digest) + .unwrap(); + protocol + } else { + snark.protocol.loaded(loader) + }; + + let instances = snark + .instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec(); + let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = Plonk::::read_proof(svk, &protocol, &instances, &mut transcript); + let accumulators = Plonk::::succinct_verify(svk, &protocol, &instances, &proof); + + ( + instances + .into_iter() + .map(|instance| { + instance + .into_iter() + .map(|instance| instance.into_assigned()) + .collect() + }) + .collect(), + accumulators, + ) +} diff --git a/aggregator/src/recursion/config.rs b/aggregator/src/recursion/config.rs new file mode 100644 index 0000000000..68f55ce55e --- /dev/null +++ b/aggregator/src/recursion/config.rs @@ -0,0 +1,64 @@ +use halo2_proofs::plonk::{Column, Instance}; +use snark_verifier::loader::halo2::halo2_ecc::{ + ecc::{BaseFieldEccChip, EccChip}, + fields::fp::FpConfig, + halo2_base::gates::{flex_gate::FlexGateConfig, range::RangeConfig}, +}; + +use crate::param::ConfigParams as RecursionCircuitConfigParams; + +use super::*; + +#[derive(Clone)] +pub struct RecursionConfig { + /// The non-native field arithmetic config from halo2-lib. + pub base_field_config: FpConfig, + /// The single instance column to hold the public input to the [`RecursionCircuit`]. + pub instance: Column, +} + +impl RecursionConfig { + pub fn configure( + meta: &mut ConstraintSystem, + params: RecursionCircuitConfigParams, + ) -> Self { + assert!( + params.limb_bits == BITS && params.num_limbs == LIMBS, + "For now we fix limb_bits = {}, otherwise change code", + BITS + ); + let base_field_config = FpConfig::configure( + meta, + params.strategy, + ¶ms.num_advice, + ¶ms.num_lookup_advice, + params.num_fixed, + params.lookup_bits, + params.limb_bits, + params.num_limbs, + halo2_base::utils::modulus::(), + 0, + params.degree as usize, + ); + + let instance = meta.instance_column(); + meta.enable_equality(instance); + + Self { + base_field_config, + instance, + } + } + + pub fn gate(&self) -> &FlexGateConfig { + &self.base_field_config.range.gate + } + + pub fn range(&self) -> &RangeConfig { + &self.base_field_config.range + } + + pub fn ecc_chip(&self) -> BaseFieldEccChip { + EccChip::construct(self.base_field_config.clone()) + } +} diff --git a/aggregator/src/recursion/util.rs b/aggregator/src/recursion/util.rs new file mode 100644 index 0000000000..14e0203af0 --- /dev/null +++ b/aggregator/src/recursion/util.rs @@ -0,0 +1,187 @@ +use std::path::Path; + +use halo2_proofs::{ + circuit::Layouter, + plonk::keygen_vk, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, +}; +use snark_verifier::{ + pcs::kzg::{Bdfg21, Kzg}, + util::{arithmetic::fe_to_limbs, transcript::TranscriptWrite}, +}; +use snark_verifier_sdk::{gen_pk, CircuitExt, Snark}; + +use super::*; + +mod dummy_circuit { + use super::*; + use std::marker::PhantomData; + + pub struct CsProxy(PhantomData<(F, C)>); + + impl Default for CsProxy { + fn default() -> Self { + Self(Default::default()) + } + } + + impl> Circuit for CsProxy { + type Config = C::Config; + type FloorPlanner = C::FloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + CsProxy(PhantomData) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + C::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // when `C` has simple selectors, we tell `CsProxy` not to over-optimize + // the selectors (e.g., compressing them all into one) by turning all + // selectors on in the first row currently this only works if all simple + // selector columns are used in the actual circuit and there are overlaps + // amongst all enabled selectors (i.e., the actual circuit will not + // optimize constraint system further) + layouter.assign_region( + || "proxy constraint system", + |mut region| { + for q in C::selectors(&config).iter() { + q.enable(&mut region, 0)?; + } + + Ok(()) + }, + )?; + Ok(()) + } + } +} + +/// Generate a "dummy" snark in case we need to "skip" the verify part +/// inside the recursive circuit: cost would be high if we apply conditional +/// selection above the verify circuits (it is in fact a ecc chip, and +/// selection increase the maximum degree by 1). +/// +/// Instead, a "dummy" snark ensure the ecc chip is valid with providen +/// witness and we just skip the output accumulator later it can "mock" any circuit +/// (with vk being provided in argument) specified by ConcreteCircuit. +fn gen_dummy_snark>( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: &[usize], + mut rng: impl Rng + Send, +) -> Snark { + use snark_verifier::cost::CostEstimation; + use std::iter; + type Pcs = Kzg; + + let protocol = compile( + params, + vk, + Config::kzg() + .with_num_instance(Vec::from(num_instance)) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + let instances = num_instance + .iter() + .map(|&n| iter::repeat_with(|| Fr::random(&mut rng)).take(n).collect()) + .collect(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + for _ in 0..protocol + .num_witness + .iter() + .chain(Some(&protocol.quotient.num_chunk())) + .sum::() + { + transcript + .write_ec_point(G1Affine::random(&mut rng)) + .unwrap(); + } + for _ in 0..protocol.evaluations.len() { + transcript.write_scalar(Fr::random(&mut rng)).unwrap(); + } + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..Pcs::estimate_cost(&queries).num_commitment { + transcript + .write_ec_point(G1Affine::random(&mut rng)) + .unwrap(); + } + transcript.finalize() + }; + + Snark::new(protocol, instances, proof) +} + +/// Generate a dummy snark for construct the first recursion snark +/// we should allow it is been generated even without the corresponding +/// vk, which is required when constructing a circuit to generate the pk +pub fn initial_recursion_snark( + params: &ParamsKZG, + recursion_vk: Option<&VerifyingKey>, + mut rng: impl Rng + Send, +) -> Snark { + let mut snark = if let Some(vk) = recursion_vk { + gen_dummy_snark::>( + params, + vk, + &[RecursionCircuit::::num_instance_fixed()], + &mut rng, + ) + } else { + // to generate the pk we need to construct a recursion circuit, + // which require another snark being build from itself (and so, need + // a pk), to break this cycling we use a "dummy" circuit for + // generating the snark + let vk = &keygen_vk( + params, + &dummy_circuit::CsProxy::>::default(), + ) + .unwrap(); + gen_dummy_snark::>( + params, + vk, + &[RecursionCircuit::::num_instance_fixed()], + &mut rng, + ) + }; + + let g = params.get_g(); + // the accumulator must be set to initial state so the first "real" + // recursion circuit (which also merge the accumulator from this snark) + // could start with a correct accumulator state + snark.instances = vec![[g[1].x, g[1].y, g[0].x, g[0].y] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain(std::iter::repeat(Fr::ZERO)) + .take(RecursionCircuit::::num_instance_fixed()) + .collect_vec()]; + + snark +} + +/// Generate the proving key for recursion. +pub fn gen_recursion_pk( + recursion_params: &ParamsKZG, + app_params: &ParamsKZG, + app_vk: &VerifyingKey, + mut rng: impl Rng + Send, + path: Option<&Path>, +) -> ProvingKey { + let app_snark = + gen_dummy_snark::(app_params, app_vk, &[ST::num_instance()], &mut rng); + + let recursive_snark = initial_recursion_snark::(recursion_params, None, &mut rng); + + let recursion = + RecursionCircuit::::new(recursion_params, app_snark, recursive_snark, &mut rng, 0); + gen_pk(recursion_params, &recursion, path) +} diff --git a/aggregator/src/tests.rs b/aggregator/src/tests.rs index 88fdb10710..deeb6f0ab0 100644 --- a/aggregator/src/tests.rs +++ b/aggregator/src/tests.rs @@ -2,6 +2,7 @@ mod aggregation; mod blob; mod compression; mod mock_chunk; +mod recursion; mod rlc; #[macro_export] @@ -23,8 +24,8 @@ macro_rules! layer_0 { ); log::trace!("finished layer 0 pk generation for circuit"); - let snark = - gen_snark_shplonk(¶m, &pk, $circuit.clone(), &mut rng, None::).unwrap(); + let snark = gen_snark_shplonk(¶m, &pk, $circuit.clone(), &mut rng, None::) + .expect("Snark generated successfully"); log::trace!("finished layer 0 snark generation for circuit"); assert!(verify_snark_shplonk::<$circuit_type>( @@ -76,7 +77,7 @@ macro_rules! compression_layer_snark { &mut rng, None::, // Some(&$path.join(Path::new("layer_1.snark"))), ) - .unwrap(); + .expect("Snark generated successfully"); log::trace!( "finished layer {} snark generation for circuit", $layer_index @@ -159,19 +160,19 @@ macro_rules! aggregation_layer_snark { let mut rng = test_rng(); - let aggregation_circuit = AggregationCircuit::new( + let batch_circuit = BatchCircuit::new( &$param, $previous_snarks.as_ref(), &mut rng, $chunks.as_ref(), ); - let pk = gen_pk(&$param, &aggregation_circuit, None); + let pk = gen_pk(&$param, &batch_circuit, None); // build the snark for next layer let snark = gen_snark_shplonk( ¶m, &pk, - aggregation_circuit.clone(), + batch_circuit.clone(), &mut rng, None::, // Some(&$path.join(Path::new("layer_3.snark"))), ); @@ -180,7 +181,7 @@ macro_rules! aggregation_layer_snark { $layer_index ); - assert!(verify_snark_shplonk::( + assert!(verify_snark_shplonk::( ¶m, snark.clone(), pk.get_vk() diff --git a/aggregator/src/tests/aggregation.rs b/aggregator/src/tests/aggregation.rs index b1335ec9de..7167d3de76 100644 --- a/aggregator/src/tests/aggregation.rs +++ b/aggregator/src/tests/aggregation.rs @@ -7,18 +7,22 @@ use snark_verifier::loader::halo2::halo2_ecc::halo2_base::utils::fs::gen_srs; use snark_verifier_sdk::{gen_pk, gen_snark_shplonk, verify_snark_shplonk, CircuitExt}; use crate::{ - aggregation::AggregationCircuit, batch::BatchHash, constants::MAX_AGG_SNARKS, layer_0, - tests::mock_chunk::MockChunkCircuit, ChunkInfo, + aggregation::BatchCircuit, + batch::{BatchHash, BatchHeader}, + constants::MAX_AGG_SNARKS, + layer_0, + tests::mock_chunk::MockChunkCircuit, + ChunkInfo, }; // See https://github.com/scroll-tech/zkevm-circuits/pull/1311#issuecomment-2139559866 #[ignore] #[test] -fn test_max_agg_snarks_aggregation_circuit() { +fn test_max_agg_snarks_batch_circuit() { let k = 21; // This set up requires one round of keccak for chunk's data hash - let circuit: AggregationCircuit = build_new_aggregation_circuit(2, k); + let circuit: BatchCircuit = build_new_batch_circuit(2, k); let instance = circuit.instances(); let mock_prover = MockProver::::run(k, &circuit, instance).unwrap(); mock_prover.assert_satisfied_par(); @@ -26,10 +30,10 @@ fn test_max_agg_snarks_aggregation_circuit() { #[ignore] #[test] -fn test_2_snark_aggregation_circuit() { +fn test_2_snark_batch_circuit() { let k = 21; - let circuit: AggregationCircuit<2> = build_new_aggregation_circuit(1, k); + let circuit: BatchCircuit<2> = build_new_batch_circuit(1, k); let instance = circuit.instances(); let mock_prover = MockProver::::run(k, &circuit, instance).unwrap(); mock_prover.assert_satisfied_par(); @@ -37,10 +41,10 @@ fn test_2_snark_aggregation_circuit() { #[ignore] #[test] -fn test_14_snark_aggregation_circuit() { +fn test_14_snark_batch_circuit() { let k = 21; - let circuit: AggregationCircuit<14> = build_new_aggregation_circuit(12, k); + let circuit: BatchCircuit<14> = build_new_batch_circuit(12, k); let instance = circuit.instances(); let mock_prover = MockProver::::run(k, &circuit, instance).unwrap(); mock_prover.assert_satisfied_par(); @@ -48,7 +52,7 @@ fn test_14_snark_aggregation_circuit() { #[ignore = "it takes too much time"] #[test] -fn test_aggregation_circuit_all_possible_num_snarks() { +fn test_batch_circuit_all_possible_num_snarks() { //env_logger::init(); let k = 20; @@ -56,18 +60,18 @@ fn test_aggregation_circuit_all_possible_num_snarks() { for i in 1..=MAX_AGG_SNARKS { println!("{i} real chunks and {} padded chunks", MAX_AGG_SNARKS - i); // This set up requires one round of keccak for chunk's data hash - let circuit: AggregationCircuit = build_new_aggregation_circuit(i, k); + let circuit: BatchCircuit = build_new_batch_circuit(i, k); let instance = circuit.instances(); let mock_prover = MockProver::::run(k, &circuit, instance).unwrap(); mock_prover.assert_satisfied_par(); } } -/// - Test aggregation proof generation and verification. +/// - Test batch aggregation proof generation and verification. /// - Test a same pk can be used for various number of chunk proofs. #[ignore = "it takes too much time"] #[test] -fn test_aggregation_circuit_full() { +fn test_batch_circuit_full() { //env_logger::init(); let process_id = process::id(); let k = 25; @@ -77,7 +81,7 @@ fn test_aggregation_circuit_full() { fs::create_dir(path).unwrap(); // This set up requires one round of keccak for chunk's data hash - let circuit: AggregationCircuit = build_new_aggregation_circuit(2, k); + let circuit: BatchCircuit = build_new_batch_circuit(2, k); let instance = circuit.instances(); let mock_prover = MockProver::::run(k, &circuit, instance).unwrap(); mock_prover.assert_satisfied_par(); @@ -90,10 +94,11 @@ fn test_aggregation_circuit_full() { let pk = gen_pk(¶m, &circuit, None); log::trace!("finished pk generation for circuit"); - let snark = gen_snark_shplonk(¶m, &pk, circuit.clone(), &mut rng, None::).unwrap(); + let snark = gen_snark_shplonk(¶m, &pk, circuit.clone(), &mut rng, None::) + .expect("Snark generated successfully"); log::trace!("finished snark generation for circuit"); - assert!(verify_snark_shplonk::>( + assert!(verify_snark_shplonk::>( ¶m, snark, pk.get_vk() @@ -101,11 +106,12 @@ fn test_aggregation_circuit_full() { log::trace!("finished verification for circuit"); // This set up requires two rounds of keccak for chunk's data hash - let circuit: AggregationCircuit = build_new_aggregation_circuit(5, k); - let snark = gen_snark_shplonk(¶m, &pk, circuit, &mut rng, None::).unwrap(); + let circuit: BatchCircuit = build_new_batch_circuit(5, k); + let snark = gen_snark_shplonk(¶m, &pk, circuit, &mut rng, None::) + .expect("Snark generated successfully"); log::trace!("finished snark generation for circuit"); - assert!(verify_snark_shplonk::>( + assert!(verify_snark_shplonk::>( ¶m, snark, pk.get_vk() @@ -115,14 +121,14 @@ fn test_aggregation_circuit_full() { #[test] #[ignore = "it takes too much time"] -fn test_aggregation_circuit_variadic() { +fn test_batch_circuit_variadic() { let k = 20; - let circuit1: AggregationCircuit = build_new_aggregation_circuit(5, k); + let circuit1: BatchCircuit = build_new_batch_circuit(5, k); let instance1 = circuit1.instances(); let prover1 = MockProver::::run(k, &circuit1, instance1).unwrap(); - let circuit2: AggregationCircuit = build_new_aggregation_circuit(10, k); + let circuit2: BatchCircuit = build_new_batch_circuit(10, k); let instance2 = circuit2.instances(); let prover2 = MockProver::::run(k, &circuit2, instance2).unwrap(); @@ -130,10 +136,10 @@ fn test_aggregation_circuit_variadic() { assert_eq!(prover1.permutation(), prover2.permutation()); } -fn build_new_aggregation_circuit( +fn build_new_batch_circuit( num_real_chunks: usize, _k: u32, -) -> AggregationCircuit { +) -> BatchCircuit { // inner circuit: Mock circuit let k0 = 8; @@ -180,9 +186,9 @@ fn build_new_aggregation_circuit( // ========================== // batch // ========================== - let batch_hash = BatchHash::construct(&chunks_with_padding); + let batch_hash = BatchHash::construct(&chunks_with_padding, BatchHeader::default()); - AggregationCircuit::new( + BatchCircuit::new( ¶ms, [real_snarks, padded_snarks].concat().as_ref(), rng, diff --git a/aggregator/src/tests/recursion.rs b/aggregator/src/tests/recursion.rs new file mode 100644 index 0000000000..4881f1885d --- /dev/null +++ b/aggregator/src/tests/recursion.rs @@ -0,0 +1,400 @@ +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + halo2curves::bn256::Fr, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Instance, Selector}, + poly::Rotation, + SerdeFormat, +}; +use snark_verifier::loader::halo2::halo2_ecc::halo2_base as sv_halo2_base; +use snark_verifier_sdk::{gen_pk, gen_snark_shplonk, verify_snark_shplonk, CircuitExt, Snark}; +use std::fs; +use sv_halo2_base::utils::fs::gen_srs; + +use crate::{param::ConfigParams as AggregationConfigParams, recursion::*}; +use ark_std::{end_timer, start_timer, test_rng}; + +fn test_recursion_impl(app_degree: u32, init_state: Fr) -> Snark +where + App: CircuitExt + StateTransition, +{ + let app_params = gen_srs(app_degree); + let recursion_config: AggregationConfigParams = + serde_json::from_reader(fs::File::open("configs/bundle_circuit.config").unwrap()).unwrap(); + let k = recursion_config.degree; + let recursion_params = gen_srs(k); + + let app = App::new(Default::default()); + let app_pk = gen_pk(&app_params, &app, None); + let mut rng = test_rng(); + + let pk_time = start_timer!(|| "Generate recursion pk"); + // this is the pk from default app and dummy self-snark + let recursion_pk = gen_recursion_pk::( + &recursion_params, + &app_params, + app_pk.get_vk(), + &mut rng, + None, + ); + end_timer!(pk_time); + + let app = App::new(init_state); + let next_state = app.state_transition(0); + let app_snark = gen_snark_shplonk(&app_params, &app_pk, app, &mut rng, None::) + .expect("Snark generated successfully"); + let init_snark = + initial_recursion_snark::(&recursion_params, Some(recursion_pk.get_vk()), &mut rng); + + let recursion = + RecursionCircuit::::new(&recursion_params, app_snark, init_snark, &mut rng, 0); + + let pk_time = start_timer!(|| "Generate secondary recursion pk for test"); + { + let r_pk_2 = gen_pk(&recursion_params, &recursion, None); + assert_eq!( + r_pk_2.get_vk().to_bytes(SerdeFormat::RawBytesUnchecked), + recursion_pk + .get_vk() + .to_bytes(SerdeFormat::RawBytesUnchecked), + ); + } + end_timer!(pk_time); + + let pf_time = start_timer!(|| "Generate first recursive snark"); + + let snark = gen_snark_shplonk( + &recursion_params, + &recursion_pk, + recursion, + &mut rng, + None::, + ) + .expect("Snark generated successfully"); + + end_timer!(pf_time); + //assert_eq!(final_state, Fr::from(2u64).pow(&[1 << num_round, 0, 0, 0])); + + assert!(verify_snark_shplonk::>( + &recursion_params, + snark.clone(), + recursion_pk.get_vk() + )); + + let app = App::new(next_state); + let app_snark = gen_snark_shplonk(&app_params, &app_pk, app, &mut rng, None::) + .expect("Snark generated successfully"); + + let recursion = + RecursionCircuit::::new(&recursion_params, app_snark, snark, test_rng(), 1); + + let pk_time = start_timer!(|| "Generate third recursion pk for test"); + { + let r_pk_3 = gen_pk(&recursion_params, &recursion, None); + assert_eq!( + r_pk_3.get_vk().to_bytes(SerdeFormat::RawBytesUnchecked), + recursion_pk + .get_vk() + .to_bytes(SerdeFormat::RawBytesUnchecked), + ); + } + end_timer!(pk_time); + + let pf_time = start_timer!(|| "Generate next recursive snark"); + + let snark = gen_snark_shplonk( + &recursion_params, + &recursion_pk, + recursion, + &mut rng, + None::, + ) + .expect("Snark generated successfully"); + + end_timer!(pf_time); + + assert!(verify_snark_shplonk::>( + &recursion_params, + snark.clone(), + recursion_pk.get_vk() + )); + + snark +} + +mod app { + use super::*; + + #[derive(Clone, Default)] + struct Square(Fr); + + impl Circuit for Square { + type Config = Selector; + type FloorPlanner = SimpleFloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let q = meta.selector(); + let i = meta.instance_column(); + meta.create_gate("square", |meta| { + let q = meta.query_selector(q); + let [i, i_w] = [0, 1].map(|rotation| meta.query_instance(i, Rotation(rotation))); + Some(q * (i.clone() * i - i_w)) + }); + q + } + + fn synthesize( + &self, + q: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region(|| "", |mut region| q.enable(&mut region, 0)) + } + } + + impl CircuitExt for Square { + fn num_instance(&self) -> Vec { + vec![2] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0, self.0.square()]] + } + } + + impl StateTransition for Square { + type Input = Fr; + type Circuit = Self; + + fn new(state: Self::Input) -> Self { + Self(state) + } + + fn num_transition_instance() -> usize { + 1 + } + + fn state_transition(&self, _: usize) -> Self::Input { + self.0.square() + } + } + + #[derive(Clone, Default)] + struct SquareBundle(Fr); + + impl StateTransition for SquareBundle { + type Input = Fr; + type Circuit = RecursionCircuit; + + fn new(state: Self::Input) -> Self { + Self(state) + } + + fn num_transition_instance() -> usize { + Square::num_transition_instance() + } + + fn state_transition(&self, _: usize) -> Self::Input { + self.0.square().square() + } + + fn num_additional_instance() -> usize { + 2 + } + + fn state_indices() -> Vec { + let beg = 13 + Self::num_transition_instance(); + (beg..beg + Self::num_transition_instance()).collect() + } + + fn state_prev_indices() -> Vec { + (13..13 + Self::num_transition_instance()).collect() + } + + fn additional_indices() -> Vec { + vec![12, 13 + Self::num_transition_instance() * 2] + } + } + + #[test] + fn test_recursion_circuit() { + test_recursion_impl::(3, Fr::from(2u64)); + } + + #[test] + fn test_recursion_agg_circuit() { + let square_snark1 = test_recursion_impl::(3, Fr::from(2u64)); + let square_snark2 = test_recursion_impl::(3, Fr::from(16u64)); + + let recursion_config: AggregationConfigParams = + serde_json::from_reader(fs::File::open("configs/bundle_circuit.config").unwrap()) + .unwrap(); + let k = recursion_config.degree; + let recursion_params = gen_srs(k); + let mut rng = test_rng(); + + let pk_time = start_timer!(|| "Generate agg recursion pk"); + let recursion_for_pk = RecursionCircuit::::new( + &recursion_params, + square_snark1.clone(), + initial_recursion_snark::(&recursion_params, None, &mut rng), + &mut rng, + 0, + ); + let recursion_pk = gen_pk(&recursion_params, &recursion_for_pk, None); + end_timer!(pk_time); + + let init_snark = initial_recursion_snark::( + &recursion_params, + Some(recursion_pk.get_vk()), + &mut rng, + ); + + let pf_time = start_timer!(|| "Generate first recursive snark"); + let recursion = RecursionCircuit::::new( + &recursion_params, + square_snark1, + init_snark, + &mut rng, + 0, + ); + + let snark = gen_snark_shplonk( + &recursion_params, + &recursion_pk, + recursion, + &mut rng, + None::, + ) + .expect("Snark generated successfully"); + + end_timer!(pf_time); + + let pf_time = start_timer!(|| "Generate second recursive snark"); + let recursion = RecursionCircuit::::new( + &recursion_params, + square_snark2, + snark, + &mut rng, + 1, + ); + + let snark = gen_snark_shplonk( + &recursion_params, + &recursion_pk, + recursion, + &mut rng, + None::, + ) + .expect("Snark generated successfully"); + end_timer!(pf_time); + + assert!(verify_snark_shplonk::>( + &recursion_params, + snark.clone(), + recursion_pk.get_vk() + )); + } +} + +mod app_add_inst { + use super::*; + + #[derive(Clone, Default)] + struct Square(Fr); + + impl Circuit for Square { + type Config = (Selector, Column, Column); + type FloorPlanner = SimpleFloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let q = meta.selector(); + let i = meta.instance_column(); + meta.create_gate("square", |meta| { + let q = meta.query_selector(q); + let [i, i_w] = [0, 1].map(|rotation| meta.query_instance(i, Rotation(rotation))); + Some(q * (i.clone() * i - i_w)) + }); + let s = meta.fixed_column(); + meta.enable_constant(s); + let a = meta.advice_column(); + meta.enable_equality(a); + meta.enable_equality(i); + (q, a, i) + } + + fn synthesize( + &self, + (q, a, i): Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + q.enable(&mut region, 0)?; + q.enable(&mut region, 1)?; + region.assign_advice_from_instance(|| "copy inst (3)", i, 3, a, 0)?; + region.assign_advice_from_constant( + || "fix inst to 42", + a, + 0, + Fr::from(42u64), + )?; + Ok(()) + }, + ) + } + } + + impl CircuitExt for Square { + fn num_instance(&self) -> Vec { + vec![4] + } + + fn instances(&self) -> Vec> { + vec![vec![ + self.0, + self.0.square(), + self.0.square().square(), + Fr::from(42u64), + ]] + } + } + + impl StateTransition for Square { + type Input = Fr; + type Circuit = Self; + + fn new(state: Self::Input) -> Self { + Self(state) + } + + fn num_additional_instance() -> usize { + 2 + } + + fn num_transition_instance() -> usize { + 1 + } + + fn state_transition(&self, _: usize) -> Self::Input { + self.0.square() + } + } + + #[test] + fn test_recursion_circuit() { + test_recursion_impl::(4, Fr::from(2u64)); + } +} diff --git a/aggregator/src/tests/rlc/dynamic_hashes.rs b/aggregator/src/tests/rlc/dynamic_hashes.rs index 6856c6a795..2cf70b41cc 100644 --- a/aggregator/src/tests/rlc/dynamic_hashes.rs +++ b/aggregator/src/tests/rlc/dynamic_hashes.rs @@ -43,7 +43,7 @@ impl Circuit for DynamicHashCircuit { let challenges = Challenges::construct_p1(meta); // hash config - // hash configuration for aggregation circuit + // hash configuration for batch circuit let keccak_circuit_config = { let keccak_table = KeccakTable::construct(meta); let challenges_exprs = challenges.exprs(meta); @@ -221,7 +221,8 @@ fn test_dynamic_hash_circuit() { // pk verifies the original circuit { - let snark = gen_snark_shplonk(¶ms, &pk, circuit, &mut rng, None::).unwrap(); + let snark = gen_snark_shplonk(¶ms, &pk, circuit, &mut rng, None::) + .expect("Snark generated successfully"); assert!(verify_snark_shplonk::( ¶ms, snark, @@ -234,7 +235,8 @@ fn test_dynamic_hash_circuit() { let a: Vec = (0..LEN * 3).map(|x| x as u8).collect::>(); let circuit = DynamicHashCircuit { inputs: a }; - let snark = gen_snark_shplonk(¶ms, &pk, circuit, &mut rng, None::).unwrap(); + let snark = gen_snark_shplonk(¶ms, &pk, circuit, &mut rng, None::) + .expect("Snark generated successfully"); assert!(verify_snark_shplonk::( ¶ms, snark, diff --git a/aggregator/src/util.rs b/aggregator/src/util.rs index dff5c854e0..44f905c37e 100644 --- a/aggregator/src/util.rs +++ b/aggregator/src/util.rs @@ -8,22 +8,6 @@ fn init_env_logger() { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("error")).init(); } -#[inline] -// assert two cells have same value -// (NOT constraining equality in circuit) -pub(crate) fn assert_equal( - a: &AssignedCell, - b: &AssignedCell, - description: &str, -) -> Result<(), Error> { - a.value().zip(b.value()).error_if_known_and(|(&a, &b)| { - if a != b { - log::error!("{description}"); - } - a != b - }) -} - #[inline] // if cond = 1, assert two cells have same value; // (NOT constraining equality in circuit) @@ -55,7 +39,7 @@ pub(crate) fn parse_hash_preimage_cells( // each pi hash has INPUT_LEN_PER_ROUND bytes as input // keccak will pad the input with another INPUT_LEN_PER_ROUND bytes // we extract all those bytes - let batch_pi_hash_preimage = &hash_input_cells[0]; + let batch_hash_preimage = &hash_input_cells[0]; let mut chunk_pi_hash_preimages = vec![]; for i in 0..N_SNARKS { chunk_pi_hash_preimages.push(&hash_input_cells[i + 1]); @@ -63,7 +47,7 @@ pub(crate) fn parse_hash_preimage_cells( let batch_data_hash_preimage = hash_input_cells.last().unwrap(); ( - batch_pi_hash_preimage, + batch_hash_preimage, chunk_pi_hash_preimages, batch_data_hash_preimage, ) @@ -78,14 +62,14 @@ pub(crate) fn parse_hash_digest_cells( Vec<&Vec>>, &[AssignedCell], ) { - let batch_pi_hash_digest = &hash_output_cells[0]; + let batch_hash_digest = &hash_output_cells[0]; let mut chunk_pi_hash_digests = vec![]; for i in 0..N_SNARKS { chunk_pi_hash_digests.push(&hash_output_cells[i + 1]); } let batch_data_hash_digest = &hash_output_cells[N_SNARKS + 1]; ( - batch_pi_hash_digest, + batch_hash_digest, chunk_pi_hash_digests, batch_data_hash_digest, ) diff --git a/gadgets/src/util.rs b/gadgets/src/util.rs index 50726faa9f..a0335437e6 100644 --- a/gadgets/src/util.rs +++ b/gadgets/src/util.rs @@ -2,7 +2,7 @@ use crate::Field; use eth_types::{ evm_types::{GasCost, OpcodeId}, - U256, + H256, U256, }; use halo2_proofs::plonk::Expression; @@ -246,3 +246,60 @@ pub fn split_u256_limb64(value: &U256) -> [U256; 4] { U256([value.0[3], 0, 0, 0]), ] } + +/// Split a 32-bytes hash into (hi, lo) Field elements. +pub fn split_h256(value: H256) -> (F, F) { + let be_bytes = value.to_fixed_bytes(); + let mut hi_le_bytes = [0u8; 32]; + let mut lo_le_bytes = [0u8; 32]; + hi_le_bytes[0x10..0x20].copy_from_slice(&be_bytes[0x00..0x10]); + lo_le_bytes[0x10..0x20].copy_from_slice(&be_bytes[0x10..0x20]); + hi_le_bytes.reverse(); + lo_le_bytes.reverse(); + ( + F::from_repr(hi_le_bytes).expect("try F from 128-bits should not fail"), + F::from_repr(lo_le_bytes).expect("try F from 128-bits should not fail"), + ) +} + +#[cfg(test)] +mod tests { + use eth_types::H256; + use halo2_proofs::halo2curves::bn256::Fr; + + use super::split_h256; + + #[test] + fn test_split_h256() { + let zero = Fr::zero(); + let in_outs = [ + // all zeroes + (H256::zero(), zero, zero), + ( + H256([ + 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0x01, + ]), + Fr::from_raw([0, 1 << 56 /* 256 ^ 7 */, 0, 0]), + Fr::from_raw([0x01, 0, 0, 0]), + ), + // 0xFB, 0xFC, 0, 0, 0, 0, 0, 0, + // 0, 0, 0, 0, 0, 0, 0xFD, 0xFE, + // 0x01, 0x02, 0, 0, 0, 0, 0, 0, + // 0, 0, 0, 0, 0, 0, 0x03, 0x04 + ( + H256([ + 0xfb, 0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfd, 0xfe, 0x01, 0x02, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x03, 0x04, + ]), + Fr::from_raw([0xfd * 256 + 0xfe, 0xfb * (1 << 56) + 0xfc * (1 << 48), 0, 0]), + Fr::from_raw([0x03 * 256 + 0x04, (1 << 56) + 0x02 * (1 << 48), 0, 0]), + ), + ]; + for (hash_in, expected_hi, expected_lo) in in_outs { + let (hi, lo) = split_h256::(hash_in); + assert_eq!(hi, expected_hi); + assert_eq!(lo, expected_lo); + } + } +} diff --git a/prover/src/aggregator.rs b/prover/src/aggregator.rs index cee92f22aa..177c313c1a 100644 --- a/prover/src/aggregator.rs +++ b/prover/src/aggregator.rs @@ -2,5 +2,5 @@ mod prover; mod verifier; pub use self::prover::{check_chunk_hashes, Prover}; -pub use aggregator::{BatchData, BatchHash, MAX_AGG_SNARKS}; +pub use aggregator::{BatchData, BatchHash, BatchHeader, MAX_AGG_SNARKS}; pub use verifier::Verifier; diff --git a/prover/src/aggregator/prover.rs b/prover/src/aggregator/prover.rs index 16223d824c..33e8f820e8 100644 --- a/prover/src/aggregator/prover.rs +++ b/prover/src/aggregator/prover.rs @@ -1,37 +1,51 @@ +use std::{env, iter::repeat}; + +use aggregator::{BatchHash, BatchHeader, ChunkInfo, MAX_AGG_SNARKS}; +use anyhow::{bail, Result}; +use eth_types::H256; +use sha2::{Digest, Sha256}; +use snark_verifier_sdk::Snark; + use crate::{ common, config::{LayerId, AGG_DEGREES}, - consts::{AGG_KECCAK_ROW, AGG_VK_FILENAME, CHUNK_PROTOCOL_FILENAME}, + consts::{BATCH_KECCAK_ROW, BATCH_VK_FILENAME, BUNDLE_VK_FILENAME, CHUNK_PROTOCOL_FILENAME}, io::{force_to_read, try_to_read}, + proof::BundleProof, + types::BundleProvingTask, BatchProof, BatchProvingTask, ChunkProof, }; -use aggregator::{ChunkInfo, MAX_AGG_SNARKS}; -use anyhow::{bail, Result}; -use sha2::{Digest, Sha256}; -use snark_verifier_sdk::Snark; -use std::{env, iter::repeat}; #[derive(Debug)] pub struct Prover { // Make it public for testing with inner functions (unnecessary for FFI). pub prover_impl: common::Prover, pub chunk_protocol: Vec, - raw_vk: Option>, + raw_vk_batch: Option>, + raw_vk_bundle: Option>, } impl Prover { pub fn from_dirs(params_dir: &str, assets_dir: &str) -> Self { - log::debug!("set env KECCAK_ROWS={}", AGG_KECCAK_ROW.to_string()); - env::set_var("KECCAK_ROWS", AGG_KECCAK_ROW.to_string()); + log::debug!("set env KECCAK_ROWS={}", BATCH_KECCAK_ROW.to_string()); + env::set_var("KECCAK_ROWS", BATCH_KECCAK_ROW.to_string()); let prover_impl = common::Prover::from_params_dir(params_dir, &AGG_DEGREES); let chunk_protocol = force_to_read(assets_dir, &CHUNK_PROTOCOL_FILENAME); - let raw_vk = try_to_read(assets_dir, &AGG_VK_FILENAME); - if raw_vk.is_none() { + let raw_vk_batch = try_to_read(assets_dir, &BATCH_VK_FILENAME); + let raw_vk_bundle = try_to_read(assets_dir, &BUNDLE_VK_FILENAME); + if raw_vk_batch.is_none() { + log::warn!( + "batch-prover: {} doesn't exist in {}", + *BATCH_VK_FILENAME, + assets_dir + ); + } + if raw_vk_bundle.is_none() { log::warn!( - "agg-prover: {} doesn't exist in {}", - *AGG_VK_FILENAME, + "batch-prover: {} doesn't exist in {}", + *BUNDLE_VK_FILENAME, assets_dir ); } @@ -39,7 +53,8 @@ impl Prover { Self { prover_impl, chunk_protocol, - raw_vk, + raw_vk_batch, + raw_vk_bundle, } } @@ -60,14 +75,20 @@ impl Prover { }) } - pub fn get_vk(&self) -> Option> { + pub fn get_batch_vk(&self) -> Option> { self.prover_impl .raw_vk(LayerId::Layer4.id()) - .or_else(|| self.raw_vk.clone()) + .or_else(|| self.raw_vk_batch.clone()) + } + + pub fn get_bundle_vk(&self) -> Option> { + self.prover_impl + .raw_vk(LayerId::Layer6.id()) + .or_else(|| self.raw_vk_bundle.clone()) } // Return the EVM proof for verification. - pub fn gen_agg_evm_proof( + pub fn gen_batch_proof( &mut self, batch: BatchProvingTask, name: Option<&str>, @@ -75,10 +96,11 @@ impl Prover { ) -> Result { let name = name.map_or_else(|| batch.identifier(), |name| name.to_string()); - let layer3_snark = self.load_or_gen_last_agg_snark(&name, batch, output_dir)?; + let (layer3_snark, batch_hash) = + self.load_or_gen_last_agg_snark::(&name, batch, output_dir)?; // Load or generate final compression thin EVM proof (layer-4). - let evm_proof = self.prover_impl.load_or_gen_comp_evm_proof( + let layer4_snark = self.prover_impl.load_or_gen_comp_snark( &name, LayerId::Layer4.id(), true, @@ -88,9 +110,10 @@ impl Prover { )?; log::info!("Got final compression thin EVM proof (layer-4): {name}"); - self.check_vk(); + self.check_batch_vk(); - let batch_proof = BatchProof::from(evm_proof.proof); + let pk = self.prover_impl.pk(LayerId::Layer4.id()); + let batch_proof = BatchProof::new(layer4_snark, pk, batch_hash)?; if let Some(output_dir) = output_dir { batch_proof.dump(output_dir, "agg")?; } @@ -100,12 +123,12 @@ impl Prover { // Generate layer3 snark. // Then it could be used to generate a layer4 proof. - pub fn load_or_gen_last_agg_snark( + pub fn load_or_gen_last_agg_snark( &mut self, name: &str, batch: BatchProvingTask, output_dir: Option<&str>, - ) -> Result { + ) -> Result<(Snark, H256)> { let real_chunk_count = batch.chunk_proofs.len(); assert!((1..=MAX_AGG_SNARKS).contains(&real_chunk_count)); @@ -134,22 +157,105 @@ impl Prover { } // Load or generate aggregation snark (layer-3). + let batch_header = BatchHeader::construct_from_chunks( + batch.batch_header.version, + batch.batch_header.batch_index, + batch.batch_header.l1_message_popped, + batch.batch_header.total_l1_message_popped, + batch.batch_header.parent_batch_hash, + batch.batch_header.last_block_timestamp, + &chunk_hashes, + ); + + // sanity check between: + // - BatchHeader supplied from infra + // - BatchHeader re-constructed by circuits + // + // for the fields data_hash, z, y, blob_versioned_hash. + assert_eq!( + batch_header.data_hash, batch.batch_header.data_hash, + "BatchHeader(sanity) mismatch data_hash expected={}, got={}", + batch.batch_header.data_hash, batch_header.data_hash + ); + assert_eq!( + batch_header.blob_data_proof[0], batch.batch_header.blob_data_proof[0], + "BatchHeader(sanity) mismatch blob data proof (z) expected={}, got={}", + batch.batch_header.blob_data_proof[0], batch_header.blob_data_proof[0], + ); + assert_eq!( + batch_header.blob_data_proof[1], batch.batch_header.blob_data_proof[1], + "BatchHeader(sanity) mismatch blob data proof (y) expected={}, got={}", + batch.batch_header.blob_data_proof[1], batch_header.blob_data_proof[1], + ); + assert_eq!( + batch_header.blob_versioned_hash, batch.batch_header.blob_versioned_hash, + "BatchHeader(sanity) mismatch blob versioned hash expected={}, got={}", + batch.batch_header.blob_versioned_hash, batch_header.blob_versioned_hash, + ); + + let batch_hash = batch_header.batch_hash(); + let batch_info: BatchHash = BatchHash::construct(&chunk_hashes, batch_header); + let layer3_snark = self.prover_impl.load_or_gen_agg_snark( name, LayerId::Layer3.id(), LayerId::Layer3.degree(), - &chunk_hashes, + batch_info, &layer2_snarks, output_dir, )?; log::info!("Got aggregation snark (layer-3): {name}"); - Ok(layer3_snark) + Ok((layer3_snark, batch_hash)) + } + + // Given a bundle proving task that consists of a list of batch proofs for all intermediate + // batches, bundles them into a single bundle proof using the RecursionCircuit, effectively + // proving the validity of all those batches. + pub fn gen_bundle_proof( + &mut self, + bundle: BundleProvingTask, + name: Option<&str>, + output_dir: Option<&str>, + ) -> Result { + let name = name.map_or_else(|| bundle.identifier(), |name| name.to_string()); + + let bundle_snarks = bundle + .batch_proofs + .iter() + .map(|proof| proof.into()) + .collect::>(); + + let layer5_snark = self.prover_impl.load_or_gen_recursion_snark( + &name, + LayerId::Layer5.id(), + LayerId::Layer5.degree(), + &bundle_snarks, + output_dir, + )?; + + let layer6_evm_proof = self.prover_impl.load_or_gen_comp_evm_proof( + &name, + LayerId::Layer6.id(), + true, + LayerId::Layer6.degree(), + layer5_snark, + output_dir, + )?; + + self.check_bundle_vk(); + + let bundle_proof: BundleProof = layer6_evm_proof.proof.into(); + if let Some(output_dir) = output_dir { + bundle_proof.dump(output_dir, "recursion")?; + } + + Ok(bundle_proof) } /// Check vk generated is same with vk loaded from assets - fn check_vk(&self) { - if self.raw_vk.is_some() { + fn check_batch_vk(&self) { + if self.raw_vk_batch.is_some() { let gen_vk = self .prover_impl .raw_vk(LayerId::Layer4.id()) @@ -158,10 +264,32 @@ impl Prover { log::warn!("no gen_vk found, skip check_vk"); return; } - let init_vk = self.raw_vk.clone().unwrap_or_default(); + let init_vk = self.raw_vk_batch.clone().unwrap_or_default(); + if gen_vk != init_vk { + log::error!( + "batch-prover: generated VK is different with init one - gen_vk = {}, init_vk = {}", + base64::encode(gen_vk), + base64::encode(init_vk), + ); + } + } + } + + /// Check vk generated is same with vk loaded from assets + fn check_bundle_vk(&self) { + if self.raw_vk_bundle.is_some() { + let gen_vk = self + .prover_impl + .raw_vk(LayerId::Layer6.id()) + .unwrap_or_default(); + if gen_vk.is_empty() { + log::warn!("no gen_vk found, skip check_vk"); + return; + } + let init_vk = self.raw_vk_bundle.clone().unwrap_or_default(); if gen_vk != init_vk { log::error!( - "agg-prover: generated VK is different with init one - gen_vk = {}, init_vk = {}", + "bundle-prover: generated VK is different with init one - gen_vk = {}, init_vk = {}", base64::encode(gen_vk), base64::encode(init_vk), ); diff --git a/prover/src/aggregator/verifier.rs b/prover/src/aggregator/verifier.rs index 4dd6622983..b4c55bba48 100644 --- a/prover/src/aggregator/verifier.rs +++ b/prover/src/aggregator/verifier.rs @@ -1,9 +1,9 @@ use crate::{ common, config::{LAYER4_CONFIG_PATH, LAYER4_DEGREE}, - consts::{agg_vk_filename, DEPLOYMENT_CODE_FILENAME}, - io::force_to_read, - BatchProof, + consts::{batch_vk_filename, DEPLOYMENT_CODE_FILENAME}, + io::{force_to_read, try_to_read}, + proof::BundleProof, }; use aggregator::CompressionCircuit; use halo2_proofs::{ @@ -12,13 +12,14 @@ use halo2_proofs::{ poly::kzg::commitment::ParamsKZG, }; use snark_verifier_sdk::verify_evm_calldata; +use snark_verifier_sdk::Snark; use std::env; #[derive(Debug)] pub struct Verifier { // Make it public for testing with inner functions (unnecessary for FFI). pub inner: common::Verifier, - deployment_code: Vec, + deployment_code: Option>, } impl Verifier { @@ -31,13 +32,13 @@ impl Verifier { Self { inner, - deployment_code, + deployment_code: Some(deployment_code), } } pub fn from_dirs(params_dir: &str, assets_dir: &str) -> Self { - let raw_vk = force_to_read(assets_dir, &agg_vk_filename()); - let deployment_code = force_to_read(assets_dir, &DEPLOYMENT_CODE_FILENAME); + let raw_vk = force_to_read(assets_dir, &batch_vk_filename()); + let deployment_code = try_to_read(assets_dir, &DEPLOYMENT_CODE_FILENAME); env::set_var("COMPRESSION_CONFIG", &*LAYER4_CONFIG_PATH); let inner = common::Verifier::from_params_dir(params_dir, *LAYER4_DEGREE, &raw_vk); @@ -48,7 +49,16 @@ impl Verifier { } } - pub fn verify_agg_evm_proof(&self, batch_proof: BatchProof) -> bool { - verify_evm_calldata(self.deployment_code.clone(), batch_proof.calldata()) + pub fn verify_batch_proof(&self, snark: impl Into) -> bool { + self.inner.verify_snark(snark.into()) + } + + pub fn verify_bundle_proof(&self, bundle_proof: BundleProof) -> bool { + if let Some(deployment_code) = self.deployment_code.clone() { + verify_evm_calldata(deployment_code, bundle_proof.calldata()) + } else { + log::warn!("No deployment_code found for EVM verifier"); + false + } } } diff --git a/prover/src/common/prover.rs b/prover/src/common/prover.rs index 78310c374e..bd7dd5916a 100644 --- a/prover/src/common/prover.rs +++ b/prover/src/common/prover.rs @@ -12,6 +12,7 @@ mod compression; mod evm; mod inner; mod mock; +mod recursion; mod utils; #[derive(Debug)] diff --git a/prover/src/common/prover/aggregation.rs b/prover/src/common/prover/aggregation.rs index 2bf3cdab53..1384d02039 100644 --- a/prover/src/common/prover/aggregation.rs +++ b/prover/src/common/prover/aggregation.rs @@ -4,38 +4,36 @@ use crate::{ io::{load_snark, write_snark}, utils::gen_rng, }; -use aggregator::{AggregationCircuit, BatchHash, ChunkInfo, MAX_AGG_SNARKS}; +use aggregator::{BatchCircuit, BatchHash}; use anyhow::{anyhow, Result}; use rand::Rng; use snark_verifier_sdk::Snark; use std::env; impl Prover { - pub fn gen_agg_snark( + pub fn gen_agg_snark( &mut self, id: &str, degree: u32, mut rng: impl Rng + Send, - chunk_hashes: &[ChunkInfo], + batch_info: BatchHash, previous_snarks: &[Snark], ) -> Result { env::set_var("AGGREGATION_CONFIG", layer_config_path(id)); - let batch_hash = BatchHash::construct(chunk_hashes); - - let circuit: AggregationCircuit = - AggregationCircuit::new(self.params(degree), previous_snarks, &mut rng, batch_hash) + let circuit: BatchCircuit = + BatchCircuit::new(self.params(degree), previous_snarks, &mut rng, batch_info) .map_err(|err| anyhow!("Failed to construct aggregation circuit: {err:?}"))?; self.gen_snark(id, degree, &mut rng, circuit, "gen_agg_snark") } - pub fn load_or_gen_agg_snark( + pub fn load_or_gen_agg_snark( &mut self, name: &str, id: &str, degree: u32, - chunk_hashes: &[ChunkInfo], + batch_info: BatchHash, previous_snarks: &[Snark], output_dir: Option<&str>, ) -> Result { @@ -50,7 +48,7 @@ impl Prover { Some(snark) => Ok(snark), None => { let rng = gen_rng(); - let result = self.gen_agg_snark(id, degree, rng, chunk_hashes, previous_snarks); + let result = self.gen_agg_snark(id, degree, rng, batch_info, previous_snarks); if let (Some(_), Ok(snark)) = (output_dir, &result) { write_snark(&file_path, snark); } diff --git a/prover/src/common/prover/recursion.rs b/prover/src/common/prover/recursion.rs new file mode 100644 index 0000000000..6046a85838 --- /dev/null +++ b/prover/src/common/prover/recursion.rs @@ -0,0 +1,110 @@ +use std::env; + +use aggregator::{initial_recursion_snark, RecursionCircuit, StateTransition, MAX_AGG_SNARKS}; +use anyhow::Result; +use rand::Rng; +use snark_verifier_sdk::{gen_snark_shplonk, Snark}; + +use crate::{ + config::layer_config_path, + io::{load_snark, write_snark}, + recursion::RecursionTask, + utils::gen_rng, +}; + +use super::Prover; + +impl Prover { + pub fn gen_recursion_snark( + &mut self, + id: &str, + degree: u32, + mut rng: impl Rng + Send, + batch_snarks: &[Snark], + ) -> Result { + // We should at least have a single snark. + assert!(!batch_snarks.is_empty()); + + env::set_var("BUNDLE_CONFIG", layer_config_path(id)); + let params = self.params(degree); + + // Generate an initial snark, that represents the start of the recursion process. + let init_snark = + initial_recursion_snark::>(params, None, &mut rng); + + // The recursion circuit's instance based on this initial snark state should not be used as + // the "real" snark output. It doesn't take into account the preprocessed state. The + // recursion circuit needs a verification key, which itself needs the recursion circuit. To + // break this dependency cycle. + let circuit_for_pk = RecursionCircuit::>::new( + params, + batch_snarks[0].clone(), + init_snark, + &mut rng, + 0, + ); + let (params, pk) = self.params_and_pk(id, degree, &circuit_for_pk)?; + + // Using the above generated PK, we can now construct the legitimate starting state. + let mut cur_snark = initial_recursion_snark::>( + params, + Some(pk.get_vk()), + &mut rng, + ); + + // The recursion task is initialised with all the snarks, and the we are at the 0th round + // of recursion at the start. + let mut task = RecursionTask::::new(batch_snarks); + let mut n_rounds = 0; + + while !task.completed() { + log::debug!("construct recursion circuit for round {}", n_rounds); + + let circuit = RecursionCircuit::>::new( + params, + task.iter_snark(), + cur_snark, + &mut rng, + n_rounds, + ); + cur_snark = gen_snark_shplonk(params, pk, circuit, &mut rng, None::)?; + + log::info!("construct recursion snark for round {} ...done", n_rounds); + + // Increment the round of recursion and transition to the next state. + n_rounds += 1; + task = RecursionTask::::new(task.state_transition(n_rounds)); + } + + Ok(cur_snark) + } + + pub fn load_or_gen_recursion_snark( + &mut self, + name: &str, + id: &str, + degree: u32, + batch_snarks: &[Snark], + output_dir: Option<&str>, + ) -> Result { + let file_path = format!( + "{}/recursion_snark_{}_{}.json", + output_dir.unwrap_or_default(), + id, + name + ); + + match output_dir.and_then(|_| load_snark(&file_path).ok().flatten()) { + Some(snark) => Ok(snark), + None => { + let rng = gen_rng(); + let result = self.gen_recursion_snark(id, degree, rng, batch_snarks); + if let (Some(_), Ok(snark)) = (output_dir, &result) { + write_snark(&file_path, snark); + } + + result + } + } + } +} diff --git a/prover/src/config.rs b/prover/src/config.rs index 5874b61420..8db379416f 100644 --- a/prover/src/config.rs +++ b/prover/src/config.rs @@ -16,11 +16,17 @@ pub static LAYER3_CONFIG_PATH: LazyLock = LazyLock::new(|| asset_file_path("layer3.config")); pub static LAYER4_CONFIG_PATH: LazyLock = LazyLock::new(|| asset_file_path("layer4.config")); +pub static LAYER5_CONFIG_PATH: LazyLock = + LazyLock::new(|| asset_file_path("layer5.config")); +pub static LAYER6_CONFIG_PATH: LazyLock = + LazyLock::new(|| asset_file_path("layer6.config")); pub static LAYER1_DEGREE: LazyLock = LazyLock::new(|| layer_degree(&LAYER1_CONFIG_PATH)); pub static LAYER2_DEGREE: LazyLock = LazyLock::new(|| layer_degree(&LAYER2_CONFIG_PATH)); pub static LAYER3_DEGREE: LazyLock = LazyLock::new(|| layer_degree(&LAYER3_CONFIG_PATH)); pub static LAYER4_DEGREE: LazyLock = LazyLock::new(|| layer_degree(&LAYER4_CONFIG_PATH)); +pub static LAYER5_DEGREE: LazyLock = LazyLock::new(|| layer_degree(&LAYER5_CONFIG_PATH)); +pub static LAYER6_DEGREE: LazyLock = LazyLock::new(|| layer_degree(&LAYER6_CONFIG_PATH)); pub static ZKEVM_DEGREES: LazyLock> = LazyLock::new(|| { Vec::from_iter(HashSet::from([ @@ -30,8 +36,14 @@ pub static ZKEVM_DEGREES: LazyLock> = LazyLock::new(|| { ])) }); -pub static AGG_DEGREES: LazyLock> = - LazyLock::new(|| Vec::from_iter(HashSet::from([*LAYER3_DEGREE, *LAYER4_DEGREE]))); +pub static AGG_DEGREES: LazyLock> = LazyLock::new(|| { + Vec::from_iter(HashSet::from([ + *LAYER3_DEGREE, + *LAYER4_DEGREE, + *LAYER5_DEGREE, + *LAYER6_DEGREE, + ])) +}); #[derive(Clone, Copy, Debug)] pub enum LayerId { @@ -41,10 +53,14 @@ pub enum LayerId { Layer1, /// Compression thin layer (to generate chunk-proof) Layer2, - /// Aggregation layer + /// Layer to batch multiple chunk proofs Layer3, /// Compression thin layer (to generate batch-proof) Layer4, + /// Recurse over a bundle of batches + Layer5, + /// Compression thin layer (to generate bundle-proof verifiable in EVM) + Layer6, } impl fmt::Display for LayerId { @@ -61,6 +77,8 @@ impl LayerId { Self::Layer2 => "layer2", Self::Layer3 => "layer3", Self::Layer4 => "layer4", + Self::Layer5 => "layer5", + Self::Layer6 => "layer6", } } @@ -71,6 +89,8 @@ impl LayerId { Self::Layer2 => *LAYER2_DEGREE, Self::Layer3 => *LAYER3_DEGREE, Self::Layer4 => *LAYER4_DEGREE, + Self::Layer5 => *LAYER5_DEGREE, + Self::Layer6 => *LAYER6_DEGREE, } } @@ -80,6 +100,8 @@ impl LayerId { Self::Layer2 => &LAYER2_CONFIG_PATH, Self::Layer3 => &LAYER3_CONFIG_PATH, Self::Layer4 => &LAYER4_CONFIG_PATH, + Self::Layer5 => &LAYER5_CONFIG_PATH, + Self::Layer6 => &LAYER6_CONFIG_PATH, Self::Inner => unreachable!("No config file for super (inner) circuit"), } } @@ -98,6 +120,8 @@ pub fn layer_config_path(id: &str) -> &str { "layer2" => &LAYER2_CONFIG_PATH, "layer3" => &LAYER3_CONFIG_PATH, "layer4" => &LAYER4_CONFIG_PATH, + "layer5" => &LAYER5_CONFIG_PATH, + "layer6" => &LAYER6_CONFIG_PATH, _ => panic!("Wrong id-{id} to get layer config path"), } } diff --git a/prover/src/consts.rs b/prover/src/consts.rs index 7bdef8e292..1cebefccc4 100644 --- a/prover/src/consts.rs +++ b/prover/src/consts.rs @@ -3,18 +3,28 @@ use std::sync::LazyLock; // TODO: is it a good design to use LazyLock? Why not read env var each time? -pub fn agg_vk_filename() -> String { - read_env_var("AGG_VK_FILENAME", "agg_vk.vkey".to_string()) +pub fn bundle_vk_filename() -> String { + read_env_var("BUNDLE_VK_FILENAME", "bundle_vk.vkey".to_string()) +} +pub fn batch_vk_filename() -> String { + read_env_var("BATCH_VK_FILENAME", "batch_vk.vkey".to_string()) } pub fn chunk_vk_filename() -> String { read_env_var("CHUNK_VK_FILENAME", "chunk_vk.vkey".to_string()) } -// For our k=21 agg circuit, 12 means it can include 2**21 / (12 * 25) * 136.0 = 0.95M bytes -pub static AGG_KECCAK_ROW: LazyLock = LazyLock::new(|| read_env_var("AGG_KECCAK_ROW", 12)); -pub static AGG_VK_FILENAME: LazyLock = LazyLock::new(agg_vk_filename); pub static CHUNK_PROTOCOL_FILENAME: LazyLock = LazyLock::new(|| read_env_var("CHUNK_PROTOCOL_FILENAME", "chunk.protocol".to_string())); +pub static BATCH_PROTOCOL_FILENAME: LazyLock = + LazyLock::new(|| read_env_var("BATCH_PROTOCOL_FILENAME", "batch.protocol".to_string())); + pub static CHUNK_VK_FILENAME: LazyLock = LazyLock::new(chunk_vk_filename); +pub static BATCH_VK_FILENAME: LazyLock = LazyLock::new(batch_vk_filename); +pub static BUNDLE_VK_FILENAME: LazyLock = LazyLock::new(bundle_vk_filename); + pub static DEPLOYMENT_CODE_FILENAME: LazyLock = LazyLock::new(|| read_env_var("DEPLOYMENT_CODE_FILENAME", "evm_verifier.bin".to_string())); + +// For our k=21 agg circuit, 12 means it can include 2**21 / (12 * 25) * 136.0 = 0.95M bytes +pub static BATCH_KECCAK_ROW: LazyLock = + LazyLock::new(|| read_env_var("BATCH_KECCAK_ROW", 12)); diff --git a/prover/src/lib.rs b/prover/src/lib.rs index cd568755d1..64a9a14da2 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -17,16 +17,17 @@ mod evm; pub mod inner; pub mod io; pub mod proof; +pub mod recursion; pub mod test; pub mod types; pub mod utils; pub mod zkevm; -pub use aggregator::{check_chunk_hashes, BatchData, BatchHash, MAX_AGG_SNARKS}; +pub use aggregator::{check_chunk_hashes, BatchData, BatchHash, BatchHeader, MAX_AGG_SNARKS}; pub use common::{ChunkInfo, CompressionCircuit}; pub use eth_types; pub use eth_types::l2_types::BlockTrace; -pub use proof::{BatchProof, ChunkProof, EvmProof, Proof}; +pub use proof::{BatchProof, BundleProof, ChunkProof, EvmProof, Proof}; pub use snark_verifier_sdk::{CircuitExt, Snark}; -pub use types::{BatchProvingTask, ChunkProvingTask, WitnessBlock}; +pub use types::{BatchProvingTask, BundleProvingTask, ChunkProvingTask, WitnessBlock}; pub use zkevm_circuits; diff --git a/prover/src/proof.rs b/prover/src/proof.rs index 91a483ae86..5e662794df 100644 --- a/prover/src/proof.rs +++ b/prover/src/proof.rs @@ -13,10 +13,12 @@ use snark_verifier_sdk::{verify_evm_proof, Snark}; use std::{fs::File, path::PathBuf}; mod batch; +mod bundle; mod chunk; mod evm; pub use batch::BatchProof; +pub use bundle::BundleProof; pub use chunk::{compare_chunk_info, ChunkProof}; pub use evm::EvmProof; diff --git a/prover/src/proof/batch.rs b/prover/src/proof/batch.rs index a153fbbd91..5cce5e1a08 100644 --- a/prover/src/proof/batch.rs +++ b/prover/src/proof/batch.rs @@ -1,112 +1,57 @@ -use super::{dump_as_json, dump_data, dump_vk, from_json_file, serialize_instance, Proof}; -use crate::utils::short_git_version; +use super::{dump_as_json, dump_vk, from_json_file, Proof}; +use crate::types::base64; use anyhow::Result; +use eth_types::H256; +use halo2_proofs::{halo2curves::bn256::G1Affine, plonk::ProvingKey}; use serde_derive::{Deserialize, Serialize}; -use snark_verifier_sdk::encode_calldata; - -const ACC_LEN: usize = 12; -const PI_LEN: usize = 32; - -const ACC_BYTES: usize = ACC_LEN * 32; -const PI_BYTES: usize = PI_LEN * 32; +use snark_verifier::Protocol; +use snark_verifier_sdk::Snark; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct BatchProof { + #[serde(with = "base64")] + pub protocol: Vec, #[serde(flatten)] - raw: Proof, + proof: Proof, + pub batch_hash: H256, } -impl From for BatchProof { - fn from(proof: Proof) -> Self { - let instances = proof.instances(); - assert_eq!(instances.len(), 1); - assert_eq!(instances[0].len(), ACC_LEN + PI_LEN); - - let vk = proof.vk; - let git_version = proof.git_version; - - // "onchain proof" = accumulator + proof - let proof = serialize_instance(&instances[0][..ACC_LEN]) - .into_iter() - .chain(proof.proof) - .collect(); - - // "onchain instances" = pi_data - let instances = serialize_instance(&instances[0][ACC_LEN..]); +impl From<&BatchProof> for Snark { + fn from(value: &BatchProof) -> Self { + let instances = value.proof.instances(); + let protocol = serde_json::from_slice::>(&value.protocol).unwrap(); Self { - raw: Proof { - proof, - instances, - vk, - git_version, - }, + protocol, + proof: value.proof.proof.clone(), + instances, } } } impl BatchProof { - pub fn from_json_file(dir: &str, name: &str) -> Result { - from_json_file(dir, &dump_filename(name)) - } + pub fn new(snark: Snark, pk: Option<&ProvingKey>, batch_hash: H256) -> Result { + let protocol = serde_json::to_vec(&snark.protocol)?; + let proof = Proof::new(snark.proof, &snark.instances, pk); - /// Returns the calldata given to YUL verifier. - /// Format: Accumulator(12x32bytes) || PIHASH(32x32bytes) || Proof - pub fn calldata(self) -> Vec { - let proof = self.proof_to_verify(); - - // calldata = instances + proof - let mut calldata = proof.instances; - calldata.extend(proof.proof); + Ok(Self { + protocol, + proof, + batch_hash, + }) + } - calldata + pub fn from_json_file(dir: &str, name: &str) -> Result { + from_json_file(dir, &dump_filename(name)) } pub fn dump(&self, dir: &str, name: &str) -> Result<()> { let filename = dump_filename(name); - dump_data(dir, &format!("pi_{filename}.data"), &self.raw.instances); - dump_data(dir, &format!("proof_{filename}.data"), &self.raw.proof); - - dump_vk(dir, &filename, &self.raw.vk); + dump_vk(dir, &filename, &self.proof.vk); dump_as_json(dir, &filename, &self) } - - // Recover a `Proof` which follows halo2 semantic of "proof" and "instance", - // where "accumulators" are instance instead of proof, not like "onchain proof". - pub fn proof_to_verify(self) -> Proof { - // raw.proof is accumulator + proof - assert!(self.raw.proof.len() > ACC_BYTES); - // raw.instances is PI - assert_eq!(self.raw.instances.len(), PI_BYTES); - - // instances = raw_proof[..12] (acc) + raw_instances (pi_data) - // proof = raw_proof[12..] - let mut instances = self.raw.proof; - let proof = instances.split_off(ACC_BYTES); - instances.extend(self.raw.instances); - - let vk = self.raw.vk; - let git_version = Some(short_git_version()); - - Proof { - proof, - instances, - vk, - git_version, - } - } - - pub fn assert_calldata(self) { - let real_calldata = self.clone().calldata(); - - let proof = self.proof_to_verify(); - // encode_calldata output: instances || proof - let expected_calldata = encode_calldata(&proof.instances(), &proof.proof); - - assert_eq!(real_calldata, expected_calldata); - } } fn dump_filename(name: &str) -> String { diff --git a/prover/src/proof/bundle.rs b/prover/src/proof/bundle.rs new file mode 100644 index 0000000000..52fa290db9 --- /dev/null +++ b/prover/src/proof/bundle.rs @@ -0,0 +1,115 @@ +use super::{dump_as_json, dump_data, dump_vk, serialize_instance}; +use crate::{utils::short_git_version, Proof}; +use anyhow::Result; +use serde_derive::{Deserialize, Serialize}; + +// 3 limbs per field element, 4 field elements +const ACC_LEN: usize = 12; + +// - Accmulator (4*LIMBS) +// - PREPROCESS_DIGEST, ROUND +// - (hi, lo) finalised state root +// - (hi, lo) finalised batch hash +// - (hi, lo) pending state root +// - (hi, lo) pending batch hash +// - chain id +// - (hi, lo) pending withdraw root +// - bundle count + +const PI_LEN: usize = 13; + +const ACC_BYTES: usize = ACC_LEN * 32; +const PI_BYTES: usize = PI_LEN * 32; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct BundleProof { + #[serde(flatten)] + on_chain_proof: Proof, +} + +impl From for BundleProof { + fn from(proof: Proof) -> Self { + let instances = proof.instances(); + assert_eq!(instances.len(), 1); + assert_eq!(instances[0].len(), ACC_LEN + PI_LEN); + + let vk = proof.vk; + let git_version = proof.git_version; + + // "onchain proof" = accumulator + proof + let proof = serialize_instance(&instances[0][..ACC_LEN]) + .into_iter() + .chain(proof.proof) + .collect(); + + // "onchain instances" = pi_data + let instances = serialize_instance(&instances[0][ACC_LEN..]); + + Self { + on_chain_proof: Proof { + proof, + instances, + vk, + git_version, + }, + } + } +} + +impl BundleProof { + /// Returns the calldata given to YUL verifier. + /// Format: Accumulator(12x32bytes) || PI(13x32bytes) || Proof + pub fn calldata(self) -> Vec { + let proof = self.proof_to_verify(); + + // calldata = instances + proof + let mut calldata = proof.instances; + calldata.extend(proof.proof); + + calldata + } + + pub fn dump(&self, dir: &str, name: &str) -> Result<()> { + let filename = format!("bundle_{name}"); + + dump_data( + dir, + &format!("pi_{filename}.data"), + &self.on_chain_proof.instances, + ); + dump_data( + dir, + &format!("proof_{filename}.data"), + &self.on_chain_proof.proof, + ); + + dump_vk(dir, &filename, &self.on_chain_proof.vk); + + dump_as_json(dir, &filename, &self) + } + + // Recover a `Proof` which follows halo2 semantic of "proof" and "instance", + // where "accumulators" are instance instead of proof, not like "onchain proof". + pub fn proof_to_verify(self) -> Proof { + // raw.proof is accumulator + proof + assert!(self.on_chain_proof.proof.len() > ACC_BYTES); + // raw.instances is PI + assert_eq!(self.on_chain_proof.instances.len(), PI_BYTES); + + // instances = raw_proof[..12] (acc) + raw_instances (pi_data) + // proof = raw_proof[12..] + let mut instances = self.on_chain_proof.proof; + let proof = instances.split_off(ACC_BYTES); + instances.extend(self.on_chain_proof.instances); + + let vk = self.on_chain_proof.vk; + let git_version = Some(short_git_version()); + + Proof { + proof, + instances, + vk, + git_version, + } + } +} diff --git a/prover/src/recursion.rs b/prover/src/recursion.rs new file mode 100644 index 0000000000..aa50a41700 --- /dev/null +++ b/prover/src/recursion.rs @@ -0,0 +1,70 @@ +use halo2_proofs::halo2curves::bn256::Fr; + +use aggregator::{BatchCircuit, StateTransition}; +use snark_verifier_sdk::Snark; + +/// 4 fields for 2 hashes (Hi, Lo) +const ST_INSTANCE: usize = 4; + +/// Additional public inputs, specifically: +/// - withdraw root (hi, lo) +/// - chain ID +const ADD_INSTANCE: usize = 3; + +/// Number of public inputs to describe the state. +const NUM_INSTANCES: usize = ST_INSTANCE + ADD_INSTANCE; + +/// Number of public inputs to describe the initial state. +const NUM_INIT_INSTANCES: usize = ST_INSTANCE; + +#[derive(Clone, Debug)] +pub struct RecursionTask<'a, const N_SNARK: usize> { + /// The [`snarks`][snark] from the [`BatchCircuit`][batch_circuit]. + /// + /// [snark]: snark_verifier_sdk::Snark + /// [batch_circuit]: aggregator::BatchCircuit + snarks: &'a [Snark], +} + +impl RecursionTask<'_, N_SNARK> { + pub fn init_instances(&self) -> [Fr; NUM_INIT_INSTANCES] { + self.snarks.first().unwrap().instances[0][..ST_INSTANCE] + .try_into() + .unwrap() + } + + pub fn state_instances(&self) -> [Fr; NUM_INSTANCES] { + self.snarks.first().unwrap().instances[0][ST_INSTANCE..] + .try_into() + .unwrap() + } + + pub fn iter_snark(&self) -> Snark { + self.snarks.first().unwrap().clone() + } + + pub fn completed(&self) -> bool { + self.snarks.is_empty() + } +} + +impl<'a, const N_SNARK: usize> StateTransition for RecursionTask<'a, N_SNARK> { + type Input = &'a [Snark]; + type Circuit = BatchCircuit; + + fn new(state: Self::Input) -> Self { + Self { snarks: state } + } + + fn state_transition(&self, _round: usize) -> Self::Input { + &self.snarks[1..] + } + + fn num_transition_instance() -> usize { + ST_INSTANCE + } + + fn num_additional_instance() -> usize { + ADD_INSTANCE + } +} diff --git a/prover/src/test.rs b/prover/src/test.rs index 3946970294..982dba2a11 100644 --- a/prover/src/test.rs +++ b/prover/src/test.rs @@ -2,6 +2,6 @@ mod batch; mod chunk; mod inner; -pub use batch::batch_prove; +pub use batch::{batch_prove, bundle_prove}; pub use chunk::chunk_prove; pub use inner::inner_prove; diff --git a/prover/src/test/batch.rs b/prover/src/test/batch.rs index 6a10fbdb54..95f5ed0eba 100644 --- a/prover/src/test/batch.rs +++ b/prover/src/test/batch.rs @@ -3,6 +3,7 @@ use crate::{ config::LayerId, consts::DEPLOYMENT_CODE_FILENAME, io::force_to_read, + types::BundleProvingTask, utils::read_env_var, BatchProvingTask, }; @@ -24,14 +25,14 @@ static BATCH_VERIFIER: LazyLock> = LazyLock::new(|| { let mut prover = BATCH_PROVER.lock().expect("poisoned batch-prover"); let params = prover.prover_impl.params(LayerId::Layer4.degree()).clone(); + let deployment_code = force_to_read(&assets_dir, &DEPLOYMENT_CODE_FILENAME); + let pk = prover .prover_impl .pk(LayerId::Layer4.id()) .expect("Failed to get batch-prove PK"); let vk = pk.get_vk().clone(); - let deployment_code = force_to_read(&assets_dir, &DEPLOYMENT_CODE_FILENAME); - let verifier = Verifier::new(params, vk, deployment_code); log::info!("Constructed batch-verifier"); @@ -44,15 +45,34 @@ pub fn batch_prove(test: &str, batch: BatchProvingTask) { let proof = BATCH_PROVER .lock() .expect("poisoned batch-prover") - .gen_agg_evm_proof(batch, None, None) + .gen_batch_proof(batch, None, None) .unwrap_or_else(|err| panic!("{test}: failed to generate batch proof: {err}")); log::info!("{test}: generated batch proof"); let verified = BATCH_VERIFIER .lock() .expect("poisoned batch-verifier") - .verify_agg_evm_proof(proof); + .verify_batch_proof(&proof); assert!(verified, "{test}: failed to verify batch proof"); log::info!("{test}: batch-prove END"); } + +pub fn bundle_prove(test: &str, bundle: BundleProvingTask) { + log::info!("{test}: bundle-prove BEGIN"); + + let proof = BATCH_PROVER + .lock() + .expect("poisoned batch-prover") + .gen_bundle_proof(bundle, None, None) + .unwrap_or_else(|err| panic!("{test}: failed to generate bundle proof: {err}")); + log::info!("{test}: generated bundle proof"); + + let verified = BATCH_VERIFIER + .lock() + .expect("poisoned batch-verifier") + .verify_bundle_proof(proof); + assert!(verified, "{test}: failed to verify bundle proof"); + + log::info!("{test}: bundle-prove END"); +} diff --git a/prover/src/types.rs b/prover/src/types.rs index e733dc5454..ef57cba56d 100644 --- a/prover/src/types.rs +++ b/prover/src/types.rs @@ -1,4 +1,4 @@ -use aggregator::ChunkInfo; +use aggregator::{BatchHeader, ChunkInfo, MAX_AGG_SNARKS}; use eth_types::l2_types::BlockTrace; use serde::{Deserialize, Serialize}; use zkevm_circuits::evm_circuit::witness::Block; @@ -11,7 +11,7 @@ pub struct BlockTraceJsonRpcResult { } pub use eth_types::base64; -use crate::ChunkProof; +use crate::{BatchProof, ChunkProof}; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ChunkProvingTask { @@ -44,6 +44,7 @@ impl ChunkProvingTask { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct BatchProvingTask { pub chunk_proofs: Vec, + pub batch_header: BatchHeader, } impl BatchProvingTask { @@ -58,3 +59,14 @@ impl BatchProvingTask { .to_string() } } + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct BundleProvingTask { + pub batch_proofs: Vec, +} + +impl BundleProvingTask { + pub fn identifier(&self) -> String { + self.batch_proofs.last().unwrap().batch_hash.to_string() + } +} diff --git a/zkevm-circuits/src/root_circuit.rs b/zkevm-circuits/src/root_circuit.rs index 8a47fb5bb6..440424fa18 100644 --- a/zkevm-circuits/src/root_circuit.rs +++ b/zkevm-circuits/src/root_circuit.rs @@ -142,4 +142,3 @@ impl<'a, M: MultiMillerLoop> Circuit for RootCircuit<'a, M> { Ok(()) } } - max_inner_blocks: 64,