Skip to content

Commit 0c68a2d

Browse files
committed
chore: update gkr2 correctness test (M31 only)
1 parent a5dbff5 commit 0c68a2d

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

gkr/src/tests/gkr_correctness.rs

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -289,35 +289,52 @@ pub fn gkr_square_test_circuit<C: GKRConfig>() -> Circuit<C> {
289289
circuit
290290
}
291291

292+
// CMD: RUSTFLAGS="-C target-cpu=native" mpiexec -n 4 cargo test --package gkr --lib -- tests::gkr_correctness::gkr_square_correctness_test --exact --show-output
292293
#[test]
293-
fn gkr_square_correctness() {
294-
type GkrConfigType = config::M31ExtConfigSha2;
295-
294+
fn gkr_square_correctness_test() {
296295
env_logger::init();
296+
let mpi_config = MPIConfig::new();
297+
assert!(gkr_square_correctness::<config::M31ExtConfigSha2>(
298+
mpi_config.clone()
299+
));
300+
assert!(gkr_square_correctness::<config::M31ExtConfigKeccak>(
301+
mpi_config
302+
));
303+
MPIConfig::finalize();
304+
}
297305

298-
let mut circuit = gkr_square_test_circuit::<GkrConfigType>();
306+
fn gkr_square_correctness<C: GKRConfig>(mpi_config: MPIConfig) -> bool {
307+
let config = Config::<C>::new(GKRScheme::GkrSquare, mpi_config);
308+
309+
let mut circuit = gkr_square_test_circuit::<C>();
299310
// Set input layers with N_2_0 = 3, N_2_1 = 5, N_2_2 = 7,
300-
// and N_2_3 varying from 0 to 15
301-
let final_vals = (0..16).map(|x| x.into()).collect::<Vec<_>>();
302-
let final_vals = <GkrConfigType as GKRConfig>::SimdCircuitField::pack(&final_vals);
311+
// and N_2_3 varying from 0 to SIMD packing size
312+
let mut final_vals = (0..16).map(|x| x.into()).collect::<Vec<_>>();
313+
// Add variety for MPI participants
314+
final_vals[0] += C::CircuitField::from(config.mpi_config.world_rank as u32);
315+
let final_vals = C::SimdCircuitField::pack(&final_vals);
303316
circuit.layers[0].input_vals = vec![2.into(), 3.into(), 5.into(), final_vals];
304317
// Set public input PI[0] = 13
305318
circuit.public_input = vec![13.into()];
306319

307-
let config = Config::<GkrConfigType>::new(GKRScheme::GkrSquare, MPIConfig::default());
308-
do_prove_verify(config, &mut circuit);
320+
let result = do_prove_verify(config, &mut circuit);
321+
result
309322
}
310323

311-
fn do_prove_verify<C: GKRConfig>(config: Config<C>, circuit: &mut Circuit<C>) {
324+
fn do_prove_verify<C: GKRConfig>(config: Config<C>, circuit: &mut Circuit<C>) -> bool {
312325
circuit.evaluate();
313326

314327
// Prove
315328
let mut prover = Prover::new(&config);
316329
prover.prepare_mem(&circuit);
317330
let (mut claimed_v, proof) = prover.prove(circuit);
318331

319-
// Verify
320-
let verifier = Verifier::new(&config);
321-
let public_input = circuit.public_input.clone();
322-
assert!(verifier.verify(circuit, &public_input, &mut claimed_v, &proof))
332+
// Verify if root process
333+
if config.mpi_config.is_root() {
334+
let verifier = Verifier::new(&config);
335+
let public_input = circuit.public_input.clone();
336+
return verifier.verify(circuit, &public_input, &mut claimed_v, &proof);
337+
}
338+
// Non-root processes return true
339+
return true;
323340
}

0 commit comments

Comments
 (0)