Skip to content

Commit 697fe34

Browse files
committed
build(mat.rs): add helpers stack and augment
1 parent a1506a9 commit 697fe34

2 files changed

Lines changed: 130 additions & 6 deletions

File tree

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ pub mod ajtai;
22
pub mod mat;
33
pub mod ntt;
44
pub mod poly;
5+
pub mod relations;
56
pub mod ring;
67
pub mod rok;
78
pub mod zq;

src/mat.rs

Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,48 @@ impl<R: Clone> Mat<R> {
7474
}
7575

7676
/// Returns column `j` as an owned `Vec<R>`.
77-
///
78-
/// Row-major storage means column `j`'s elements live at offsets
79-
/// `j, j + ncols, j + 2*ncols, ...` — they are NOT contiguous, so
80-
/// we cannot return a `&[R]` slice (a slice must be contiguous in
81-
/// memory). That is why this method clones into an owned `Vec`
82-
/// and why the bound `R: Clone` is required.
8377
pub fn col(&self, j: usize) -> Vec<R> {
8478
assert!(j < self.ncols, "col index out of bounds");
8579

8680
self.data.iter().skip(j).step_by(self.ncols).cloned().collect()
8781
}
8882

83+
/// A.augment(B) = [A | B]
84+
pub fn augment(self, other: Self) -> Self {
85+
assert_eq!(self.nrows, other.nrows, "row mismatch");
86+
let new_ncols = self.ncols + other.ncols;
87+
let mut data = Vec::with_capacity(self.nrows*new_ncols);
88+
// Fill A into `data`
89+
for i in 0..self.nrows {
90+
for j in 0..self.ncols {
91+
data.push(self[(i, j)].clone());
92+
}
93+
for j in 0..other.ncols {
94+
data.push(other[(i, j)].clone());
95+
}
96+
}
97+
98+
Self {
99+
data,
100+
ncols: self.ncols + other.ncols,
101+
nrows: self.nrows,
102+
}
103+
}
104+
105+
/// A.stack(B) = [A
106+
/// B]
107+
pub fn stack(self, other: Self) -> Self {
108+
assert_eq!(self.ncols, other.ncols, "col mismatch");
109+
110+
let mut data = self.data.clone();
111+
data.extend(other.data);
112+
Self {
113+
data,
114+
nrows: self.nrows + other.nrows,
115+
ncols: self.ncols
116+
}
117+
}
118+
89119
pub fn transpose(&self) -> Self {
90120
Mat::from_fn(self.ncols, self.nrows, |i, j| {
91121
self.data[j * self.ncols + i].clone()
@@ -442,4 +472,97 @@ mod tests {
442472
assert_eq!(c[(0,0)], r([11, 11, 11, 11]));
443473
}
444474

475+
// ─── stack ───
476+
477+
#[test]
478+
fn test_stack_basic() {
479+
// A: 2×3, B: 1×3 → 3×3
480+
let a = m(&[[1, 2, 3], [4, 5, 6]]);
481+
let b = m(&[[7, 8, 9]]);
482+
let c = a.stack(b);
483+
assert_eq!(c.dimensions(), (3, 3));
484+
assert_eq!(c.row(0), &[z(1), z(2), z(3)]);
485+
assert_eq!(c.row(1), &[z(4), z(5), z(6)]);
486+
assert_eq!(c.row(2), &[z(7), z(8), z(9)]);
487+
}
488+
489+
#[test]
490+
fn test_stack_preserves_ncols() {
491+
let a = m(&[[1, 2]]);
492+
let b = m(&[[3, 4], [5, 6]]);
493+
let c = a.stack(b);
494+
assert_eq!(c.dimensions(), (3, 2));
495+
}
496+
497+
#[test]
498+
#[should_panic(expected = "col mismatch")]
499+
fn test_stack_ncols_mismatch_panics() {
500+
let a = m(&[[1, 2, 3]]);
501+
let b = m(&[[4, 5]]);
502+
let _ = a.stack(b);
503+
}
504+
505+
#[test]
506+
fn test_stack_zero_rows_on_top() {
507+
// 0×3 stack 2×3 → 2×3 (the zero matrix vanishes on top)
508+
let empty = Mat::<F>::from_fn(0, 3, |_, _| z(0));
509+
let b = m(&[[1, 2, 3], [4, 5, 6]]);
510+
let c = empty.stack(b);
511+
assert_eq!(c.dimensions(), (2, 3));
512+
assert_eq!(c.row(0), &[z(1), z(2), z(3)]);
513+
assert_eq!(c.row(1), &[z(4), z(5), z(6)]);
514+
}
515+
516+
#[test]
517+
fn test_stack_zero_rows_on_bottom() {
518+
// 2×3 stack 0×3 → 2×3 (the zero matrix vanishes on bottom)
519+
let a = m(&[[1, 2, 3], [4, 5, 6]]);
520+
let empty = Mat::<F>::from_fn(0, 3, |_, _| z(0));
521+
let c = a.stack(empty);
522+
assert_eq!(c.dimensions(), (2, 3));
523+
assert_eq!(c.row(0), &[z(1), z(2), z(3)]);
524+
assert_eq!(c.row(1), &[z(4), z(5), z(6)]);
525+
}
526+
527+
#[test]
528+
fn test_stack_zero_rows_both() {
529+
// 0×3 stack 0×3 → 0×3
530+
let a = Mat::<F>::from_fn(0, 3, |_, _| z(0));
531+
let b = Mat::<F>::from_fn(0, 3, |_, _| z(0));
532+
let c = a.stack(b);
533+
assert_eq!(c.dimensions(), (0, 3));
534+
}
535+
536+
// ─── augment ───
537+
538+
#[test]
539+
fn test_augment_basic() {
540+
// A = [1 2] B = [5 6] A | B = [1 2 5 6]
541+
// [3 4] [7 8] [3 4 7 8]
542+
let a = m(&[[1, 2], [3, 4]]);
543+
let b = m(&[[5, 6], [7, 8]]);
544+
let c = a.augment(b);
545+
assert_eq!(c.dimensions(), (2, 4));
546+
assert_eq!(c.row(0), &[z(1), z(2), z(5), z(6)]);
547+
assert_eq!(c.row(1), &[z(3), z(4), z(7), z(8)]);
548+
}
549+
550+
#[test]
551+
fn test_augment_different_widths() {
552+
let a = m(&[[1], [2], [3]]); // 3 × 1
553+
let b = m(&[[4, 5], [6, 7], [8, 9]]); // 3 × 2
554+
let c = a.augment(b);
555+
assert_eq!(c.dimensions(), (3, 3));
556+
assert_eq!(c.row(0), &[z(1), z(4), z(5)]);
557+
assert_eq!(c.row(1), &[z(2), z(6), z(7)]);
558+
assert_eq!(c.row(2), &[z(3), z(8), z(9)]);
559+
}
560+
561+
#[test]
562+
#[should_panic(expected = "row mismatch")]
563+
fn test_augment_nrows_mismatch_panics() {
564+
let a = m(&[[1, 2], [3, 4]]);
565+
let b = m(&[[5, 6]]);
566+
let _ = a.augment(b);
567+
}
445568
}

0 commit comments

Comments
 (0)