@@ -626,9 +626,10 @@ fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
626
626
if !same_type :: < A , S :: Elem > ( ) {
627
627
return false ;
628
628
}
629
+ let ( m, n) = a. dim ( ) ;
629
630
let s0 = a. strides ( ) [ 0 ] ;
630
631
let s1 = a. strides ( ) [ 1 ] ;
631
- if s1 != 1 {
632
+ if ! ( s1 == 1 || n == 1 ) {
632
633
return false ;
633
634
}
634
635
if s0 < 1 || s1 < 1 {
@@ -639,11 +640,53 @@ fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
639
640
{
640
641
return false ;
641
642
}
642
- let ( m, n) = a. dim ( ) ;
643
643
if m > blas_index:: max_value ( ) as usize ||
644
644
n > blas_index:: max_value ( ) as usize
645
645
{
646
646
return false ;
647
647
}
648
648
true
649
649
}
650
+
651
+ #[ cfg( test) ]
652
+ mod tests {
653
+
654
+ use super :: * ;
655
+
656
+ #[ test]
657
+ #[ cfg( feature="blas" ) ]
658
+ fn blas_row_major_2d_normal_matrix ( ) {
659
+ let m: Array2 < f32 > = Array2 :: zeros ( ( 3 , 5 ) ) ;
660
+ assert ! ( blas_row_major_2d:: <f32 , _>( & m) ) ;
661
+ }
662
+
663
+ #[ test]
664
+ #[ cfg( feature="blas" ) ]
665
+ fn blas_row_major_2d_row_matrix ( ) {
666
+ let m: Array2 < f32 > = Array2 :: zeros ( ( 1 , 5 ) ) ;
667
+ assert ! ( blas_row_major_2d:: <f32 , _>( & m) ) ;
668
+ }
669
+
670
+ #[ test]
671
+ #[ cfg( feature="blas" ) ]
672
+ fn blas_row_major_2d_column_matrix ( ) {
673
+ let m: Array2 < f32 > = Array2 :: zeros ( ( 5 , 1 ) ) ;
674
+ assert ! ( blas_row_major_2d:: <f32 , _>( & m) ) ;
675
+ }
676
+
677
+ #[ test]
678
+ #[ cfg( feature="blas" ) ]
679
+ fn blas_row_major_2d_transposed_row_matrix ( ) {
680
+ let m: Array2 < f32 > = Array2 :: zeros ( ( 1 , 5 ) ) ;
681
+ let m_t = m. t ( ) ;
682
+ assert ! ( blas_row_major_2d:: <f32 , _>( & m_t) ) ;
683
+ }
684
+
685
+ #[ test]
686
+ #[ cfg( feature="blas" ) ]
687
+ fn blas_row_major_2d_transposed_column_matrix ( ) {
688
+ let m: Array2 < f32 > = Array2 :: zeros ( ( 5 , 1 ) ) ;
689
+ let m_t = m. t ( ) ;
690
+ assert ! ( blas_row_major_2d:: <f32 , _>( & m_t) ) ;
691
+ }
692
+ }
0 commit comments