Skip to content

Commit c46794b

Browse files
committed
MAINT: Add numeric tests for mat mul with column / row
This adds numeric tests to accompany #342
1 parent f8f48ce commit c46794b

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

numeric-tests/tests/accuracy.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,62 @@ fn accurate_mul_f64() {
197197
}
198198

199199

200+
#[test]
201+
fn accurate_mul_with_column_f64() {
202+
// pick a few random sizes
203+
let mut rng = rand::weak_rng();
204+
for i in 0..10 {
205+
let m = rng.gen_range(1, 350);
206+
let k = rng.gen_range(1, 350);
207+
let a = gen_f64(Ix2(m, k));
208+
let b_owner = gen_f64(Ix2(k, k));
209+
let b_row_col;
210+
let b_sq;
211+
212+
// pick dense square or broadcasted to square matrix
213+
match i {
214+
0 ... 3 => b_sq = b_owner.view(),
215+
4 ... 7 => {
216+
b_row_col = b_owner.column(0);
217+
b_sq = b_row_col.broadcast((k, k)).unwrap();
218+
}
219+
_otherwise => {
220+
b_row_col = b_owner.row(0);
221+
b_sq = b_row_col.broadcast((k, k)).unwrap();
222+
}
223+
};
224+
225+
for j in 0..k {
226+
for &flip in &[true, false] {
227+
let j = j as isize;
228+
let b = if flip {
229+
// one row in 2D
230+
b_sq.slice(s![j..j + 1, ..]).reversed_axes()
231+
} else {
232+
// one column in 2D
233+
b_sq.slice(s![.., j..j + 1])
234+
};
235+
println!("Testing size ({} × {}) by ({} × {})", a.shape()[0], a.shape()[1], b.shape()[0], b.shape()[1]);
236+
println!("Strides ({:?}) by ({:?})", a.strides(), b.strides());
237+
let c = a.dot(&b);
238+
let reference = reference_mat_mul(&a, &b);
239+
let diff = (&c - &reference).mapv_into(f64::abs);
240+
241+
let rtol = 1e-7;
242+
let atol = 1e-12;
243+
let crtol = c.mapv(|x| x.abs() * rtol);
244+
let tol = crtol + atol;
245+
let tol_m_diff = &diff - &tol;
246+
let maxdiff = *tol_m_diff.max();
247+
println!("diff offset from tolerance level= {:.2e}", maxdiff);
248+
if maxdiff > 0. {
249+
panic!("results differ");
250+
}
251+
}
252+
}
253+
}
254+
}
255+
200256

201257
trait Utils {
202258
type Elem;

0 commit comments

Comments
 (0)