@@ -289,35 +289,52 @@ pub fn gkr_square_test_circuit<C: GKRConfig>() -> Circuit<C> {
289289 circuit
290290}
291291
292+ // CMD: RUSTFLAGS="-C target-cpu=native" mpiexec -n 4 cargo test --package gkr --lib -- tests::gkr_correctness::gkr_square_correctness_test --exact --show-output
292293#[ test]
293- fn gkr_square_correctness ( ) {
294- type GkrConfigType = config:: M31ExtConfigSha2 ;
295-
294+ fn gkr_square_correctness_test ( ) {
296295 env_logger:: init ( ) ;
296+ let mpi_config = MPIConfig :: new ( ) ;
297+ assert ! ( gkr_square_correctness:: <config:: M31ExtConfigSha2 >(
298+ mpi_config. clone( )
299+ ) ) ;
300+ assert ! ( gkr_square_correctness:: <config:: M31ExtConfigKeccak >(
301+ mpi_config
302+ ) ) ;
303+ MPIConfig :: finalize ( ) ;
304+ }
297305
298- let mut circuit = gkr_square_test_circuit :: < GkrConfigType > ( ) ;
306+ fn gkr_square_correctness < C : GKRConfig > ( mpi_config : MPIConfig ) -> bool {
307+ let config = Config :: < C > :: new ( GKRScheme :: GkrSquare , mpi_config) ;
308+
309+ let mut circuit = gkr_square_test_circuit :: < C > ( ) ;
299310 // Set input layers with N_2_0 = 3, N_2_1 = 5, N_2_2 = 7,
300- // and N_2_3 varying from 0 to 15
301- let final_vals = ( 0 ..16 ) . map ( |x| x. into ( ) ) . collect :: < Vec < _ > > ( ) ;
302- let final_vals = <GkrConfigType as GKRConfig >:: SimdCircuitField :: pack ( & final_vals) ;
311+ // and N_2_3 varying from 0 to SIMD packing size
312+ let mut final_vals = ( 0 ..16 ) . map ( |x| x. into ( ) ) . collect :: < Vec < _ > > ( ) ;
313+ // Add variety for MPI participants
314+ final_vals[ 0 ] += C :: CircuitField :: from ( config. mpi_config . world_rank as u32 ) ;
315+ let final_vals = C :: SimdCircuitField :: pack ( & final_vals) ;
303316 circuit. layers [ 0 ] . input_vals = vec ! [ 2 . into( ) , 3 . into( ) , 5 . into( ) , final_vals] ;
304317 // Set public input PI[0] = 13
305318 circuit. public_input = vec ! [ 13 . into( ) ] ;
306319
307- let config = Config :: < GkrConfigType > :: new ( GKRScheme :: GkrSquare , MPIConfig :: default ( ) ) ;
308- do_prove_verify ( config , & mut circuit ) ;
320+ let result = do_prove_verify ( config , & mut circuit ) ;
321+ result
309322}
310323
311- fn do_prove_verify < C : GKRConfig > ( config : Config < C > , circuit : & mut Circuit < C > ) {
324+ fn do_prove_verify < C : GKRConfig > ( config : Config < C > , circuit : & mut Circuit < C > ) -> bool {
312325 circuit. evaluate ( ) ;
313326
314327 // Prove
315328 let mut prover = Prover :: new ( & config) ;
316329 prover. prepare_mem ( & circuit) ;
317330 let ( mut claimed_v, proof) = prover. prove ( circuit) ;
318331
319- // Verify
320- let verifier = Verifier :: new ( & config) ;
321- let public_input = circuit. public_input . clone ( ) ;
322- assert ! ( verifier. verify( circuit, & public_input, & mut claimed_v, & proof) )
332+ // Verify if root process
333+ if config. mpi_config . is_root ( ) {
334+ let verifier = Verifier :: new ( & config) ;
335+ let public_input = circuit. public_input . clone ( ) ;
336+ return verifier. verify ( circuit, & public_input, & mut claimed_v, & proof) ;
337+ }
338+ // Non-root processes return true
339+ return true ;
323340}
0 commit comments