Skip to content

Commit

Permalink
wip: GKR2 verifier
Browse files Browse the repository at this point in the history
  • Loading branch information
enpsi committed Nov 5, 2024
1 parent 01981fc commit 362916b
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 39 deletions.
6 changes: 5 additions & 1 deletion gkr/src/prover/gkr_square.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ pub fn gkr_square_prove<C: GKRConfig, T: Transcript<C::ChallengeField>>(
for _i in 0..circuit.layers.last().unwrap().output_var_num {
rz0.push(transcript.generate_challenge_field_element());
}
log::trace!("Initial rz0: {:?}", rz0);

let mut r_simd = vec![];
for _i in 0..C::get_field_pack_size().trailing_zeros() {
r_simd.push(transcript.generate_challenge_field_element());
}
log::trace!("Initial r_simd: {:?}", r_simd);

// TODO: MPI support

let circuit_output = &circuit.layers.last().unwrap().output_vals;
let claimed_v = C::eval_circuit_vals_at_challenge(circuit_output, &rz0, &mut sp.hg_evals);
Expand All @@ -35,7 +39,7 @@ pub fn gkr_square_prove<C: GKRConfig, T: Transcript<C::ChallengeField>>(
log::trace!("Layer {} proved", i);
log::trace!("rz0.0: {:?}", rz0[0]);
log::trace!("rz0.1: {:?}", rz0[1]);
log::trace!("rz0.2: {:?}", rz0[2]);
// log::trace!("rz0.2: {:?}", rz0[2]);
}

end_timer!(timer);
Expand Down
6 changes: 3 additions & 3 deletions gkr/src/tests/gkr_correctness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ fn do_prove_verify<C: GKRConfig>(config: Config<C>, circuit: &mut Circuit<C>) {
let (mut claimed_v, proof) = prover.prove(circuit);

// Verify
// let verifier = Verifier::new(&config);
// let public_input = vec![];
// assert!(verifier.verify(circuit, &public_input, &mut claimed_v, &proof))
let verifier = Verifier::new(&config);
let public_input = vec![];
assert!(verifier.verify(circuit, &public_input, &mut claimed_v, &proof))
}
99 changes: 68 additions & 31 deletions gkr/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use arith::{Field, FieldSerde};
use ark_std::{end_timer, start_timer};
use circuit::{Circuit, CircuitLayer};
use config::{Config, FiatShamirHashType, GKRConfig, PolynomialCommitmentType};
use config::{Config, FiatShamirHashType, GKRConfig, GKRScheme, PolynomialCommitmentType};
use sumcheck::{GKRVerifierHelper, VerifierScratchPad};
use transcript::{
BytesHashTranscript, FieldHashTranscript, Keccak256hasher, MIMCHasher, Proof, SHA256hasher,
Expand All @@ -17,6 +17,9 @@ use transcript::{
use crate::grind;
use crate::RawCommitment;

mod gkr_square;
pub use gkr_square::gkr_square_verify;

#[inline(always)]
fn verify_sumcheck_step<C: GKRConfig, T: Transcript<C::ChallengeField>>(
mut proof_reader: impl Read,
Expand All @@ -41,6 +44,10 @@ fn verify_sumcheck_step<C: GKRConfig, T: Transcript<C::ChallengeField>>(
*claimed_sum = GKRVerifierHelper::degree_2_eval(&ps, r, sp);
} else if degree == 3 {
*claimed_sum = GKRVerifierHelper::degree_3_eval(&ps, r, sp);
} else if degree == 6 {
*claimed_sum = GKRVerifierHelper::degree_6_eval(&ps, r, sp);
} else {
panic!("unsupported degree");
}

verified
Expand Down Expand Up @@ -287,38 +294,68 @@ impl<C: GKRConfig> Verifier<C> {

circuit.fill_rnd_coefs(transcript);

let (mut verified, rz0, rz1, r_simd, r_mpi, claimed_v0, claimed_v1) = gkr_verify(
&self.config,
circuit,
public_input,
claimed_v,
transcript,
&mut cursor,
);

log::info!("GKR verification: {}", verified);

match self.config.polynomial_commitment_type {
PolynomialCommitmentType::Raw => {
// for Raw, no need to load from proof
log::trace!("rz0.size() = {}", rz0.len());
log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len());

let v1 = commitment.mpi_verify(&rz0, &r_simd, &r_mpi, claimed_v0);
verified &= v1;

if rz1.is_some() {
let v2 = commitment.mpi_verify(
rz1.as_ref().unwrap(),
&r_simd,
&r_mpi,
claimed_v1.unwrap(),
);
verified &= v2;
let verified = match self.config.gkr_scheme {
GKRScheme::Vanilla => {
let (mut verified, rz0, rz1, r_simd, r_mpi, claimed_v0, claimed_v1) = gkr_verify(
&self.config,
circuit,
public_input,
claimed_v,
transcript,
&mut cursor,
);

log::info!("GKR verification: {}", verified);

match self.config.polynomial_commitment_type {
PolynomialCommitmentType::Raw => {
// for Raw, no need to load from proof
log::trace!("rz0.size() = {}", rz0.len());
log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len());

let v1 = commitment.mpi_verify(&rz0, &r_simd, &r_mpi, claimed_v0);
verified &= v1;

if rz1.is_some() {
let v2 = commitment.mpi_verify(
rz1.as_ref().unwrap(),
&r_simd,
&r_mpi,
claimed_v1.unwrap(),
);
verified &= v2;
}
}
_ => todo!(),
}
verified
}
_ => todo!(),
}
GKRScheme::GkrSquare => {
let (mut verified, rz, r_simd, r_mpi, claimed_v) = gkr_square_verify(
&self.config,
circuit,
public_input,
claimed_v,
transcript,
&mut cursor,
);

log::info!("GKR verification: {}", verified);

match self.config.polynomial_commitment_type {
PolynomialCommitmentType::Raw => {
// for Raw, no need to load from proof
log::trace!("rz.size() = {}", rz.len());
log::trace!("Poly_vals.size() = {}", commitment.poly_vals.len());

let v1 = commitment.mpi_verify(&rz, &r_simd, &r_mpi, claimed_v);
verified &= v1;
}
_ => todo!(),
}
verified
}
};

end_timer!(timer);

Expand Down
161 changes: 161 additions & 0 deletions gkr/src/verifier/gkr_square.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use super::verify_sumcheck_step;
use arith::{Field, FieldSerde};
use ark_std::{end_timer, start_timer};
use circuit::{Circuit, CircuitLayer};
use config::{Config, GKRConfig};
use std::{io::Read, vec};
use sumcheck::{GKRVerifierHelper, VerifierScratchPad};
use transcript::Transcript;

#[allow(clippy::type_complexity)]
pub fn gkr_square_verify<C: GKRConfig, T: Transcript<C::ChallengeField>>(
config: &Config<C>,
circuit: &Circuit<C>,
public_input: &[C::SimdCircuitField],
claimed_v: &C::ChallengeField,
transcript: &mut T,
mut proof_reader: impl Read,
) -> (
bool,
Vec<C::ChallengeField>,
Vec<C::ChallengeField>,
Vec<C::ChallengeField>,
C::ChallengeField,
) {
let timer = start_timer!(|| "gkr verify");
let mut sp = VerifierScratchPad::<C>::new(config, circuit);

let layer_num = circuit.layers.len();
let mut rz = vec![];
let mut r_simd = vec![];
let mut r_mpi = vec![];

for _ in 0..circuit.layers.last().unwrap().output_var_num {
rz.push(transcript.generate_challenge_field_element());
}
log::trace!("rz {:?}", rz);

for _ in 0..C::get_field_pack_size().trailing_zeros() {
r_simd.push(transcript.generate_challenge_field_element());
}
log::trace!("r_simd {:?}", r_simd);

// TODO: MPI support
assert_eq!(
config.mpi_config.world_size().trailing_zeros(),
0,
"MPI not supported yet"
);
// for _ in 0..config.mpi_config.world_size().trailing_zeros() {
// r_mpi.push(transcript.generate_challenge_field_element());
// }

let mut verified = true;
let mut current_claim = *claimed_v;
for i in (0..layer_num).rev() {
let cur_verified;
(cur_verified, rz, r_simd, r_mpi, current_claim) = sumcheck_verify_gkr_square_layer(
config,
&circuit.layers[i],
public_input,
&rz,
&r_simd,
&r_mpi,
current_claim,
&mut proof_reader,
transcript,
&mut sp,
i == layer_num - 1,
);
verified &= cur_verified;
}
end_timer!(timer);
(verified, rz, r_simd, r_mpi, current_claim)
}

#[allow(clippy::too_many_arguments)]
#[allow(clippy::type_complexity)]
#[allow(clippy::unnecessary_unwrap)]
fn sumcheck_verify_gkr_square_layer<C: GKRConfig, T: Transcript<C::ChallengeField>>(
config: &Config<C>,
layer: &CircuitLayer<C>,
public_input: &[C::SimdCircuitField],
rz: &[C::ChallengeField],
r_simd: &Vec<C::ChallengeField>,
r_mpi: &Vec<C::ChallengeField>,
current_claim: C::ChallengeField,
mut proof_reader: impl Read,
transcript: &mut T,
sp: &mut VerifierScratchPad<C>,
is_output_layer: bool,
) -> (
bool,
Vec<C::ChallengeField>,
Vec<C::ChallengeField>,
Vec<C::ChallengeField>,
C::ChallengeField,
) {
// GKR2 with Power5 gate has degree 6 polynomial
let degree = 6;

GKRVerifierHelper::prepare_layer(layer, &None, rz, &None, r_simd, r_mpi, sp, is_output_layer);

let var_num = layer.input_var_num;
let mut sum = current_claim;
sum -= GKRVerifierHelper::eval_cst(&layer.const_, public_input, sp);

let mut rx = vec![];
let mut r_simd_var = vec![];
let mut r_mpi_var = vec![];
let mut verified = true;

for i_var in 0..var_num {
verified &= verify_sumcheck_step::<C, T>(
&mut proof_reader,
degree,
transcript,
&mut sum,
&mut rx,
sp,
);
log::trace!("x {} var, verified? {}", i_var, verified);
}
GKRVerifierHelper::set_rx(&rx, sp);

for i_var in 0..C::get_field_pack_size().trailing_zeros() {
verified &= verify_sumcheck_step::<C, T>(
&mut proof_reader,
degree,
transcript,
&mut sum,
&mut r_simd_var,
sp,
);
log::trace!("simd {} var, verified? {}", i_var, verified);
}
GKRVerifierHelper::set_r_simd_xy(&r_simd_var, sp);

// TODO: nontrivial MPI support
for _i_var in 0..config.mpi_config.world_size().trailing_zeros() {
verified &= verify_sumcheck_step::<C, T>(
&mut proof_reader,
3,
transcript,
&mut sum,
&mut r_mpi_var,
sp,
);
// println!("{} mpi var, verified? {}", _i_var, verified);
}
GKRVerifierHelper::set_r_mpi_xy(&r_mpi_var, sp);

let v_claim = C::ChallengeField::deserialize_from(&mut proof_reader).unwrap();

sum -= v_claim * GKRVerifierHelper::eval_pow_1(&layer.uni, sp)
+ v_claim.exp(5) * GKRVerifierHelper::eval_pow_5(&layer.uni, sp);
transcript.append_field_element(&v_claim);

verified &= sum == C::ChallengeField::ZERO;

(verified, rx, r_simd_var, r_mpi_var, v_claim)
}
31 changes: 31 additions & 0 deletions sumcheck/src/scratch_pad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ pub struct VerifierScratchPad<C: GKRConfig> {
pub gf2_deg2_eval_coef: C::ChallengeField, // 1 / x(x - 1)
pub deg3_eval_at: [C::ChallengeField; 4],
pub deg3_lag_denoms_inv: [C::ChallengeField; 4],
// ====== for deg6 eval ======
pub deg6_eval_at: [C::ChallengeField; 7],
pub deg6_lag_denoms_inv: [C::ChallengeField; 7],
}

impl<C: GKRConfig> VerifierScratchPad<C> {
Expand Down Expand Up @@ -143,6 +146,32 @@ impl<C: GKRConfig> VerifierScratchPad<C> {
deg3_lag_denoms_inv[i] = denominator.inv().unwrap();
}

let deg6_eval_at = if C::FIELD_TYPE == FieldType::GF2 {
panic!("GF2 not supported yet");
} else {
[
C::ChallengeField::ZERO,
C::ChallengeField::ONE,
C::ChallengeField::from(2),
C::ChallengeField::from(3),
C::ChallengeField::from(4),
C::ChallengeField::from(5),
C::ChallengeField::from(6),
]
};

let mut deg6_lag_denoms_inv = [C::ChallengeField::ZERO; 7];
for i in 0..7 {
let mut denominator = C::ChallengeField::ONE;
for j in 0..7 {
if j == i {
continue;
}
denominator *= deg6_eval_at[i] - deg6_eval_at[j];
}
deg6_lag_denoms_inv[i] = denominator.inv().unwrap();
}

Self {
eq_evals_at_rz0: vec![C::ChallengeField::zero(); max_io_size],
eq_evals_at_r_simd: vec![C::ChallengeField::zero(); simd_size],
Expand All @@ -162,6 +191,8 @@ impl<C: GKRConfig> VerifierScratchPad<C> {
gf2_deg2_eval_coef,
deg3_eval_at,
deg3_lag_denoms_inv,
deg6_eval_at,
deg6_lag_denoms_inv,
}
}
}
Loading

0 comments on commit 362916b

Please sign in to comment.