Skip to content

Commit b53538e

Browse files
committed
build(mat.rs): add helpers of mat. diagonal and submatrix
1 parent 2c339fa commit b53538e

1 file changed

Lines changed: 150 additions & 1 deletion

File tree

src/mat.rs

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
use rand::rand_core::block;
2+
13
use crate::ring::Ring;
2-
use std::ops::{Add, Index, Mul};
4+
use std::ops::{Add, Index, Mul, Range};
35

46
#[derive(Debug, Clone, PartialEq, Eq)]
57
pub struct Mat<R> {
@@ -118,6 +120,20 @@ impl<R: Clone> Mat<R> {
118120
self.data[j * self.ncols + i].clone()
119121
})
120122
}
123+
124+
/// Return a new Mat containing the rectangular region `rows × cols`.
125+
///
126+
/// Both ranges are half-open (`start..end`). Panics if the end of either
127+
/// range exceeds the corresponding dimension.
128+
pub fn submatrix(&self, rows: Range<usize>, cols: Range<usize>) -> Self {
129+
assert!(rows.end <= self.nrows && cols.end <= self.ncols, "submatrix OOB");
130+
Mat::<R>::from_fn(
131+
rows.end - rows.start,
132+
cols.end - cols.start,
133+
// i \in [0, nrows), j \in [0, ncols)
134+
|i, j| self[(i + rows.start, j + cols.start)].clone()
135+
)
136+
}
121137
}
122138

