Skip to content

Commit c39ef75

Browse files
committed
use reference in rok_*
1 parent b53538e commit c39ef75

4 files changed

Lines changed: 402 additions & 42 deletions

File tree

src/mat.rs

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use rand::rand_core::block;
21

32
use crate::ring::Ring;
43
use std::ops::{Add, Index, Mul, Range};
@@ -11,9 +10,18 @@ pub struct Mat<R> {
1110
}
1211

1312
impl<R> Mat<R> {
13+
/// Build a matrix from row vectors.
14+
///
15+
/// Panics on empty outer vec — the column count is undeterminable
16+
/// from an empty input. For 0-row matrices, use `Mat::zero(0, ncols)`
17+
/// or `Mat::from_fn(0, ncols, _)` so `ncols` is explicit.
18+
/// `vec![vec![], vec![]]` (M rows of width 0) is fine — ncols = 0.
1419
pub fn new(rows: impl Into<Vec<Vec<R>>>) -> Self {
1520
let rows: Vec<Vec<R>> = rows.into();
16-
assert!(!rows.is_empty(), "Mat::new requires at least one row");
21+
assert!(
22+
!rows.is_empty(),
23+
"Mat::new: empty rows — use Mat::zero(0, ncols) for 0×ncols",
24+
);
1725
let nrows = rows.len();
1826
let ncols = rows[0].len();
1927
assert!(
@@ -126,12 +134,15 @@ impl<R: Clone> Mat<R> {
126134
/// Both ranges are half-open (`start..end`). Panics if the end of either
127135
/// range exceeds the corresponding dimension.
128136
pub fn submatrix(&self, rows: Range<usize>, cols: Range<usize>) -> Self {
129-
assert!(rows.end <= self.nrows && cols.end <= self.ncols, "submatrix OOB");
137+
assert!(
138+
rows.end <= self.nrows && cols.end <= self.ncols,
139+
"submatrix OOB"
140+
);
130141
Mat::<R>::from_fn(
131142
rows.end - rows.start,
132143
cols.end - cols.start,
133144
// i \in [0, nrows), j \in [0, ncols)
134-
|i, j| self[(i + rows.start, j + cols.start)].clone()
145+
|i, j| self[(i + rows.start, j + cols.start)].clone(),
135146
)
136147
}
137148
}
@@ -159,16 +170,16 @@ impl<R: Ring> Mat<R> {
159170
///
160171
/// Example: `block_diagonal(&[I_1, I_2, I_1])` is the 4×4 identity.
161172
pub fn block_diagonal(blocks: &[Mat<R>]) -> Self {
162-
let nrows_new: usize = blocks.iter().map(|x| x.nrows).sum();
173+
let nrows_new: usize = blocks.iter().map(|x| x.nrows).sum();
163174
let ncols_new: usize = blocks.iter().map(|x| x.ncols).sum();
164-
let mut data: Vec<R> = vec![R::zero();nrows_new * ncols_new];
175+
let mut data: Vec<R> = vec![R::zero(); nrows_new * ncols_new];
165176
let mut cur_start_row: usize = 0;
166177
let mut cur_start_col: usize = 0;
167178
// Fill in each b in the diagonal of `data`.
168179
for b in blocks.iter() {
169180
for i in 0..b.nrows {
170181
for j in 0..b.ncols {
171-
data[(cur_start_row + i) * ncols_new + cur_start_col + j] = b[(i, j)].clone();
182+
data[(cur_start_row + i) * ncols_new + cur_start_col + j] = b[(i, j)];
172183
}
173184
}
174185
cur_start_row += b.nrows;
@@ -268,18 +279,90 @@ mod tests {
268279
}
269280

270281
#[test]
271-
#[should_panic(expected = "at least one row")]
282+
#[should_panic(expected = "empty rows")]
272283
fn test_new_empty_panics() {
284+
// Empty outer Vec is ambiguous (ncols undeterminable) — must panic.
273285
let _: Mat<F> = Mat::new(Vec::<Vec<F>>::new());
274286
}
275287

288+
#[test]
289+
fn test_new_zero_cols_from_empty_inner_rows() {
290+
// M × 0: outer has rows, inner rows have width 0 → ncols = 0 is determined.
291+
let mat: Mat<F> = Mat::new(vec![vec![], vec![], vec![]]);
292+
assert_eq!(mat.dimensions(), (3, 0));
293+
}
294+
276295
#[test]
277296
fn test_from_fn_matches_new() {
278297
let by_new = m(&[[0, 1, 2], [10, 11, 12]]);
279298
let by_fn = Mat::<F>::from_fn(2, 3, |i, j| z((i * 10 + j) as u64));
280299
assert_eq!(by_new, by_fn);
281300
}
282301

302+
// ─── 0-dim matrices ───
303+
304+
#[test]
305+
fn test_from_fn_zero_rows() {
306+
// 0 × m: closure never invoked, but ncols is preserved.
307+
let mat = Mat::<F>::from_fn(0, 5, |_, _| panic!("must not be called"));
308+
assert_eq!(mat.dimensions(), (0, 5));
309+
}
310+
311+
#[test]
312+
fn test_from_fn_zero_cols() {
313+
let mat = Mat::<F>::from_fn(3, 0, |_, _| panic!("must not be called"));
314+
assert_eq!(mat.dimensions(), (3, 0));
315+
}
316+
317+
#[test]
318+
fn test_zero_with_zero_rows() {
319+
let mat = Mat::<F>::zero(0, 4);
320+
assert_eq!(mat.dimensions(), (0, 4));
321+
}
322+
323+
#[test]
324+
fn test_zero_with_zero_cols() {
325+
let mat = Mat::<F>::zero(2, 0);
326+
assert_eq!(mat.dimensions(), (2, 0));
327+
}
328+
329+
#[test]
330+
fn test_stack_onto_zero_row_keeps_ncols() {
331+
// 0 × 3 stacked on 2 × 3 = 2 × 3 (caller can drop the empty top half).
332+
let top = Mat::<F>::zero(0, 3);
333+
let bot = m(&[[1, 2, 3], [4, 5, 6]]);
334+
let stacked = top.stack(&bot);
335+
assert_eq!(stacked.dimensions(), (2, 3));
336+
assert_eq!(stacked, bot);
337+
}
338+
339+
#[test]
340+
#[should_panic(expected = "col mismatch")]
341+
fn test_stack_zero_x_zero_onto_real_matrix_panics() {
342+
// 0 × 0 ≠ "any ncols" — stack must reject this.
343+
let zero_zero = Mat::<F>::from_fn(0, 0, |_, _| unreachable!());
344+
let real = m(&[[1, 2, 3]]);
345+
let _ = zero_zero.stack(&real);
346+
}
347+
348+
#[test]
349+
fn test_augment_with_zero_col_keeps_nrows() {
350+
// (2 × 0) augment (2 × 3) = 2 × 3.
351+
let left = Mat::<F>::zero(2, 0);
352+
let right = m(&[[1, 2, 3], [4, 5, 6]]);
353+
let aug = left.augment(&right);
354+
assert_eq!(aug.dimensions(), (2, 3));
355+
assert_eq!(aug, right);
356+
}
357+
358+
#[test]
359+
fn test_transpose_zero_row() {
360+
// (0 × 5)^T = (5 × 0).
361+
let mat = Mat::<F>::zero(0, 5);
362+
let t = mat.transpose();
363+
assert_eq!(t.dimensions(), (5, 0));
364+
}
365+
283366
#[test]
284367
fn test_from_flatten() {
285368
let data = [z(1), z(2), z(3), z(4)];
@@ -676,7 +759,7 @@ mod tests {
676759
fn test_block_diagonal_mixed_dims() {
677760
// diag(2×3 block, 1×2 block) → 3×5 matrix
678761
let a = m(&[[1, 2, 3], [4, 5, 6]]); // 2×3
679-
let b = m(&[[7, 8]]); // 1×2
762+
let b = m(&[[7, 8]]); // 1×2
680763
let bd = Mat::<F>::block_diagonal(&[a, b]);
681764

682765
assert_eq!(bd.dimensions(), (3, 5));

src/relations.rs

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
//!
77
//! where F = F_com stacked over F_eval, and ‖·‖₂ is the column l_2 norm
88
//! over the concatenated coefficient vector of each column of W.
9-
//!
10-
//! Reference: SALSAA paper §3–4. Python: `06_salsaa/relations.py`.
119
1210
use crate::mat::Mat;
1311
use crate::ring::Rq;
@@ -67,7 +65,7 @@ impl<const Q: u64, const D: usize> LinInstance<Q, D> {
6765
self.h.ncols()
6866
}
6967

70-
/// n_top — cols of H = rows of F.
68+
/// n_top — \bar F, number of rows of the commitment matrix.
7169
pub fn n_top(&self) -> usize {
7270
self.f_com.nrows()
7371
}
@@ -113,38 +111,37 @@ impl<const Q: u64, const D: usize> LinWitness<Q, D> {
113111
}
114112
}
115113

116-
/// A LinInstance plus a verified LinWitness. Construction must enforce:
117-
/// 1. Dimension consistency (H, F, W, Y all line up)
118-
/// 2. Algebraic relation: H · (F · W) = Y
119-
/// 3. l_2 norm bound: max_i ‖w_i‖₂ ≤ β
120-
///
121-
/// `new` returns `Self` here for symmetry with Python. You can refactor
122-
/// to `Result<Self, RelError>` (Ch 9 practice) if you want explicit
123-
/// error variants instead of panic.
124-
#[derive(Debug, Clone)]
125-
pub struct LinRelation<const Q: u64, const D: usize> {
126-
pub instance: LinInstance<Q, D>,
127-
pub witness: LinWitness<Q, D>,
128-
}
129-
114+
/// l2 norm squared of a column is the sum of the norm squared of all ring elements
130115
fn col_l2_norm_squared<const Q: u64, const D: usize>(col: &[Rq<Q, D>]) -> u64 {
131116
col.iter().map(|r| r.l2_norm_squared()).sum()
132117
}
133118

134-
fn max_col_l2_norm_squared<const Q: u64, const D: usize>(w: &Mat<Rq<Q, D>>) -> u64 {
119+
/// l2 norm of a matrix is the max norm among its columns.
120+
/// this function returns the max l2 norm **squared**
121+
fn mat_l2_norm_squared<const Q: u64, const D: usize>(w: &Mat<Rq<Q, D>>) -> u64 {
135122
(0..w.ncols())
136123
.map(|j| col_l2_norm_squared(&w.col(j)))
137124
.max()
138125
.unwrap_or(0)
139126
}
140127

128+
/// A LinInstance plus a verified LinWitness. Construction must enforce:
129+
/// 1. Dimension consistency (H, F, W, Y all line up)
130+
/// 2. Algebraic relation: H · (F · W) = Y
131+
/// 3. l_2 norm bound: max_i ‖w_i‖₂ ≤ β
132+
#[derive(Debug, Clone)]
133+
pub struct LinRelation<const Q: u64, const D: usize> {
134+
pub instance: LinInstance<Q, D>,
135+
pub witness: LinWitness<Q, D>,
136+
}
137+
141138
impl<const Q: u64, const D: usize> LinRelation<Q, D> {
142139
pub fn new(instance: LinInstance<Q, D>, witness: LinWitness<Q, D>) -> Self {
143140
// 1. l_2 norm of W must be <= \beta
144-
let l2_norm_squared_w = max_col_l2_norm_squared(&witness.w);
141+
let l2_norm_squared_w = mat_l2_norm_squared(&witness.w);
145142
let l2_norm_bound_squared = instance.beta * instance.beta;
146143
assert!(
147-
l2_norm_squared_w < l2_norm_bound_squared,
144+
l2_norm_squared_w <= l2_norm_bound_squared,
148145
"exceeded norm bound: actual norm squared={}, norm bound={}",
149146
l2_norm_squared_w,
150147
l2_norm_bound_squared,
@@ -217,12 +214,6 @@ mod tests {
217214
Mat::new(v)
218215
}
219216

220-
// NOTE: Initial Σ^lin has F_eval = 0 × m. Current `Mat::new` requires
221-
// at least one row, so 0-row matrices can't be constructed yet. All
222-
// tests below use F_eval with ≥ 1 row to dodge that limitation. When
223-
// you decide how to handle the empty case (Option, Mat::empty, etc.),
224-
// add a test for the initial state.
225-
226217
// ─── LinWitness ───
227218

228219
#[test]
@@ -346,8 +337,73 @@ mod tests {
346337
let _ = LinRelation::new(inst, wit);
347338
}
348339

349-
// NOTE: norm-bound violation test (||w_i||_2 > β) is deferred — it
350-
// needs `Zq::centered()` (mapping v ∈ [0, q) → signed [-q/2, q/2])
351-
// which isn't on Zq yet. Once that exists, add a `should_panic` test
352-
// putting a column of large-coefficient entries into W with β = 1.
340+
// ─── norm-bound checks (the ‖·‖₂ branch of LinRelation::new) ───
341+
342+
#[test]
343+
fn test_relation_norm_at_boundary_constructs() {
344+
// ‖w‖² == β². Bound is non-strict (≤) so this MUST pass.
345+
// W = [[c(1)]]: coeffs [1, 0, 0, 0]. centered norm² = 1. β = 1 → β² = 1.
346+
// F = F_com = [[1]] (no eval block), so F · W = [[1]] = Y.
347+
let h = Mat::<R>::identity(1);
348+
let f_com = mat(&[[1]]);
349+
let f_eval = Mat::<R>::zero(0, 1);
350+
let w = mat(&[[1]]);
351+
let y = mat(&[[1]]);
352+
353+
let inst = LinInstance::new(h, f_com, f_eval, y, 1);
354+
let wit = LinWitness::new(w);
355+
let _ = LinRelation::new(inst, wit); // must not panic
356+
}
357+
358+
#[test]
359+
#[should_panic(expected = "exceeded norm bound")]
360+
fn test_relation_norm_violation_panics() {
361+
// ‖w‖² > β². W = [[c(5)]]: centered coeffs [5, 0, 0, 0], norm² = 25.
362+
// β = 1 → β² = 1. 25 > 1 ⇒ must panic.
363+
// Y = F · W = c(5) so the algebraic relation alone would still hold.
364+
let h = Mat::<R>::identity(1);
365+
let f_com = mat(&[[1]]);
366+
let f_eval = Mat::<R>::zero(0, 1);
367+
let w = mat(&[[5]]);
368+
let y = mat(&[[5]]);
369+
370+
let inst = LinInstance::new(h, f_com, f_eval, y, 1);
371+
let wit = LinWitness::new(w);
372+
let _ = LinRelation::new(inst, wit);
373+
}
374+
375+
// ─── initial state: empty evaluation block ───
376+
377+
#[test]
378+
fn test_instance_initial_state_empty_f_eval() {
379+
// F_eval = 0 × m means "no evaluation rows yet". This is the
380+
// initial Σ^lin shape before any with_extra_eval call. F should
381+
// equal F_com (stacking 0 rows on top changes nothing).
382+
let h = Mat::<R>::identity(2);
383+
let f_com = mat(&[[1, 2], [3, 4]]);
384+
let f_eval = Mat::<R>::zero(0, 2); // explicit 0 × 2 — ncols carried
385+
let y = Mat::<R>::zero(2, 1);
386+
let inst = LinInstance::new(h, f_com.clone(), f_eval, y, 10);
387+
388+
assert_eq!(inst.n_hat(), 2);
389+
assert_eq!(inst.n(), 2); // F_com.nrows + F_eval.nrows = 2 + 0
390+
assert_eq!(inst.n_top(), 2); // F_com.nrows
391+
assert_eq!(inst.m(), 2);
392+
assert_eq!(inst.f(), f_com, "F == F_com when F_eval is empty");
393+
}
394+
395+
#[test]
396+
fn test_relation_with_empty_f_eval_constructs() {
397+
// End-to-end: a satisfied Σ^lin with the initial empty F_eval.
398+
let h = Mat::<R>::identity(2);
399+
let f_com = mat(&[[1, 2], [3, 4]]);
400+
let f_eval = Mat::<R>::zero(0, 2);
401+
let w = mat(&[[1], [1]]);
402+
// F · W = [[1+2], [3+4]] = [[3], [7]] (constant-poly entries)
403+
let y = mat(&[[3], [7]]);
404+
405+
let inst = LinInstance::new(h, f_com, f_eval, y, 10);
406+
let wit = LinWitness::new(w);
407+
let _ = LinRelation::new(inst, wit); // must not panic
408+
}
353409
}

src/rok/batch.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1-
//! Batch reduction: combines multiple sumcheck claims into one.
2-
//!
3-
//! Reference: `06_salsaa/rok/batch.py`.
1+
use crate::relations::LinRelation;
2+
3+
/// Batch evaluation statements into smaller statements.
4+
/// E.g. \tilde f(r) = s
5+
/// \tilde f(\bar r) = \bar s
6+
/// -> tilde f(r) + c \tilde f(\bar r) = s + c \bar s
7+
pub fn rok_join<const Q: u64, const D: usize>(
8+
_lin: &LinRelation<Q, D>,
9+
_n_target_eval_rows: usize,
10+
) -> LinRelation<Q, D> {
11+
todo!()
12+
}

0 commit comments

Comments
 (0)