Skip to content

Commit 0b4cdd4

Browse files
authored
Merge pull request #342 from lloydmeta/enhancement/blas_row_major_2d
Adds blas support for transposed row matrices (Closes #340)
2 parents 84744bd + c1958d2 commit 0b4cdd4

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ exclude = ["docgen/images/*"]
2222
[lib]
2323
name = "ndarray"
2424
bench = false
25-
test = false
25+
test = true
2626

2727
[dependencies.num-traits]
2828
version = "0.1.32"

src/linalg/impl_linalg.rs

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,10 @@ fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
626626
if !same_type::<A, S::Elem>() {
627627
return false;
628628
}
629+
let (m, n) = a.dim();
629630
let s0 = a.strides()[0];
630631
let s1 = a.strides()[1];
631-
if s1 != 1 {
632+
if !(s1 == 1 || n == 1) {
632633
return false;
633634
}
634635
if s0 < 1 || s1 < 1 {
@@ -639,11 +640,53 @@ fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
639640
{
640641
return false;
641642
}
642-
let (m, n) = a.dim();
643643
if m > blas_index::max_value() as usize ||
644644
n > blas_index::max_value() as usize
645645
{
646646
return false;
647647
}
648648
true
649649
}
650+
651+
#[cfg(test)]
652+
mod tests {
653+
654+
use super::*;
655+
656+
#[test]
657+
#[cfg(feature="blas")]
658+
fn blas_row_major_2d_normal_matrix() {
659+
let m: Array2<f32> = Array2::zeros((3, 5));
660+
assert!(blas_row_major_2d::<f32, _>(&m));
661+
}
662+
663+
#[test]
664+
#[cfg(feature="blas")]
665+
fn blas_row_major_2d_row_matrix() {
666+
let m: Array2<f32> = Array2::zeros((1, 5));
667+
assert!(blas_row_major_2d::<f32, _>(&m));
668+
}
669+
670+
#[test]
671+
#[cfg(feature="blas")]
672+
fn blas_row_major_2d_column_matrix() {
673+
let m: Array2<f32> = Array2::zeros((5, 1));
674+
assert!(blas_row_major_2d::<f32, _>(&m));
675+
}
676+
677+
#[test]
678+
#[cfg(feature="blas")]
679+
fn blas_row_major_2d_transposed_row_matrix() {
680+
let m: Array2<f32> = Array2::zeros((1, 5));
681+
let m_t = m.t();
682+
assert!(blas_row_major_2d::<f32, _>(&m_t));
683+
}
684+
685+
#[test]
686+
#[cfg(feature="blas")]
687+
fn blas_row_major_2d_transposed_column_matrix() {
688+
let m: Array2<f32> = Array2::zeros((5, 1));
689+
let m_t = m.t();
690+
assert!(blas_row_major_2d::<f32, _>(&m_t));
691+
}
692+
}

0 commit comments

Comments
 (0)