1+ use rand:: rand_core:: block;
2+
13use crate :: ring:: Ring ;
2- use std:: ops:: { Add , Index , Mul } ;
4+ use std:: ops:: { Add , Index , Mul , Range } ;
35
46#[ derive( Debug , Clone , PartialEq , Eq ) ]
57pub struct Mat < R > {
@@ -118,6 +120,20 @@ impl<R: Clone> Mat<R> {
118120 self . data [ j * self . ncols + i] . clone ( )
119121 } )
120122 }
123+
124+ /// Return a new Mat containing the rectangular region `rows × cols`.
125+ ///
126+ /// Both ranges are half-open (`start..end`). Panics if the end of either
127+ /// range exceeds the corresponding dimension.
128+ pub fn submatrix ( & self , rows : Range < usize > , cols : Range < usize > ) -> Self {
129+ assert ! ( rows. end <= self . nrows && cols. end <= self . ncols, "submatrix OOB" ) ;
130+ Mat :: < R > :: from_fn (
131+ rows. end - rows. start ,
132+ cols. end - cols. start ,
133+ // i \in [0, nrows), j \in [0, ncols)
134+ |i, j| self [ ( i + rows. start , j + cols. start ) ] . clone ( )
135+ )
136+ }
121137}
122138
123139impl < R > Index < ( usize , usize ) > for Mat < R > {
@@ -137,6 +153,33 @@ impl<R: Ring> Mat<R> {
137153 pub fn identity ( n : usize ) -> Self {
138154 Mat :: from_fn ( n, n, |i, j| if i == j { R :: one ( ) } else { R :: zero ( ) } )
139155 }
156+
157+ /// Build a block-diagonal matrix from the supplied blocks. Off-diagonal
158+ /// entries are filled with `R::zero()`.
159+ ///
160+ /// Example: `block_diagonal(&[I_1, I_2, I_1])` is the 4×4 identity.
161+ pub fn block_diagonal ( blocks : & [ Mat < R > ] ) -> Self {
162+ let nrows_new: usize = blocks. iter ( ) . map ( |x| x. nrows ) . sum ( ) ;
163+ let ncols_new: usize = blocks. iter ( ) . map ( |x| x. ncols ) . sum ( ) ;
164+ let mut data: Vec < R > = vec ! [ R :: zero( ) ; nrows_new * ncols_new] ;
165+ let mut cur_start_row: usize = 0 ;
166+ let mut cur_start_col: usize = 0 ;
167+ // Fill in each b in the diagonal of `data`.
168+ for b in blocks. iter ( ) {
169+ for i in 0 ..b. nrows {
170+ for j in 0 ..b. ncols {
171+ data[ ( cur_start_row + i) * ncols_new + cur_start_col + j] = b[ ( i, j) ] . clone ( ) ;
172+ }
173+ }
174+ cur_start_row += b. nrows ;
175+ cur_start_col += b. ncols ;
176+ }
177+ Self {
178+ data,
179+ nrows : nrows_new,
180+ ncols : ncols_new,
181+ }
182+ }
140183}
141184
142185impl < R : Ring > Add for Mat < R > {
@@ -561,4 +604,110 @@ mod tests {
561604 let b = m ( & [ [ 5 , 6 ] ] ) ;
562605 let _ = a. augment ( & b) ;
563606 }
607+
608+ // ─── submatrix ───
609+
610+ #[ test]
611+ fn test_submatrix_basic ( ) {
612+ // [1 2 3]
613+ // [4 5 6]
614+ // [7 8 9] sub(1..3, 1..3) = [[5 6], [8 9]]
615+ let mat = m ( & [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] ) ;
616+ let sub = mat. submatrix ( 1 ..3 , 1 ..3 ) ;
617+ assert_eq ! ( sub. dimensions( ) , ( 2 , 2 ) ) ;
618+ assert_eq ! ( sub[ ( 0 , 0 ) ] , z( 5 ) ) ;
619+ assert_eq ! ( sub[ ( 0 , 1 ) ] , z( 6 ) ) ;
620+ assert_eq ! ( sub[ ( 1 , 0 ) ] , z( 8 ) ) ;
621+ assert_eq ! ( sub[ ( 1 , 1 ) ] , z( 9 ) ) ;
622+ }
623+
624+ #[ test]
625+ fn test_submatrix_full ( ) {
626+ let mat = m ( & [ [ 1 , 2 ] , [ 3 , 4 ] ] ) ;
627+ let sub = mat. submatrix ( 0 ..2 , 0 ..2 ) ;
628+ assert_eq ! ( sub, mat) ;
629+ }
630+
631+ #[ test]
632+ fn test_submatrix_single_row ( ) {
633+ let mat = m ( & [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) ;
634+ let sub = mat. submatrix ( 1 ..2 , 0 ..3 ) ;
635+ assert_eq ! ( sub. dimensions( ) , ( 1 , 3 ) ) ;
636+ assert_eq ! ( sub. row( 0 ) , & [ z( 4 ) , z( 5 ) , z( 6 ) ] ) ;
637+ }
638+
639+ #[ test]
640+ fn test_submatrix_single_col ( ) {
641+ let mat = m ( & [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) ;
642+ let sub = mat. submatrix ( 0 ..2 , 1 ..2 ) ;
643+ assert_eq ! ( sub. dimensions( ) , ( 2 , 1 ) ) ;
644+ assert_eq ! ( sub[ ( 0 , 0 ) ] , z( 2 ) ) ;
645+ assert_eq ! ( sub[ ( 1 , 0 ) ] , z( 5 ) ) ;
646+ }
647+
648+ #[ test]
649+ #[ should_panic( expected = "submatrix OOB" ) ]
650+ fn test_submatrix_row_oob_panics ( ) {
651+ let mat = m ( & [ [ 1 , 2 ] , [ 3 , 4 ] ] ) ;
652+ let _ = mat. submatrix ( 0 ..3 , 0 ..2 ) ; // row end > nrows
653+ }
654+
655+ #[ test]
656+ #[ should_panic( expected = "submatrix OOB" ) ]
657+ fn test_submatrix_col_oob_panics ( ) {
658+ let mat = m ( & [ [ 1 , 2 ] , [ 3 , 4 ] ] ) ;
659+ let _ = mat. submatrix ( 0 ..2 , 0 ..3 ) ; // col end > ncols
660+ }
661+
662+ // ─── block_diagonal ───
663+
664+ #[ test]
665+ fn test_block_diagonal_three_identities_equals_identity ( ) {
666+ // diag(I_1, I_2, I_1) is the 4×4 identity (block_diag of identity blocks).
667+ let bd = Mat :: < F > :: block_diagonal ( & [
668+ Mat :: < F > :: identity ( 1 ) ,
669+ Mat :: < F > :: identity ( 2 ) ,
670+ Mat :: < F > :: identity ( 1 ) ,
671+ ] ) ;
672+ assert_eq ! ( bd, Mat :: <F >:: identity( 4 ) ) ;
673+ }
674+
675+ #[ test]
676+ fn test_block_diagonal_mixed_dims ( ) {
677+ // diag(2×3 block, 1×2 block) → 3×5 matrix
678+ let a = m ( & [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) ; // 2×3
679+ let b = m ( & [ [ 7 , 8 ] ] ) ; // 1×2
680+ let bd = Mat :: < F > :: block_diagonal ( & [ a, b] ) ;
681+
682+ assert_eq ! ( bd. dimensions( ) , ( 3 , 5 ) ) ;
683+
684+ // top-left 2×3 = a
685+ assert_eq ! ( bd[ ( 0 , 0 ) ] , z( 1 ) ) ;
686+ assert_eq ! ( bd[ ( 0 , 2 ) ] , z( 3 ) ) ;
687+ assert_eq ! ( bd[ ( 1 , 0 ) ] , z( 4 ) ) ;
688+ assert_eq ! ( bd[ ( 1 , 2 ) ] , z( 6 ) ) ;
689+
690+ // top-right 2×2 = zeros
691+ assert_eq ! ( bd[ ( 0 , 3 ) ] , z( 0 ) ) ;
692+ assert_eq ! ( bd[ ( 0 , 4 ) ] , z( 0 ) ) ;
693+ assert_eq ! ( bd[ ( 1 , 3 ) ] , z( 0 ) ) ;
694+ assert_eq ! ( bd[ ( 1 , 4 ) ] , z( 0 ) ) ;
695+
696+ // bottom-left 1×3 = zeros
697+ assert_eq ! ( bd[ ( 2 , 0 ) ] , z( 0 ) ) ;
698+ assert_eq ! ( bd[ ( 2 , 1 ) ] , z( 0 ) ) ;
699+ assert_eq ! ( bd[ ( 2 , 2 ) ] , z( 0 ) ) ;
700+
701+ // bottom-right 1×2 = b
702+ assert_eq ! ( bd[ ( 2 , 3 ) ] , z( 7 ) ) ;
703+ assert_eq ! ( bd[ ( 2 , 4 ) ] , z( 8 ) ) ;
704+ }
705+
706+ #[ test]
707+ fn test_block_diagonal_single_block ( ) {
708+ // diag(A) == A (degenerate case)
709+ let a = m ( & [ [ 1 , 2 ] , [ 3 , 4 ] ] ) ;
710+ let bd = Mat :: < F > :: block_diagonal ( & [ a. clone ( ) ] ) ;
711+ assert_eq ! ( bd, a) ;
712+ }
564713}
0 commit comments