diff --git a/gkr/src/prover/gkr_square.rs b/gkr/src/prover/gkr_square.rs index 4689421bd..57194d59e 100644 --- a/gkr/src/prover/gkr_square.rs +++ b/gkr/src/prover/gkr_square.rs @@ -19,11 +19,15 @@ pub fn gkr_square_prove>( for _i in 0..circuit.layers.last().unwrap().output_var_num { rz0.push(transcript.generate_challenge_field_element()); } + log::trace!("Initial rz0: {:?}", rz0); let mut r_simd = vec![]; for _i in 0..C::get_field_pack_size().trailing_zeros() { r_simd.push(transcript.generate_challenge_field_element()); } + log::trace!("Initial r_simd: {:?}", r_simd); + + // TODO: MPI support let circuit_output = &circuit.layers.last().unwrap().output_vals; let claimed_v = C::eval_circuit_vals_at_challenge(circuit_output, &rz0, &mut sp.hg_evals); @@ -35,7 +39,7 @@ pub fn gkr_square_prove>( log::trace!("Layer {} proved", i); log::trace!("rz0.0: {:?}", rz0[0]); log::trace!("rz0.1: {:?}", rz0[1]); - log::trace!("rz0.2: {:?}", rz0[2]); + // log::trace!("rz0.2: {:?}", rz0[2]); } end_timer!(timer); diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index 77f1980f7..ee7c8580f 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -307,7 +307,7 @@ fn do_prove_verify(config: Config, circuit: &mut Circuit) { let (mut claimed_v, proof) = prover.prove(circuit); // Verify - // let verifier = Verifier::new(&config); - // let public_input = vec![]; - // assert!(verifier.verify(circuit, &public_input, &mut claimed_v, &proof)) + let verifier = Verifier::new(&config); + let public_input = vec![]; + assert!(verifier.verify(circuit, &public_input, &mut claimed_v, &proof)) } diff --git a/gkr/src/verifier.rs b/gkr/src/verifier.rs index 2be3c0902..c3cbd9f9f 100644 --- a/gkr/src/verifier.rs +++ b/gkr/src/verifier.rs @@ -6,7 +6,7 @@ use std::{ use arith::{Field, FieldSerde}; use ark_std::{end_timer, start_timer}; use circuit::{Circuit, CircuitLayer}; -use config::{Config, FiatShamirHashType, GKRConfig, PolynomialCommitmentType}; +use config::{Config, FiatShamirHashType, GKRConfig, GKRScheme, PolynomialCommitmentType}; use sumcheck::{GKRVerifierHelper, VerifierScratchPad}; use transcript::{ BytesHashTranscript, FieldHashTranscript, Keccak256hasher, MIMCHasher, Proof, SHA256hasher, @@ -17,6 +17,9 @@ use transcript::{ use crate::grind; use crate::RawCommitment; +mod gkr_square; +pub use gkr_square::gkr_square_verify; + #[inline(always)] fn verify_sumcheck_step>( mut proof_reader: impl Read, @@ -41,6 +44,10 @@ fn verify_sumcheck_step>( *claimed_sum = GKRVerifierHelper::degree_2_eval(&ps, r, sp); } else if degree == 3 { *claimed_sum = GKRVerifierHelper::degree_3_eval(&ps, r, sp); + } else if degree == 6 { + *claimed_sum = GKRVerifierHelper::degree_6_eval(&ps, r, sp); + } else { + panic!("unsupported degree"); } verified @@ -287,38 +294,68 @@ impl Verifier { circuit.fill_rnd_coefs(transcript); - let (mut verified, rz0, rz1, r_simd, r_mpi, claimed_v0, claimed_v1) = gkr_verify( - &self.config, - circuit, - public_input, - claimed_v, - transcript, - &mut cursor, - ); - - log::info!("GKR verification: {}", verified); - - match self.config.polynomial_commitment_type { - PolynomialCommitmentType::Raw => { - // for Raw, no need to load from proof - log::trace!("rz0.size() = {}", rz0.len()); - log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len()); - - let v1 = commitment.mpi_verify(&rz0, &r_simd, &r_mpi, claimed_v0); - verified &= v1; - - if rz1.is_some() { - let v2 = commitment.mpi_verify( - rz1.as_ref().unwrap(), - &r_simd, - &r_mpi, - claimed_v1.unwrap(), - ); - verified &= v2; + let verified = match self.config.gkr_scheme { + GKRScheme::Vanilla => { + let (mut verified, rz0, rz1, r_simd, r_mpi, claimed_v0, claimed_v1) = gkr_verify( + &self.config, + circuit, + public_input, + claimed_v, + transcript, + &mut cursor, + ); + + log::info!("GKR verification: {}", verified); + + match self.config.polynomial_commitment_type { + PolynomialCommitmentType::Raw => { + // for Raw, no need to load from proof + log::trace!("rz0.size() = {}", rz0.len()); + log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len()); + + let v1 = commitment.mpi_verify(&rz0, &r_simd, &r_mpi, claimed_v0); + verified &= v1; + + if rz1.is_some() { + let v2 = commitment.mpi_verify( + rz1.as_ref().unwrap(), + &r_simd, + &r_mpi, + claimed_v1.unwrap(), + ); + verified &= v2; + } + } + _ => todo!(), } + verified } - _ => todo!(), - } + GKRScheme::GkrSquare => { + let (mut verified, rz, r_simd, r_mpi, claimed_v) = gkr_square_verify( + &self.config, + circuit, + public_input, + claimed_v, + transcript, + &mut cursor, + ); + + log::info!("GKR verification: {}", verified); + + match self.config.polynomial_commitment_type { + PolynomialCommitmentType::Raw => { + // for Raw, no need to load from proof + log::trace!("rz.size() = {}", rz.len()); + log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len()); + + let v1 = commitment.mpi_verify(&rz, &r_simd, &r_mpi, claimed_v); + verified &= v1; + } + _ => todo!(), + } + verified + } + }; end_timer!(timer); diff --git a/gkr/src/verifier/gkr_square.rs b/gkr/src/verifier/gkr_square.rs new file mode 100644 index 000000000..85bcbd543 --- /dev/null +++ b/gkr/src/verifier/gkr_square.rs @@ -0,0 +1,161 @@ +use super::verify_sumcheck_step; +use arith::{Field, FieldSerde}; +use ark_std::{end_timer, start_timer}; +use circuit::{Circuit, CircuitLayer}; +use config::{Config, GKRConfig}; +use std::{io::Read, vec}; +use sumcheck::{GKRVerifierHelper, VerifierScratchPad}; +use transcript::Transcript; + +#[allow(clippy::type_complexity)] +pub fn gkr_square_verify>( + config: &Config, + circuit: &Circuit, + public_input: &[C::SimdCircuitField], + claimed_v: &C::ChallengeField, + transcript: &mut T, + mut proof_reader: impl Read, +) -> ( + bool, + Vec, + Vec, + Vec, + C::ChallengeField, +) { + let timer = start_timer!(|| "gkr verify"); + let mut sp = VerifierScratchPad::::new(config, circuit); + + let layer_num = circuit.layers.len(); + let mut rz = vec![]; + let mut r_simd = vec![]; + let mut r_mpi = vec![]; + + for _ in 0..circuit.layers.last().unwrap().output_var_num { + rz.push(transcript.generate_challenge_field_element()); + } + log::trace!("rz {:?}", rz); + + for _ in 0..C::get_field_pack_size().trailing_zeros() { + r_simd.push(transcript.generate_challenge_field_element()); + } + log::trace!("r_simd {:?}", r_simd); + + // TODO: MPI support + assert_eq!( + config.mpi_config.world_size().trailing_zeros(), + 0, + "MPI not supported yet" + ); + // for _ in 0..config.mpi_config.world_size().trailing_zeros() { + // r_mpi.push(transcript.generate_challenge_field_element()); + // } + + let mut verified = true; + let mut current_claim = *claimed_v; + for i in (0..layer_num).rev() { + let cur_verified; + (cur_verified, rz, r_simd, r_mpi, current_claim) = sumcheck_verify_gkr_square_layer( + config, + &circuit.layers[i], + public_input, + &rz, + &r_simd, + &r_mpi, + current_claim, + &mut proof_reader, + transcript, + &mut sp, + i == layer_num - 1, + ); + verified &= cur_verified; + } + end_timer!(timer); + (verified, rz, r_simd, r_mpi, current_claim) +} + +#[allow(clippy::too_many_arguments)] +#[allow(clippy::type_complexity)] +#[allow(clippy::unnecessary_unwrap)] +fn sumcheck_verify_gkr_square_layer>( + config: &Config, + layer: &CircuitLayer, + public_input: &[C::SimdCircuitField], + rz: &[C::ChallengeField], + r_simd: &Vec, + r_mpi: &Vec, + current_claim: C::ChallengeField, + mut proof_reader: impl Read, + transcript: &mut T, + sp: &mut VerifierScratchPad, + is_output_layer: bool, +) -> ( + bool, + Vec, + Vec, + Vec, + C::ChallengeField, +) { + // GKR2 with Power5 gate has degree 6 polynomial + let degree = 6; + + GKRVerifierHelper::prepare_layer(layer, &None, rz, &None, r_simd, r_mpi, sp, is_output_layer); + + let var_num = layer.input_var_num; + let mut sum = current_claim; + sum -= GKRVerifierHelper::eval_cst(&layer.const_, public_input, sp); + + let mut rx = vec![]; + let mut r_simd_var = vec![]; + let mut r_mpi_var = vec![]; + let mut verified = true; + + for i_var in 0..var_num { + verified &= verify_sumcheck_step::( + &mut proof_reader, + degree, + transcript, + &mut sum, + &mut rx, + sp, + ); + log::trace!("x {} var, verified? {}", i_var, verified); + } + GKRVerifierHelper::set_rx(&rx, sp); + + for i_var in 0..C::get_field_pack_size().trailing_zeros() { + verified &= verify_sumcheck_step::( + &mut proof_reader, + degree, + transcript, + &mut sum, + &mut r_simd_var, + sp, + ); + log::trace!("simd {} var, verified? {}", i_var, verified); + } + GKRVerifierHelper::set_r_simd_xy(&r_simd_var, sp); + + // TODO: nontrivial MPI support + for _i_var in 0..config.mpi_config.world_size().trailing_zeros() { + verified &= verify_sumcheck_step::( + &mut proof_reader, + 3, + transcript, + &mut sum, + &mut r_mpi_var, + sp, + ); + // println!("{} mpi var, verified? {}", _i_var, verified); + } + GKRVerifierHelper::set_r_mpi_xy(&r_mpi_var, sp); + + let v_claim = C::ChallengeField::deserialize_from(&mut proof_reader).unwrap(); + + sum -= v_claim * GKRVerifierHelper::eval_pow_1(&layer.uni, sp) + + v_claim.exp(5) * GKRVerifierHelper::eval_pow_5(&layer.uni, sp); + transcript.append_field_element(&v_claim); + + verified &= sum == C::ChallengeField::ZERO; + + (verified, rx, r_simd_var, r_mpi_var, v_claim) +} diff --git a/sumcheck/src/scratch_pad.rs b/sumcheck/src/scratch_pad.rs index 71c8f6e34..facfb575c 100644 --- a/sumcheck/src/scratch_pad.rs +++ b/sumcheck/src/scratch_pad.rs @@ -91,6 +91,9 @@ pub struct VerifierScratchPad { pub gf2_deg2_eval_coef: C::ChallengeField, // 1 / x(x - 1) pub deg3_eval_at: [C::ChallengeField; 4], pub deg3_lag_denoms_inv: [C::ChallengeField; 4], + // ====== for deg6 eval ====== + pub deg6_eval_at: [C::ChallengeField; 7], + pub deg6_lag_denoms_inv: [C::ChallengeField; 7], } impl VerifierScratchPad { @@ -143,6 +146,32 @@ impl VerifierScratchPad { deg3_lag_denoms_inv[i] = denominator.inv().unwrap(); } + let deg6_eval_at = if C::FIELD_TYPE == FieldType::GF2 { + panic!("GF2 not supported yet"); + } else { + [ + C::ChallengeField::ZERO, + C::ChallengeField::ONE, + C::ChallengeField::from(2), + C::ChallengeField::from(3), + C::ChallengeField::from(4), + C::ChallengeField::from(5), + C::ChallengeField::from(6), + ] + }; + + let mut deg6_lag_denoms_inv = [C::ChallengeField::ZERO; 7]; + for i in 0..7 { + let mut denominator = C::ChallengeField::ONE; + for j in 0..7 { + if j == i { + continue; + } + denominator *= deg6_eval_at[i] - deg6_eval_at[j]; + } + deg6_lag_denoms_inv[i] = denominator.inv().unwrap(); + } + Self { eq_evals_at_rz0: vec![C::ChallengeField::zero(); max_io_size], eq_evals_at_r_simd: vec![C::ChallengeField::zero(); simd_size], @@ -162,6 +191,8 @@ impl VerifierScratchPad { gf2_deg2_eval_coef, deg3_eval_at, deg3_lag_denoms_inv, + deg6_eval_at, + deg6_lag_denoms_inv, } } } diff --git a/sumcheck/src/verifier_helper.rs b/sumcheck/src/verifier_helper.rs index 7837ec89e..a103765d1 100644 --- a/sumcheck/src/verifier_helper.rs +++ b/sumcheck/src/verifier_helper.rs @@ -1,5 +1,5 @@ use arith::{ExtensionField, Field}; -use circuit::{CircuitLayer, CoefType, GateAdd, GateConst, GateMul}; +use circuit::{CircuitLayer, CoefType, GateAdd, GateConst, GateMul, GateUni}; use config::{FieldType, GKRConfig}; use polynomials::EqPolynomial; @@ -140,6 +140,39 @@ impl GKRVerifierHelper { v * sp.eq_r_simd_r_simd_xy * sp.eq_r_mpi_r_mpi_xy } + /// GKR2 equivalent of `eval_add`. (Note that GKR2 uses pow1 gates instead of add gates) + #[inline(always)] + pub fn eval_pow_1( + gates: &[GateUni], + sp: &VerifierScratchPad, + ) -> C::ChallengeField { + let mut v = C::ChallengeField::zero(); + for gate in gates { + // Gates of type 12346 represent an add gate + if gate.gate_type == 12346 { + v += sp.eq_evals_at_rz0[gate.o_id] + * C::challenge_mul_circuit_field(&sp.eq_evals_at_rx[gate.i_ids[0]], &gate.coef); + } + } + v * sp.eq_r_simd_r_simd_xy + } + + #[inline(always)] + pub fn eval_pow_5( + gates: &[GateUni], + sp: &VerifierScratchPad, + ) -> C::ChallengeField { + let mut v = C::ChallengeField::zero(); + for gate in gates { + // Gates of type 12345 represent a pow5 gate + if gate.gate_type == 12345 { + v += sp.eq_evals_at_rz0[gate.o_id] + * C::challenge_mul_circuit_field(&sp.eq_evals_at_rx[gate.i_ids[0]], &gate.coef); + } + } + v * sp.eq_r_simd_r_simd_xy + } + #[inline(always)] pub fn set_rx(rx: &[C::ChallengeField], sp: &mut VerifierScratchPad) { EqPolynomial::::eq_eval_at( @@ -219,13 +252,26 @@ impl GKRVerifierHelper { Self::lag_eval(vals, x, sp) } + #[inline(always)] + pub fn degree_6_eval( + vals: &[C::ChallengeField], + x: C::ChallengeField, + sp: &VerifierScratchPad, + ) -> C::ChallengeField { + Self::lag_eval(vals, x, sp) + } + #[inline(always)] fn lag_eval( vals: &[C::ChallengeField], x: C::ChallengeField, sp: &VerifierScratchPad, ) -> C::ChallengeField { - assert_eq!(sp.deg3_eval_at.len(), vals.len()); + let (evals, lag_denoms_inv) = match vals.len() { + 4 => (sp.deg3_eval_at.to_vec(), sp.deg3_lag_denoms_inv.to_vec()), + 7 => (sp.deg6_eval_at.to_vec(), sp.deg6_lag_denoms_inv.to_vec()), + _ => panic!("unsupported degree"), + }; let mut v = C::ChallengeField::ZERO; for i in 0..vals.len() { @@ -234,9 +280,9 @@ impl GKRVerifierHelper { if j == i { continue; } - numerator *= x - sp.deg3_eval_at[j]; + numerator *= x - evals[j]; } - v += numerator * sp.deg3_lag_denoms_inv[i] * vals[i]; + v += numerator * lag_denoms_inv[i] * vals[i]; } v }