Skip to content

Commit 59c8909

Browse files
committed
call it a stop here on the cache friendly transpose
1 parent d529716 commit 59c8909

File tree

2 files changed

+118
-32
lines changed

2 files changed

+118
-32
lines changed

gkr/src/poly_commit/orion.rs

+112-26
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//! Includes implementation for Orion Expander-Code.
33
44
use arith::Field;
5+
use polynomials::MultiLinearPoly;
56
use rand::seq::index;
67
use thiserror::Error;
78

@@ -125,6 +126,8 @@ impl OrionCodeParameter {
125126
}
126127
}
127128

129+
// TODO: fix a set of Orion code parameters for message length ranging 2^5 - 2^15
130+
128131
#[derive(Clone)]
129132
pub struct OrionExpanderGraphPositioned {
130133
pub graph: OrionExpanderGraph,
@@ -244,24 +247,56 @@ impl OrionCode {
244247

245248
#[inline(always)]
246249
pub fn encode<F: Field>(&self, msg: &[F]) -> OrionResult<OrionCodeword<F>> {
247-
if msg.len() != self.msg_len() {
250+
let mut codeword = vec![F::ZERO; self.code_len()];
251+
self.encode_in_place(msg, &mut codeword)?;
252+
Ok(codeword)
253+
}
254+
255+
#[inline(always)]
256+
pub fn encode_in_place<F: Field>(&self, msg: &[F], buffer: &mut [F]) -> OrionResult<()> {
257+
if msg.len() != self.msg_len() || buffer.len() != self.code_len() {
248258
return Err(OrionPCSError::ParameterUnmatchError);
249259
}
250260

251-
let mut codeword = vec![F::ZERO; self.code_len()];
252-
codeword[..self.msg_len()].copy_from_slice(msg);
253-
261+
buffer[..self.msg_len()].copy_from_slice(msg);
254262
let mut scratch = vec![F::ZERO; self.code_len()];
255263

256264
self.g0s
257265
.iter()
258266
.chain(self.g1s.iter())
259-
.try_for_each(|g| g.expander_mul(&mut codeword, &mut scratch))?;
260-
261-
Ok(codeword)
267+
.try_for_each(|g| g.expander_mul(buffer, &mut scratch))
262268
}
263269
}
264270

271+
/****************************************
272+
* IMPLEMENTATIONS FOR MATRIX TRANSPOSE *
273+
****************************************/
274+
275+
pub(crate) const fn cache_batch_size<F: Sized>() -> usize {
276+
const CACHE_SIZE: usize = 1 << 16;
277+
CACHE_SIZE / size_of::<F>()
278+
}
279+
280+
// NOTE we assume that the matrix has sides of length po2
281+
pub(crate) fn transpose_in_place<F: Field>(mat: &mut [F], scratch: &mut [F], row_num: usize) {
282+
let col_num = mat.len() / row_num;
283+
let batch_size = cache_batch_size::<F>();
284+
285+
mat.chunks(batch_size)
286+
.enumerate()
287+
.for_each(|(i, ith_batch)| {
288+
let src_starts = i * batch_size;
289+
let dst_starts = (src_starts / col_num) + (src_starts % col_num) * row_num;
290+
291+
ith_batch
292+
.iter()
293+
.enumerate()
294+
.for_each(|(j, &elem_j)| scratch[dst_starts + j * row_num] = elem_j)
295+
});
296+
297+
mat.copy_from_slice(scratch);
298+
}
299+
265300
/**********************************************************
266301
* IMPLEMENTATIONS FOR ORION POLYNOMIAL COMMITMENT SCHEME *
267302
**********************************************************/
@@ -273,45 +308,81 @@ pub struct OrionPCSImpl {
273308
pub code_instance: OrionCode,
274309
}
275310

276-
// TODO use interleaved codeword and commit against interleaved alphabets
277-
#[allow(unused)]
278-
type InterleavedOrionCodeword<F> = Vec<OrionCodeword<F>>;
279-
280311
impl OrionPCSImpl {
281-
// TODO: check num_variables ~ code_params.msg_len()
282-
pub fn new(num_variables: usize, code_instance: OrionCode) -> Self {
312+
fn row_col_from_variables(num_variables: usize) -> (usize, usize) {
313+
let poly_variables: usize = num_variables;
314+
315+
// NOTE(Hang): rounding up here in halving the poly variable num
316+
// up to discussion if we want to half by round down
317+
let row_num: usize = 1 << ((poly_variables + 1) / 2);
318+
let msg_size: usize = (1 << poly_variables) / row_num;
319+
320+
(row_num, msg_size)
321+
}
322+
323+
pub fn new(num_variables: usize, code_instance: OrionCode) -> OrionResult<Self> {
324+
let (_, msg_size) = Self::row_col_from_variables(num_variables);
325+
if msg_size != code_instance.msg_len() {
326+
return Err(OrionPCSError::ParameterUnmatchError);
327+
}
328+
283329
// NOTE: we just move the instance of code,
284330
// don't think the instance of expander code will be used elsewhere
285-
Self {
331+
Ok(Self {
286332
num_variables,
287333
code_instance,
288-
}
334+
})
289335
}
290336

291-
// TODO: check num_variables ~ code_params.msg_len()
292337
pub fn from_random(
293338
num_variables: usize,
339+
// TODO: should be removed with a precomputed list of params
294340
code_params: OrionCodeParameter,
295341
mut rng: impl rand::RngCore,
296-
) -> Self {
297-
Self {
342+
) -> OrionResult<Self> {
343+
let (_, msg_size) = Self::row_col_from_variables(num_variables);
344+
if msg_size != code_params.input_message_len {
345+
return Err(OrionPCSError::ParameterUnmatchError);
346+
}
347+
348+
Ok(Self {
298349
num_variables,
299350
code_instance: OrionCode::new(code_params, &mut rng),
300-
}
351+
})
301352
}
302353

303354
// TODO query complexity for how many queries one need for interleaved codeword
304-
pub fn query_complexity(#[allow(unused)] soundness_bits: usize) -> usize {
355+
pub fn query_complexity(&self, #[allow(unused)] soundness_bits: usize) -> usize {
305356
todo!()
306357
}
307358

308-
// TODO multilinear polynomial
309-
// TODO write to matrix, encode each row (k x k matrix)
310-
// TODO need a merkle tree to commit each column (k x n matrix)
311-
// - TODO need a cache friendly transpose
312-
// TODO need a merkle tree to commit against all merkle tree roots
313359
// TODO commitment with data
314-
pub fn commit() {
360+
pub fn commit<F: Field>(&self, poly: &MultiLinearPoly<F>) -> OrionResult<()> {
361+
let (row_num, msg_size) = Self::row_col_from_variables(poly.get_num_vars());
362+
363+
// NOTE(Hang): another idea - if the inv_code_rate happens to be a po2
364+
// then it would very much favor us, as matrix will be square,
365+
// or composed by 2 squared matrices
366+
367+
let mut interleaved_codeword_buffer =
368+
vec![F::ZERO; row_num * self.code_instance.code_len()];
369+
370+
// NOTE: now the interleaved codeword is k x n matrix from expander code
371+
poly.coeffs
372+
.chunks(msg_size)
373+
.zip(interleaved_codeword_buffer.chunks_mut(self.code_instance.msg_len()))
374+
.try_for_each(|(row_i, codeword_i)| {
375+
self.code_instance.encode_in_place(row_i, codeword_i)
376+
})?;
377+
378+
// NOTE: the interleaved codeword buffer is n x k matrix
379+
// with each column being an expander code
380+
let mut scratch = vec![F::ZERO; row_num * self.code_instance.code_len()];
381+
transpose_in_place(&mut interleaved_codeword_buffer, &mut scratch, row_num);
382+
drop(scratch);
383+
384+
// TODO need a merkle tree to commit against all merkle tree roots
385+
315386
todo!()
316387
}
317388

@@ -326,6 +397,21 @@ impl OrionPCSImpl {
326397
pub fn verify() {
327398
todo!()
328399
}
400+
401+
// TODO after commit and open
402+
pub fn batch_commit() {
403+
todo!()
404+
}
405+
406+
// TODO after commit and open
407+
pub fn batch_open() {
408+
todo!()
409+
}
410+
411+
// TODO after commit and open
412+
pub fn batch_verify() {
413+
todo!()
414+
}
329415
}
330416

331417
// TODO waiting on a unified multilinear PCS trait - align OrionPCSImpl against PCS trait

gkr/src/poly_commit/orion_test.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use arith::Field;
22
use ark_std::test_rng;
33
use gf2_128::GF2_128;
4+
use mersenne31::M31Ext3;
45

56
use crate::{OrionCode, OrionCodeParameter};
67

@@ -14,9 +15,7 @@ fn gen_msg_codeword<F: Field>(code: &OrionCode, mut rng: impl rand::RngCore) ->
1415
(random_msg, codeword)
1516
}
1617

17-
fn linear_combine<F: Field>(vec_s: &Vec<Vec<F>>, scalars: &[F]) -> Vec<F> {
18-
assert_eq!(vec_s.len(), scalars.len());
19-
18+
fn row_combination<F: Field>(vec_s: &[Vec<F>], scalars: &[F]) -> Vec<F> {
2019
let mut out = vec![F::ZERO; vec_s[0].len()];
2120

2221
scalars
@@ -64,8 +63,8 @@ fn test_orion_code_generic<F: Field>() {
6463
.map(|_| gen_msg_codeword(&orion_code, &mut rng))
6564
.unzip();
6665

67-
let msg_linear_combined = linear_combine(&msgs, &random_scalrs);
68-
let codeword_linear_combined = linear_combine(&codewords, &random_scalrs);
66+
let msg_linear_combined = row_combination(&msgs, &random_scalrs);
67+
let codeword_linear_combined = row_combination(&codewords, &random_scalrs);
6968

7069
let codeword_computed = orion_code.encode(&msg_linear_combined).unwrap();
7170

@@ -74,5 +73,6 @@ fn test_orion_code_generic<F: Field>() {
7473

7574
#[test]
7675
fn test_orion_code() {
77-
test_orion_code_generic::<GF2_128>()
76+
test_orion_code_generic::<GF2_128>();
77+
test_orion_code_generic::<M31Ext3>();
7878
}

0 commit comments

Comments
 (0)