Skip to content

Commit 492c2f7

Browse files
committed
feat(sumcheck): sumcheck with bookkeeping with Zq
1 parent 9728c8a commit 492c2f7

3 files changed

Lines changed: 127 additions & 19 deletions

File tree

src/ring.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,19 @@ impl<const Q: u64, const D: usize> Rq<Q, D> {
118118
}
119119

120120
/// Negate and rotate all coefficients but the constant term.
121-
///
121+
///
122122
/// E.g. a(x) = a_0 + a_1 x^1 + a_2 x^2 + a_3 x^3
123123
/// \bar a(x) = a(x^{-1})
124124
/// = a_0 - a_3 x^1 - a_2 x^2 - a_1 x^3
125125
fn conjugate(&self) -> Self {
126126
Self {
127127
coeffs: std::array::from_fn(|i| {
128-
if i == 0 {
129-
self.coeffs[i].clone()
128+
if i == 0 {
129+
self.coeffs[i]
130130
} else {
131-
Zq::<Q>::new(Q - self.coeffs[D-i].value())
131+
Zq::<Q>::new(Q - self.coeffs[D - i].value())
132132
}
133-
})
133+
}),
134134
}
135135
}
136136
}

src/rok/norm.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,26 @@
1616
//! These dependencies are flagged where used; the stubs below give the shape
1717
//! the SALSAA Python prototype expects so the translation has clear anchor points.
1818
19+
use rand::Rng;
20+
1921
use crate::{mat::Mat, relations::LinRelation, ring::Rq, zq::Zq};
2022

