diff --git a/Cargo.lock b/Cargo.lock index 6155537ce5..7506201fcf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4594,7 +4594,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "snark-verifier" version = "0.1.0" -source = "git+https://github.com/scroll-tech/snark-verifier?branch=develop#572ef69d1595fca82213d3b05e859eaf355a5fa1" +source = "git+https://github.com/scroll-tech/snark-verifier?branch=develop#948671cac73f11e66187a15483e38ab3626dc2a3" dependencies = [ "bytes", "ethereum-types", @@ -4617,7 +4617,7 @@ dependencies = [ [[package]] name = "snark-verifier-sdk" version = "0.0.1" -source = "git+https://github.com/scroll-tech/snark-verifier?branch=develop#572ef69d1595fca82213d3b05e859eaf355a5fa1" +source = "git+https://github.com/scroll-tech/snark-verifier?branch=develop#948671cac73f11e66187a15483e38ab3626dc2a3" dependencies = [ "bincode", "ethereum-types", diff --git a/aggregator/src/aggregation.rs b/aggregator/src/aggregation.rs index bb067bfb0c..a616965750 100644 --- a/aggregator/src/aggregation.rs +++ b/aggregator/src/aggregation.rs @@ -19,3 +19,42 @@ pub(crate) use rlc::{RlcConfig, POWS_OF_256}; pub use circuit::BatchCircuit; pub use config::BatchCircuitConfig; +use halo2_base::halo2_proofs::halo2curves::bn256::{Fr, G1Affine}; +use snark_verifier::Protocol; + +/// Alias for a list of G1 points. +pub type PreprocessedPolyCommits = Vec; +/// Alias for the transcript's initial state. +pub type TranscriptInitState = Fr; + +/// Alias for the fixed part of the protocol which consists of the commitments to the preprocessed +/// polynomials and the initial state of the transcript. +#[derive(Clone)] +pub struct FixedProtocol { + /// The commitments to the preprocessed polynomials. + pub preprocessed: PreprocessedPolyCommits, + /// The initial state of the transcript. + pub init_state: TranscriptInitState, +} + +impl From> for FixedProtocol { + fn from(protocol: Protocol) -> Self { + Self { + preprocessed: protocol.preprocessed, + init_state: protocol + .transcript_initial_state + .expect("protocol transcript init state None"), + } + } +} + +impl From<&Protocol> for FixedProtocol { + fn from(protocol: &Protocol) -> Self { + Self { + preprocessed: protocol.preprocessed.clone(), + init_state: protocol + .transcript_initial_state + .expect("protocol transcript init state None"), + } + } +} diff --git a/aggregator/src/aggregation/circuit.rs b/aggregator/src/aggregation/circuit.rs index b9a3cc7c79..06cb53c9b0 100644 --- a/aggregator/src/aggregation/circuit.rs +++ b/aggregator/src/aggregation/circuit.rs @@ -1,5 +1,6 @@ use ark_std::{end_timer, start_timer}; use halo2_proofs::{ + arithmetic::Field, circuit::{Layouter, SimpleFloorPlanner, Value}, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::{Circuit, ConstraintSystem, Error, Selector}, @@ -11,14 +12,20 @@ use snark_verifier::{ loader::halo2::{ halo2_ecc::{ ecc::EccChip, - fields::fp::FpConfig, - halo2_base::{AssignedValue, Context, ContextParams}, + fields::{fp::FpConfig, FieldChip}, + halo2_base::{ + gates::{GateInstructions, RangeInstructions}, + AssignedValue, Context, ContextParams, + QuantumCell::Existing, + }, }, - Halo2Loader, + Halo2Loader, IntegerInstructions, }, pcs::kzg::{Bdfg21, Kzg, KzgSuccinctVerifyingKey}, }; -use snark_verifier_sdk::{aggregate, flatten_accumulator, CircuitExt, Snark, SnarkWitness}; +use snark_verifier_sdk::{ + aggregate_as_witness, flatten_accumulator, CircuitExt, Snark, SnarkWitness, +}; use std::{env, fs::File, rc::Rc}; use zkevm_circuits::util::Challenges; @@ -30,8 +37,8 @@ use crate::{ core::{assign_batch_hashes, extract_proof_and_instances_with_pairing_check}, util::parse_hash_digest_cells, witgen::{zstd_encode, MultiBlockProcessResult}, - ConfigParams, LOG_DEGREE, PI_CHAIN_ID, PI_CURRENT_BATCH_HASH, PI_CURRENT_STATE_ROOT, - PI_CURRENT_WITHDRAW_ROOT, PI_PARENT_BATCH_HASH, PI_PARENT_STATE_ROOT, + ConfigParams, FixedProtocol, LOG_DEGREE, PI_CHAIN_ID, PI_CURRENT_BATCH_HASH, + PI_CURRENT_STATE_ROOT, PI_CURRENT_WITHDRAW_ROOT, PI_PARENT_BATCH_HASH, PI_PARENT_STATE_ROOT, }; /// Batch circuit, the chunk aggregation routine below recursion circuit @@ -55,14 +62,21 @@ pub struct BatchCircuit { // batch hash circuit for which the snarks are generated // the chunks in this batch are also padded already pub batch_hash: BatchHash, + + /// The SNARK protocol from the halo2-based inner circuit route. + pub halo2_protocol: FixedProtocol, + /// The SNARK protocol from the sp1-based inner circuit route. + pub sp1_protocol: FixedProtocol, } impl BatchCircuit { - pub fn new( + pub fn new>( params: &ParamsKZG, snarks_with_padding: &[Snark], rng: impl Rng + Send, batch_hash: BatchHash, + halo2_protocol: P, + sp1_protocol: P, ) -> Result { let timer = start_timer!(|| "generate aggregation circuit"); @@ -120,6 +134,8 @@ impl BatchCircuit { flattened_instances, as_proof: Value::known(as_proof), batch_hash, + halo2_protocol: halo2_protocol.into(), + sp1_protocol: sp1_protocol.into(), }) } @@ -209,22 +225,21 @@ impl Circuit for BatchCircuit { let loader: Rc>>> = Halo2Loader::new(ecc_chip, ctx); - // // extract the assigned values for // - instances which are the public inputs of each chunk (prefixed with 12 // instances from previous accumulators) // - new accumulator - // - log::debug!("aggregation: chunk aggregation"); - let (assigned_aggregation_instances, acc) = aggregate::>( + let ( + assigned_aggregation_instances, + acc, + preprocessed_poly_sets, + transcript_init_states, + ) = aggregate_as_witness::>( &self.svk, &loader, &self.snarks_with_padding, self.as_proof(), ); - for (i, e) in assigned_aggregation_instances[0].iter().enumerate() { - log::trace!("{}-th instance: {:?}", i, e.value) - } // extract the following cells for later constraints // - the accumulators @@ -238,13 +253,118 @@ impl Circuit for BatchCircuit { .iter() .flat_map(|instance_column| instance_column.iter().skip(ACC_LEN)), ); + for (i, e) in assigned_aggregation_instances[0].iter().enumerate() { + log::trace!("{}-th instance: {:?}", i, e.value) + } - loader - .ctx_mut() - .print_stats(&["snark aggregation [chunks -> batch]"]); + loader.ctx_mut().print_stats(&["snark aggregation"]); let mut ctx = Rc::into_inner(loader).unwrap().into_ctx(); - log::debug!("batching: assigning barycentric"); + + // We must ensure that the commitments to preprocessed polynomial and initial + // state of transcripts for every SNARK that is being aggregated belongs to the + // fixed set of values expected. + // + // First we load the constants. + let mut preprocessed_polys_halo2 = Vec::with_capacity(7); + let mut preprocessed_polys_sp1 = Vec::with_capacity(7); + for &preprocessed_poly in self.halo2_protocol.preprocessed.iter() { + preprocessed_polys_halo2.push( + config + .ecc_chip() + .assign_constant_point(&mut ctx, preprocessed_poly), + ); + } + for &preprocessed_poly in self.sp1_protocol.preprocessed.iter() { + preprocessed_polys_sp1.push( + config + .ecc_chip() + .assign_constant_point(&mut ctx, preprocessed_poly), + ); + } + let transcript_init_state_halo2 = config + .ecc_chip() + .field_chip() + .range() + .gate() + .assign_constant(&mut ctx, self.halo2_protocol.init_state) + .expect("IntegerInstructions::assign_constant infallible"); + let transcript_init_state_sp1 = config + .ecc_chip() + .field_chip() + .range() + .gate() + .assign_constant(&mut ctx, self.sp1_protocol.init_state) + .expect("IntegerInstructions::assign_constant infallible"); + + // Commitments to the preprocessed polynomials. + // + // check_1: halo2-route + // check_2: sp1-route + // + // OR(check_1, check_2) == 1 + let mut route_check = Vec::with_capacity(N_SNARKS); + for preprocessed_polys in preprocessed_poly_sets.iter() { + let mut preprocessed_check_1 = + config.flex_gate().load_constant(&mut ctx, Fr::ONE); + let mut preprocessed_check_2 = + config.flex_gate().load_constant(&mut ctx, Fr::ONE); + for ((commitment, comm_halo2), comm_sp1) in preprocessed_polys + .iter() + .zip_eq(preprocessed_polys_halo2.iter()) + .zip_eq(preprocessed_polys_sp1.iter()) + { + let check_1 = + config.ecc_chip().is_equal(&mut ctx, commitment, comm_halo2); + let check_2 = + config.ecc_chip().is_equal(&mut ctx, commitment, comm_sp1); + preprocessed_check_1 = config.flex_gate().and( + &mut ctx, + Existing(preprocessed_check_1), + Existing(check_1), + ); + preprocessed_check_2 = config.flex_gate().and( + &mut ctx, + Existing(preprocessed_check_2), + Existing(check_2), + ); + } + route_check.push(preprocessed_check_1); + let preprocessed_check = config.flex_gate().or( + &mut ctx, + Existing(preprocessed_check_1), + Existing(preprocessed_check_2), + ); + config + .flex_gate() + .assert_is_const(&mut ctx, &preprocessed_check, Fr::ONE); + } + + // Transcript initial state. + // + // If the SNARK belongs to halo2-route, the initial state is the halo2-initial + // state. Otherwise sp1-initial state. + for (transcript_init_state, &route) in + transcript_init_states.iter().zip_eq(route_check.iter()) + { + let transcript_init_state = transcript_init_state + .expect("SNARK should have an initial state for transcript"); + let init_state_expected = config.flex_gate().select( + &mut ctx, + Existing(transcript_init_state_halo2), + Existing(transcript_init_state_sp1), + Existing(route), + ); + GateInstructions::assert_equal( + config.flex_gate(), + &mut ctx, + Existing(transcript_init_state), + Existing(init_state_expected), + ); + } + + ctx.print_stats(&["protocol check"]); + let barycentric = config.blob_consistency_config.assign_barycentric( &mut ctx, &self.batch_hash.blob_bytes, diff --git a/aggregator/src/tests/aggregation.rs b/aggregator/src/tests/aggregation.rs index ec374c420b..96c5034464 100644 --- a/aggregator/src/tests/aggregation.rs +++ b/aggregator/src/tests/aggregation.rs @@ -209,6 +209,7 @@ fn build_new_batch_circuit( }) .collect_vec() }; + let snark_protocol = real_snarks[0].protocol.clone(); // ========================== // padded chunks @@ -225,6 +226,8 @@ fn build_new_batch_circuit( [real_snarks, padded_snarks].concat().as_ref(), rng, batch_hash, + &snark_protocol, + &snark_protocol, ) .unwrap() } @@ -293,6 +296,8 @@ fn build_batch_circuit_skip_encoding() -> BatchCircuit() -> BatchCircuit Prover<'params> { LayerId::Layer3.id(), LayerId::Layer3.degree(), batch_info, + &self.halo2_protocol, + &self.sp1_protocol, &layer2_snarks, output_dir, )?; diff --git a/prover/src/common/prover/aggregation.rs b/prover/src/common/prover/aggregation.rs index 4d4ca2bc1b..d17e838a94 100644 --- a/prover/src/common/prover/aggregation.rs +++ b/prover/src/common/prover/aggregation.rs @@ -6,34 +6,52 @@ use crate::{ }; use aggregator::{BatchCircuit, BatchHash}; use anyhow::{anyhow, Result}; +use halo2_proofs::halo2curves::bn256::G1Affine; use rand::Rng; use snark_verifier_sdk::Snark; use std::env; impl<'params> Prover<'params> { + #[allow(clippy::too_many_arguments)] pub fn gen_agg_snark( &mut self, id: &str, degree: u32, mut rng: impl Rng + Send, batch_info: BatchHash, + halo2_protocol: &[u8], + sp1_protocol: &[u8], previous_snarks: &[Snark], ) -> Result { env::set_var("AGGREGATION_CONFIG", layer_config_path(id)); - let circuit: BatchCircuit = - BatchCircuit::new(self.params(degree), previous_snarks, &mut rng, batch_info) - .map_err(|err| anyhow!("Failed to construct aggregation circuit: {err:?}"))?; + let halo2_protocol = + serde_json::from_slice::>(halo2_protocol)?; + let sp1_protocol = + serde_json::from_slice::>(sp1_protocol)?; + + let circuit: BatchCircuit = BatchCircuit::new( + self.params(degree), + previous_snarks, + &mut rng, + batch_info, + halo2_protocol, + sp1_protocol, + ) + .map_err(|err| anyhow!("Failed to construct aggregation circuit: {err:?}"))?; self.gen_snark(id, degree, &mut rng, circuit, "gen_agg_snark") } + #[allow(clippy::too_many_arguments)] pub fn load_or_gen_agg_snark( &mut self, name: &str, id: &str, degree: u32, batch_info: BatchHash, + halo2_protocol: &[u8], + sp1_protocol: &[u8], previous_snarks: &[Snark], output_dir: Option<&str>, ) -> Result { @@ -48,7 +66,15 @@ impl<'params> Prover<'params> { Some(snark) => Ok(snark), None => { let rng = gen_rng(); - let result = self.gen_agg_snark(id, degree, rng, batch_info, previous_snarks); + let result = self.gen_agg_snark( + id, + degree, + rng, + batch_info, + halo2_protocol, + sp1_protocol, + previous_snarks, + ); if let (Some(_), Ok(snark)) = (output_dir, &result) { write_snark(&file_path, snark); } diff --git a/prover/src/consts.rs b/prover/src/consts.rs index 978594092d..19b1800ddb 100644 --- a/prover/src/consts.rs +++ b/prover/src/consts.rs @@ -13,8 +13,24 @@ pub fn chunk_vk_filename() -> String { read_env_var("CHUNK_VK_FILENAME", "vk_chunk.vkey".to_string()) } -pub static CHUNK_PROTOCOL_FILENAME: LazyLock = - LazyLock::new(|| read_env_var("CHUNK_PROTOCOL_FILENAME", "chunk.protocol".to_string())); +/// The file descriptor for the JSON serialised SNARK [`protocol`][protocol] that +/// defines the [`CompressionCircuit`][compr_circuit] SNARK that uses halo2-based +/// [`SuperCircuit`][super_circuit]. +/// +/// [protocol]: snark_verifier::Protocol +/// [compr_circuit]: aggregator::CompressionCircuit +/// [super_circuit]: zkevm_circuits::super_circuit::SuperCircuit +pub static FD_HALO2_CHUNK_PROTOCOL: LazyLock = + LazyLock::new(|| read_env_var("HALO2_CHUNK_PROTOCOL", "chunk_halo2.protocol".to_string())); + +/// The file descriptor for the JSON serialised SNARK [`protocol`][protocol] that +/// defines the [`CompressionCircuit`][compr_circuit] SNARK that uses sp1-based +/// STARK that is SNARKified using a halo2-backend. +/// +/// [protocol]: snark_verifier::Protocol +/// [compr_circuit]: aggregator::CompressionCircuit +pub static FD_SP1_CHUNK_PROTOCOL: LazyLock = + LazyLock::new(|| read_env_var("SP1_CHUNK_PROTOCOL", "chunk_sp1.protocol".to_string())); pub static CHUNK_VK_FILENAME: LazyLock = LazyLock::new(chunk_vk_filename); pub static BATCH_VK_FILENAME: LazyLock = LazyLock::new(batch_vk_filename); diff --git a/testool/src/statetest/executor.rs b/testool/src/statetest/executor.rs index 2dcecc8d4c..34dbbc4485 100644 --- a/testool/src/statetest/executor.rs +++ b/testool/src/statetest/executor.rs @@ -644,7 +644,7 @@ pub fn run_test( eth_types::constants::set_env_coinbase(&st.env.current_coinbase); prover::test::chunk_prove( &test_id, - prover::ChunkProvingTask::from(vec![_scroll_trace]), + prover::ChunkProvingTask::new(vec![_scroll_trace], prover::ChunkKind::Halo2), ); }