Skip to content

Commit 9bdc7a3

Browse files
committed
feat(fold.rs): add rok_fold and skeleton for the rests
1 parent 70ac0dc commit 9bdc7a3

9 files changed

Lines changed: 1535 additions & 8 deletions

File tree

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ pub mod poly;
55
pub mod relations;
66
pub mod ring;
77
pub mod rok;
8+
pub mod salsaa;
9+
pub mod sumcheck;
810
pub mod zq;

src/rok/decompose.rs

Lines changed: 339 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,340 @@
11
//! b-ary decomposition of a witness matrix into low-norm pieces.
2-
//!
3-
//! Reference: `06_salsaa/rok/decompose.py`.
2+
3+
use crate::{mat::Mat, relations::LinRelation, ring::Rq, zq::Zq};
4+
5+
/// ℓ = ⌈log_b(2β + 1)⌉ — number of base-b digits needed to encode the range
6+
/// [-β, β]. Panics on β = 0 (range is degenerate).
7+
pub fn get_l(_beta: u64, _b: u64) -> usize {
8+
todo!()
9+
}
10+
11+
/// Balanced b-ary decomposition of a Z_q element into ℓ digits in [-⌊b/2⌋, ⌊b/2⌋].
12+
///
13+
/// E.g. b = 2: 7 → [ 1, 1, 1, 0, ...]
14+
/// -7 → [-1, -1, -1, 0, ...]
15+
/// E.g. b = 3: 5 → [-1, -1, 1, 0, ...]
16+
///
17+
/// Strategy: take the centered representative, peel off digits via repeated
18+
/// mod-b. If a digit exceeds b/2, subtract b from it and carry +b into the
19+
/// remaining value, keeping every digit balanced.
20+
pub fn balanced_b_ary_decompose_zq<const Q: u64>(_f: Zq<Q>, _b: u64, _l: usize) -> Vec<Zq<Q>> {
21+
todo!()
22+
}
23+
24+
/// Inverse of `balanced_b_ary_decompose_zq`: Σ_i coeffs[i] · b^i.
25+
pub fn compose_zq<const Q: u64>(_coeffs: &[Zq<Q>], _b: u64) -> Zq<Q> {
26+
todo!()
27+
}
28+
29+
/// Decompose witness W into ℓ matrices V_0, ..., V_{ℓ-1} such that
30+
/// W = Σ_k b^k · V_k, with each V_k's polynomial coefficients in [-⌊b/2⌋, ⌊b/2⌋].
31+
///
32+
/// Per-entry decomposition runs `balanced_b_ary_decompose_zq` on every
33+
/// coefficient of every R_q entry of W:
34+
/// r = 4 + 5x + 3x^2 → for each coefficient c at exponent `exp`:
35+
/// c·x^{exp} = (d_0·b^0 + d_1·b^1 + ...) · x^{exp}
36+
/// = d_0·b^0·x^{exp} + d_1·b^1·x^{exp} + ...
37+
/// V_0 V_1 ...
38+
pub fn decompose_w<const Q: u64, const D: usize>(
39+
_w: &Mat<Rq<Q, D>>,
40+
_b: u64,
41+
_l: usize,
42+
) -> Vec<Mat<Rq<Q, D>>> {
43+
todo!()
44+
}
45+
46+
/// Π^b-decomp: decomposes the witness W into ℓ low-norm chunks (V_0, ..., V_{ℓ-1})
47+
/// and widens (W, Y) into (Ŵ, Ŷ) = (V_0 | ... | V_{ℓ-1}, Z_0 | ... | Z_{ℓ-1})
48+
/// where Z_k = H · F · V_k.
49+
///
50+
/// Effect: H, F_com, F_eval, m, n, n̂ preserved; r grows r → ℓ·r; β tightens
51+
/// from the per-entry centered bound.
52+
pub fn rok_decompose<const Q: u64, const D: usize>(
53+
lin: &LinRelation<Q, D>,
54+
b: u64,
55+
) -> LinRelation<Q, D> {
56+
let beta = lin.beta();
57+
let l = get_l(beta, b);
58+
59+
//
60+
// Prover
61+
//
62+
let h = &lin.instance.h;
63+
let f = lin.instance.f();
64+
let w = &lin.witness.w;
65+
// Vs = decompose_w(W, b, l)
66+
// Zs = [H * F * V_k for V_k in Vs]
67+
68+
// V_tilde = [V_0 || ... || V_{l-1}]
69+
// Z_tilde = [Z_0 || ... || Z_{l-1}]
70+
71+
//
72+
// Verifier
73+
//
74+
let y = &lin.instance.y;
75+
// Y ?= Σ_{i=0}^{l-1} b^i · Z_i — verifier recomputes and checks.
76+
77+
//
78+
// Both
79+
//
80+
// Per-coefficient bound after balanced b-ary decomp is [-b/2, b/2].
81+
// column ℓ_2^2 <= m · d · (b//2)^2
82+
// β <= ⌊b/2⌋ · √(m · d)
83+
// Uses isqrt so it stays integer (floor when m·d is not a square).
84+
// new_beta = (b // 2) * isqrt(m * d)
85+
let _ = (l, h, f, w, y, b);
86+
87+
todo!()
88+
}
89+
90+
#[cfg(test)]
91+
mod tests {
92+
use super::*;
93+
use crate::mat::Mat;
94+
use crate::ring::Rq;
95+
use crate::zq::Zq;
96+
97+
const Q: u64 = 17;
98+
const D: usize = 4;
99+
type F = Zq<Q>;
100+
type R = Rq<Q, D>;
101+
102+
/// Z_q element from a signed integer (handles negatives via centered repr).
103+
fn zq(i: i64) -> F {
104+
let q = Q as i64;
105+
let v = i.rem_euclid(q) as u64;
106+
F::new(v)
107+
}
108+
109+
/// Constant polynomial of value `v` in R_q (other coefficients zero).
110+
fn c(v: u64) -> R {
111+
let mut coeffs = [F::zero(); D];
112+
coeffs[0] = F::new(v);
113+
R::new(coeffs)
114+
}
115+
116+
/// Monomial v·x^exp in R_q.
117+
fn mono(v: i64, exp: usize) -> R {
118+
assert!(exp < D);
119+
let mut coeffs = [F::zero(); D];
120+
coeffs[exp] = zq(v);
121+
R::new(coeffs)
122+
}
123+
124+
/// Build a `Mat<R>` of constant-polynomial entries from u64 rows.
125+
fn mat<const N: usize>(rows: &[[u64; N]]) -> Mat<R> {
126+
let v: Vec<Vec<R>> = rows
127+
.iter()
128+
.map(|row| row.iter().map(|&v| c(v)).collect())
129+
.collect();
130+
Mat::new(v)
131+
}
132+
133+
/// Build a valid LinRelation with H = I_{n_top + n_eval} and
134+
/// Y = H · F · W (so the relation invariant holds by construction).
135+
fn build_rel(f_com: Mat<R>, f_eval: Mat<R>, w: Mat<R>, beta: u64) -> LinRelation<Q, D> {
136+
let n_total = f_com.nrows() + f_eval.nrows();
137+
let h = Mat::<R>::identity(n_total);
138+
let f = f_com.stack(&f_eval);
139+
let y = h.clone() * f * w.clone();
140+
let inst = LinInstance::new(h, f_com, f_eval, y, beta);
141+
let wit = LinWitness::new(w);
142+
LinRelation::new(inst, wit)
143+
}
144+
145+
// ─── get_l ───
146+
147+
/// ℓ = number of base-b digits needed for the range [-β, β].
148+
#[test]
149+
fn test_get_l_explicit() {
150+
// β=1, b=2: 2β+1 = 3, need 2 binary digits.
151+
assert_eq!(get_l(1, 2), 2);
152+
// β=4, b=2: 2β+1 = 9, need 4 binary digits.
153+
assert_eq!(get_l(4, 2), 4);
154+
// β=7, b=3: 2β+1 = 15, need 3 ternary digits.
155+
assert_eq!(get_l(7, 3), 3);
156+
}
157+
158+
/// β = 0 is a degenerate range; should panic.
159+
#[test]
160+
#[should_panic]
161+
fn test_get_l_beta_zero_panics() {
162+
let _ = get_l(0, 2);
163+
}
164+
165+
// ─── balanced b-ary decomposition (Z_q-level) ───
166+
167+
/// Concrete digit lists for documented cases (mirrors `test_decompose_Fq_explicit`).
168+
#[test]
169+
fn test_balanced_decompose_zq_explicit() {
170+
// 7 = 1 + 2 + 4 = 0b0111 → [1, 1, 1, 0]
171+
assert_eq!(
172+
balanced_b_ary_decompose_zq::<Q>(zq(7), 2, 4),
173+
vec![zq(1), zq(1), zq(1), zq(0)]
174+
);
175+
// Sign carries through: -7 → [-1, -1, -1, 0]
176+
assert_eq!(
177+
balanced_b_ary_decompose_zq::<Q>(zq(-7), 2, 4),
178+
vec![zq(-1), zq(-1), zq(-1), zq(0)]
179+
);
180+
// Zero → all zeros.
181+
assert_eq!(
182+
balanced_b_ary_decompose_zq::<Q>(zq(0), 2, 4),
183+
vec![zq(0); 4]
184+
);
185+
// Balanced ternary stress test: 5 with b=3 exercises the carry step.
186+
// Non-balanced would give [2, 1] (digit 2 ∉ {-1, 0, 1}); balanced uses
187+
// carry to push 2 → -1 with +3 added to the next position:
188+
// 5 = (-1)·1 + (-1)·3 + 1·9 → [-1, -1, 1]
189+
assert_eq!(
190+
balanced_b_ary_decompose_zq::<Q>(zq(5), 3, 3),
191+
vec![zq(-1), zq(-1), zq(1)]
192+
);
193+
// Sign symmetry for the same case.
194+
assert_eq!(
195+
balanced_b_ary_decompose_zq::<Q>(zq(-5), 3, 3),
196+
vec![zq(1), zq(1), zq(-1)]
197+
);
198+
}
199+
200+
/// Reverse direction: `compose_zq` reassembles the digit lists above.
201+
#[test]
202+
fn test_compose_zq_explicit() {
203+
assert_eq!(compose_zq::<Q>(&[zq(1), zq(1), zq(1), zq(0)], 2), zq(7));
204+
assert_eq!(compose_zq::<Q>(&[zq(-1), zq(-1), zq(-1), zq(0)], 2), zq(-7));
205+
assert_eq!(compose_zq::<Q>(&[zq(0); 4], 2), zq(0));
206+
// Balanced ternary recompose: [-1, -1, 1] · (1, 3, 9) = -1 - 3 + 9 = 5.
207+
assert_eq!(compose_zq::<Q>(&[zq(-1), zq(-1), zq(1)], 3), zq(5));
208+
assert_eq!(compose_zq::<Q>(&[zq(1), zq(1), zq(-1)], 3), zq(-5));
209+
}
210+
211+
/// compose(decompose(f)) == f for all f ∈ [-β, β] and several (b, β).
212+
#[test]
213+
fn test_decompose_zq_roundtrip() {
214+
for b in [2u64, 3] {
215+
for beta in [1u64, 4, 7] {
216+
let l = get_l(beta, b);
217+
for f_int in -(beta as i64)..=(beta as i64) {
218+
let f = zq(f_int);
219+
let coeffs = balanced_b_ary_decompose_zq::<Q>(f, b, l);
220+
assert_eq!(
221+
coeffs.len(),
222+
l,
223+
"decompose_zq must return ℓ={l} digits, got {} (f={f_int}, b={b})",
224+
coeffs.len(),
225+
);
226+
let f_back = compose_zq::<Q>(&coeffs, b);
227+
assert_eq!(
228+
f_back, f,
229+
"roundtrip mismatch: f={f_int}, b={b}, β={beta}, coeffs={coeffs:?}",
230+
);
231+
}
232+
}
233+
}
234+
}
235+
236+
// ─── decompose_w (matrix-level) ───
237+
238+
/// W = Σ_k b^k · V_k where V = decompose_w(W, b, ℓ).
239+
#[test]
240+
fn test_decompose_w_roundtrip() {
241+
// W with mixed-degree polys; max coeff magnitude 3 → β=4 is safe.
242+
let w: Mat<R> = Mat::new(vec![
243+
vec![c(1) + mono(2, 1), c(3)],
244+
vec![R::zero(), c(0) - mono(1, 2)],
245+
]);
246+
let b = 2u64;
247+
let beta = 4u64;
248+
let l = get_l(beta, b);
249+
let v = decompose_w(&w, b, l);
250+
251+
// Shape: ℓ matrices, each same dim as W.
252+
assert_eq!(v.len(), l, "expected ℓ={l} matrices");
253+
for (k, v_k) in v.iter().enumerate() {
254+
assert_eq!(v_k.nrows(), w.nrows(), "V_{k} row count");
255+
assert_eq!(v_k.ncols(), w.ncols(), "V_{k} col count");
256+
}
257+
258+
// Round-trip: Σ_k b^k · V_k must reassemble to W.
259+
// (Mat lacks scalar-mul; do it cell-wise.)
260+
let w_back = Mat::<R>::from_fn(w.nrows(), w.ncols(), |i, j| {
261+
let mut acc = R::zero();
262+
for (k, v_k) in v.iter().enumerate() {
263+
let bk = c(b.pow(k as u32));
264+
acc = acc + bk * v_k.row(i)[j];
265+
}
266+
acc
267+
});
268+
assert_eq!(w_back, w, "Σ_k b^k · V_k must equal W");
269+
}
270+
271+
/// Every coefficient of every V_k lives in {-⌊b/2⌋, ..., ⌊b/2⌋}.
272+
/// For b=2, |coeff| ≤ 1 (i.e. coeff ∈ {-1, 0, 1}).
273+
#[test]
274+
fn test_decompose_w_norm_bound() {
275+
let w: Mat<R> = Mat::new(vec![vec![c(7), -mono(3, 1)], vec![mono(5, 2), R::zero()]]);
276+
let b = 2u64;
277+
let beta = 7u64;
278+
let l = get_l(beta, b);
279+
let v = decompose_w(&w, b, l);
280+
281+
let bound = (b / 2) as i64;
282+
for (k, v_k) in v.iter().enumerate() {
283+
for i in 0..v_k.nrows() {
284+
for j in 0..v_k.ncols() {
285+
for &coeff in v_k.row(i)[j].coeffs() {
286+
let cv = coeff.to_centered().abs();
287+
assert!(
288+
cv <= bound,
289+
"V_{k}[{i}][{j}] has |centered coeff|={cv} > ⌊b/2⌋={bound}",
290+
);
291+
}
292+
}
293+
}
294+
}
295+
}
296+
297+
// ─── rok_decompose smoke ───
298+
299+
/// Π^b-decomp: r grows by integer factor ℓ; H, F, m, n, n̂ unchanged.
300+
#[test]
301+
fn test_rok_decompose_smoke() {
302+
// β = 4, b = 2 → ℓ = ⌈log_2(9)⌉ = 4. r should grow r_in → 4·r_in.
303+
let rel = build_rel(
304+
mat(&[[1, 2]]),
305+
mat(&[[3, 4], [5, 6]]),
306+
mat(&[[1, 0], [0, 1]]),
307+
4,
308+
);
309+
let out = rok_decompose(&rel, 2);
310+
311+
assert_eq!(out.m(), rel.m(), "m unchanged");
312+
assert_eq!(out.n(), rel.n(), "n unchanged");
313+
assert_eq!(out.n_hat(), rel.n_hat(), "n̂ unchanged");
314+
assert_eq!(out.n_top(), rel.n_top(), "n_top unchanged");
315+
assert_eq!(out.instance.f_com, rel.instance.f_com, "F_com preserved");
316+
assert_eq!(out.instance.f_eval, rel.instance.f_eval, "F_eval preserved");
317+
assert_eq!(out.instance.h, rel.instance.h, "H preserved");
318+
319+
assert!(out.r() > rel.r(), "r must strictly grow");
320+
assert_eq!(out.r() % rel.r(), 0, "r grows by integer factor ℓ");
321+
let ell = out.r() / rel.r();
322+
assert!(ell > 1, "ℓ > 1 for non-trivial decomposition");
323+
324+
// β should not grow (per-entry bound tightens).
325+
assert!(out.beta() <= rel.beta(), "β must not grow");
326+
}
327+
328+
/// rok_decompose MUST produce a `LinRelation` whose `H · F · Ŵ = Ŷ` holds.
329+
/// Reaching the end without panic = LinRelation::new's invariant check passed.
330+
#[test]
331+
fn test_rok_decompose_produces_valid_relation() {
332+
let rel = build_rel(
333+
mat(&[[1, 2]]),
334+
mat(&[[3, 4], [5, 6]]),
335+
mat(&[[1, 0], [0, 1]]),
336+
4,
337+
);
338+
let _out = rok_decompose(&rel, 2);
339+
}
340+
}

0 commit comments

Comments
 (0)