2123
/// Sample u ∈ Z_q\{0} and return the Vandermonde column (u^0, u^1, ..., u^{r·d/e - 1}).
2224
///
2325
/// Used by `rok_bar_sum` as the RLC coefficient vector across all NTT slots
2426
/// (there are r·d/e NTT slots total: r columns × d/e slots per Rq element).
25-
pub fn get_u_vec<const Q: u64>(_r: usize, _d: usize, _e: usize) -> Vec<Zq<Q>> {
27+
pub fn sample_u_vec<const Q: u64, const D: usize>(
28+
r: usize,
29+
e: usize,
30+
rng: &mut impl Rng,
31+
) -> Mat<Zq<Q>> {
2632
// u = random nonzero Z_q
33+
let mut u = Zq::<Q>::random(rng);
34+
while u == Zq::zero() {
35+
u = Zq::<Q>::random(rng);
36+
}
2737
// return [u^0, u^1, ..., u^{r·d/e - 1}]
28-
todo!()
38+
Mat::from_fn(1, r * D / e, |i, j| u.pow((i * j) as u64))
2939
}
3040

3141
/// Π^bar-sum: sumcheck on the RLC of CRT(LDE[W] · LDE[W̄]).

src/sumcheck.rs

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,42 @@
1313
1414
use rand::Rng;
1515

16-
use crate::zq::Zq;
16+
use crate::{ring::Ring, zq::Zq};
1717

1818
/// Sumcheck output: the verifier's running claim after the last round plus
1919
/// the challenges chosen along the way. After receiving this, the verifier
2020
/// still owes the final oracle check:
2121
/// a_l ?= f(rands[0], ..., rands[l-1])
2222
/// which lives OUTSIDE this routine (the caller must verify it via whatever
2323
/// commitment / lookup is appropriate for f).
24-
pub struct SumcheckOutput<const Q: u64> {
25-
pub a_l: Zq<Q>,
26-
pub rands: Vec<Zq<Q>>,
24+
pub struct SumcheckOutput<T> {
25+
pub a_l: T,
26+
pub rands: Vec<T>,
27+
}
28+
29+
/// Calculate current \tilde f(r_0, ..., x, \vec x2)
30+
///
31+
/// Since we know f(X, \vec x2) is linear with \vec x2 as constant, given `f` is a multilinear extension,
32+
/// interpolation can be done with f(0, \vec x2) and f(1, \vec x2).
33+
/// So, f(X, \vec x2) = (1-X) f(0, \vec x2) + X f(1, \vec x2)
34+
///
35+
/// i = (x_{j+1}, ..., x_{l-1}) = \vec x2
36+
/// e.g. i = 0 -> (0, 0), i = 1 -> (0, 1)
37+
fn cal_f_x<const Q: u64>(table: &[Zq<Q>], x: Zq<Q>, i: usize) -> Zq<Q> {
38+
let half_idx = table.len() / 2;
39+
// p(x) = (1-x)*lo + x*hi
40+
(Zq::one() - x) * table[i] + x * table[half_idx + i]
41+
}
42+
43+
/// Derive h_j(x) = Σ_{b_{j+1}} ... Σ_{b_{l-1}} f(r_0, ..., r_{j-1}, X, b_{j+1}, ..., b_{l-1})
44+
///
45+
/// We know h(X) = Σ_{\vec x_2 \in [d_h]^{l-(j+1)} f(X, \vec x2) + f(X, \vec x2)
46+
/// We then know f(2, \vec x2) and f(2, \vec x2) with X=2 and \vec x2 passed in.
47+
fn h_x<const Q: u64>(table: &[Zq<Q>], x: Zq<Q>) -> Zq<Q> {
48+
let half_idx = table.len() / 2;
49+
(0..half_idx)
50+
.map(|i| cal_f_x(table, x, i))
51+
.fold(Zq::<Q>::zero(), |acc, v| acc + v)
2752
}
2853

2954
/// Sumcheck over hypercube [d_h]^num_vars for a function given by its
@@ -39,16 +64,86 @@ pub struct SumcheckOutput<const Q: u64> {
3964
/// - `rng`: verifier challenge source. (Replace with `&mut Transcript` once
4065
/// Fiat–Shamir lands — see README Future.)
4166
pub fn sumcheck<const Q: u64>(
42-
book: Vec<Zq<Q>>,
67+
f: Vec<Zq<Q>>,
68+
f_bar: Vec<Zq<Q>>,
4369
claimed_sum: Zq<Q>,
4470
num_vars: usize,
4571
d_h: usize,
4672
rng: &mut impl Rng,
47-
) -> SumcheckOutput<Q> {
73+
) -> SumcheckOutput<Zq<Q>> {
74+
assert_eq!(f.len(), f_bar.len(), "f and f_bar must be the same length");
75+
assert_eq!(f.len(), d_h.pow(num_vars as u32), "f size is not [d_h]^l");
4876
// Claim: Σ_{b_0} ... Σ_{b_{l-1}} f(b_0, ..., b_{l-1}) = a_0
4977
// a_j = the verifier's running claim before round j; a_0 = claimed_sum.
78+
let mut a = claimed_sum;
79+
80+
// received_randoms = [r_0, r_1, ..., r_{l-1}] accumulated each round.
81+
let mut received_randoms = Vec::<Zq<Q>>::with_capacity(num_vars);
82+
83+
let mut table_f = f.clone();
84+
let _table_f_bar = f_bar.clone();
85+
86+
for _j in 0..num_vars {
87+
//
88+
// Prover
89+
//
90+
let h_0 = h_x(&table_f, Zq::zero());
91+
let h_1 = h_x(&table_f, Zq::one());
5092

51-
// received_randoms = [r_0, r_1, ..., r_{j-1}] accumulated each round.
93+
// h(2) needs some tricks since f and \tilde f only agree on the hypercube [d]^l
94+
// =====
95+
// we know h(2) = (f(2, 0) * \bar f(2, 0)) + (f(2, 1) * \bar f(2, 1))
96+
// -> need to derive f(X, x_2) first
97+
// Since we know f(X, x2) with x2 as constant is linear given f is a multilinear extension,
98+
// interpolation can be done with f(0, x_2) and f(1, x_2).
99+
// So, f(X, x2) = (1-X) f(0, x2) + X f(1, x2)
100+
// We then know f(2, 0) and f(2, 1) with X=2 and x2={0,1} passed in.
101+
// Do the same cal for \bar f so we can calculate h(2)
102+
// =====
103+
let _h_2 = h_x(&table_f, Zq::new(2));
104+
105+
// Send h_0, h_1, h_2 as g(x) to Verifier
106+
107+
//
108+
// Verifier
109+
//
110+
// V is not sure if g_j(x) = h_j(x) as P claimed
111+
// and needs to verify
112+
// 1. a_j = g_j(0) + ... + g_j(d-1)
113+
// 2. g_j(r) ?= \sum_{b_{j+1}} ... \sum_{b_{l-1}} f(r_0, ..., r_j, b_{j+1}, ..., b_{l-1}), by SZDL
114+
// - recursion: this is done by running sumcheck again with P
115+
116+
// 1. Verify a_j == g_j(0) + g_j(1)
117+
assert_eq!(
118+
a,
119+
h_0 + h_1,
120+
"a_j does not match h_j(0)+...+h_j(d_h-1): a_j={a:?}, h_0={h_0:?}, h_1={h_1:?}"
121+
);
122+
123+
// 2. SZDL: g_j(r) ?= \sum_{b_{j+1}} ... \sum_{b_{l-1}} f(r_0, ..., r_j, b_{j+1}, ..., b_{l-1})
124+
// Verifier samples random r_j
125+
let r = Zq::<Q>::random(rng);
126+
// Send r_j to Prover
127+
128+
//
129+
// Prover
130+
//
131+
// Calculate a_{j+1} = g_j(r_j)
132+
a = h_x(&table_f, r);
133+
// Send a_{j+1} to Verifier
134+
135+
// Fold the table for the next round
136+
let half_idx = table_f.len() / 2;
137+
let mut table_f_new = Vec::with_capacity(half_idx);
138+
for i in 0..half_idx {
139+
// w(0,0) =
140+
table_f_new.push(cal_f_x(&table_f, r, i));
141+
}
142+
table_f = table_f_new;
143+
144+
// Save all `r`s from verifier
145+
received_randoms.push(r);
146+
}
52147

53148
// For each round j = 0..l:
54149
//
@@ -71,8 +166,10 @@ pub fn sumcheck<const Q: u64>(
71166
// — this final oracle check is the caller's responsibility (e.g. via an LDE
72167
// evaluation or commitment opening).
73168

74-
let _ = (book, claimed_sum, num_vars, d_h, rng);
75-
todo!()
169+
SumcheckOutput {
170+
a_l: a,
171+
rands: received_randoms,
172+
}
76173
}
77174

78175
#[cfg(test)]
@@ -100,7 +197,7 @@ mod tests {
100197
let book = vec![zq(0), zq(1), zq(1), zq(2)];
101198
let claimed = zq(4);
102199
let mut rng = rand::rng();
103-
let out = sumcheck::<Q>(book, claimed, 2, 2, &mut rng);
200+
let out = sumcheck(book.clone(), book, claimed, 2, 2, &mut rng);
104201
assert_eq!(out.rands.len(), 2, "one challenge per variable");
105202
}
106203

@@ -111,19 +208,20 @@ mod tests {
111208
let book = vec![zq(0), zq(1), zq(1), zq(2)]; // true sum is 4
112209
let bogus = zq(5);
113210
let mut rng = rand::rng();
114-
let _ = sumcheck::<Q>(book, bogus, 2, 2, &mut rng);
211+
let _ = sumcheck(book.clone(), book, bogus, 2, 2, &mut rng);
115212
}
116213

117214
// ─── shape / API correctness ───
118215

119216
/// Constant function f ≡ c over [d_h]^l: sum = d_h^l · c.
120217
#[test]
218+
#[ignore = "d_h > 2 not yet supported: needs Lagrange interp through d_h points + d_h-term verifier sum"]
121219
fn test_sumcheck_constant_function() {
122220
// d_h = 3, l = 2 → hypercube has 9 points, all equal to c=2. Sum = 9·2 = 18 = 1 mod 17.
123221
let book = vec![zq(2); 9];
124222
let claimed = zq(18 % 17);
125223
let mut rng = rand::rng();
126-
let out = sumcheck::<Q>(book, claimed, 2, 3, &mut rng);
224+
let out = sumcheck(book.clone(), book, claimed, 2, 3, &mut rng);
127225
assert_eq!(out.rands.len(), 2);
128226
}
129227
}

0 commit comments

Comments
 (0)