123139
impl<R> Index<(usize, usize)> for Mat<R> {
@@ -137,6 +153,33 @@ impl<R: Ring> Mat<R> {
137153
pub fn identity(n: usize) -> Self {
138154
Mat::from_fn(n, n, |i, j| if i == j { R::one() } else { R::zero() })
139155
}
156+
157+
/// Build a block-diagonal matrix from the supplied blocks. Off-diagonal
158+
/// entries are filled with `R::zero()`.
159+
///
160+
/// Example: `block_diagonal(&[I_1, I_2, I_1])` is the 4×4 identity.
161+
pub fn block_diagonal(blocks: &[Mat<R>]) -> Self {
162+
let nrows_new: usize = blocks.iter().map(|x| x.nrows).sum();
163+
let ncols_new: usize = blocks.iter().map(|x| x.ncols).sum();
164+
let mut data: Vec<R> = vec![R::zero();nrows_new * ncols_new];
165+
let mut cur_start_row: usize = 0;
166+
let mut cur_start_col: usize = 0;
167+
// Fill in each b in the diagonal of `data`.
168+
for b in blocks.iter() {
169+
for i in 0..b.nrows {
170+
for j in 0..b.ncols {
171+
data[(cur_start_row + i) * ncols_new + cur_start_col + j] = b[(i, j)].clone();
172+
}
173+
}
174+
cur_start_row += b.nrows;
175+
cur_start_col += b.ncols;
176+
}
177+
Self {
178+
data,
179+
nrows: nrows_new,
180+
ncols: ncols_new,
181+
}
182+
}
140183
}
141184

142185
impl<R: Ring> Add for Mat<R> {
@@ -561,4 +604,110 @@ mod tests {
561604
let b = m(&[[5, 6]]);
562605
let _ = a.augment(&b);
563606
}
607+
608+
// ─── submatrix ───
609+
610+
#[test]
611+
fn test_submatrix_basic() {
612+
// [1 2 3]
613+
// [4 5 6]
614+
// [7 8 9] sub(1..3, 1..3) = [[5 6], [8 9]]
615+
let mat = m(&[[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
616+
let sub = mat.submatrix(1..3, 1..3);
617+
assert_eq!(sub.dimensions(), (2, 2));
618+
assert_eq!(sub[(0, 0)], z(5));
619+
assert_eq!(sub[(0, 1)], z(6));
620+
assert_eq!(sub[(1, 0)], z(8));
621+
assert_eq!(sub[(1, 1)], z(9));
622+
}
623+
624+
#[test]
625+
fn test_submatrix_full() {
626+
let mat = m(&[[1, 2], [3, 4]]);
627+
let sub = mat.submatrix(0..2, 0..2);
628+
assert_eq!(sub, mat);
629+
}
630+
631+
#[test]
632+
fn test_submatrix_single_row() {
633+
let mat = m(&[[1, 2, 3], [4, 5, 6]]);
634+
let sub = mat.submatrix(1..2, 0..3);
635+
assert_eq!(sub.dimensions(), (1, 3));
636+
assert_eq!(sub.row(0), &[z(4), z(5), z(6)]);
637+
}
638+
639+
#[test]
640+
fn test_submatrix_single_col() {
641+
let mat = m(&[[1, 2, 3], [4, 5, 6]]);
642+
let sub = mat.submatrix(0..2, 1..2);
643+
assert_eq!(sub.dimensions(), (2, 1));
644+
assert_eq!(sub[(0, 0)], z(2));
645+
assert_eq!(sub[(1, 0)], z(5));
646+
}
647+
648+
#[test]
649+
#[should_panic(expected = "submatrix OOB")]
650+
fn test_submatrix_row_oob_panics() {
651+
let mat = m(&[[1, 2], [3, 4]]);
652+
let _ = mat.submatrix(0..3, 0..2); // row end > nrows
653+
}
654+
655+
#[test]
656+
#[should_panic(expected = "submatrix OOB")]
657+
fn test_submatrix_col_oob_panics() {
658+
let mat = m(&[[1, 2], [3, 4]]);
659+
let _ = mat.submatrix(0..2, 0..3); // col end > ncols
660+
}
661+
662+
// ─── block_diagonal ───
663+
664+
#[test]
665+
fn test_block_diagonal_three_identities_equals_identity() {
666+
// diag(I_1, I_2, I_1) is the 4×4 identity (block_diag of identity blocks).
667+
let bd = Mat::<F>::block_diagonal(&[
668+
Mat::<F>::identity(1),
669+
Mat::<F>::identity(2),
670+
Mat::<F>::identity(1),
671+
]);
672+
assert_eq!(bd, Mat::<F>::identity(4));
673+
}
674+
675+
#[test]
676+
fn test_block_diagonal_mixed_dims() {
677+
// diag(2×3 block, 1×2 block) → 3×5 matrix
678+
let a = m(&[[1, 2, 3], [4, 5, 6]]); // 2×3
679+
let b = m(&[[7, 8]]); // 1×2
680+
let bd = Mat::<F>::block_diagonal(&[a, b]);
681+
682+
assert_eq!(bd.dimensions(), (3, 5));
683+
684+
// top-left 2×3 = a
685+
assert_eq!(bd[(0, 0)], z(1));
686+
assert_eq!(bd[(0, 2)], z(3));
687+
assert_eq!(bd[(1, 0)], z(4));
688+
assert_eq!(bd[(1, 2)], z(6));
689+
690+
// top-right 2×2 = zeros
691+
assert_eq!(bd[(0, 3)], z(0));
692+
assert_eq!(bd[(0, 4)], z(0));
693+
assert_eq!(bd[(1, 3)], z(0));
694+
assert_eq!(bd[(1, 4)], z(0));
695+
696+
// bottom-left 1×3 = zeros
697+
assert_eq!(bd[(2, 0)], z(0));
698+
assert_eq!(bd[(2, 1)], z(0));
699+
assert_eq!(bd[(2, 2)], z(0));
700+
701+
// bottom-right 1×2 = b
702+
assert_eq!(bd[(2, 3)], z(7));
703+
assert_eq!(bd[(2, 4)], z(8));
704+
}
705+
706+
#[test]
707+
fn test_block_diagonal_single_block() {
708+
// diag(A) == A (degenerate case)
709+
let a = m(&[[1, 2], [3, 4]]);
710+
let bd = Mat::<F>::block_diagonal(&[a.clone()]);
711+
assert_eq!(bd, a);
712+
}
564713
}

0 commit comments

Comments
 (0)