diff --git a/ndarray-linalg/tests/solveh.rs b/ndarray-linalg/tests/solveh.rs index 1074057f..ad5904ed 100644 --- a/ndarray-linalg/tests/solveh.rs +++ b/ndarray-linalg/tests/solveh.rs @@ -1,30 +1,115 @@ use ndarray::*; use ndarray_linalg::*; +use num_complex::{Complex32, Complex64}; -#[test] -fn solveh_random() { - let a: Array2 = random_hpd(3); - let x: Array1 = random(3); +fn test_solveh(n: usize, transpose: bool, rtol: A::Real) +where + A: Scalar + Lapack, +{ + let a: Array2 = if transpose { + random_hpd(n).reversed_axes() + } else { + random_hpd(n) + }; + let x: Array1 = random(n); let b = a.dot(&x); - let y = a.solveh_into(b).unwrap(); - assert_close_l2!(&x, &y, 1e-7); + let mut solutions = Vec::new(); + solutions.push(a.solveh(&b).unwrap()); + solutions.push(a.factorizeh().unwrap().solveh(&b).unwrap()); + solutions.push(a.factorizeh_into().unwrap().solveh(&b).unwrap()); + for solution in solutions { + assert_close_l2!(&x, &solution, rtol); + } +} +fn test_solveh_into(n: usize, transpose: bool, rtol: A::Real) +where + A: Scalar + Lapack, +{ + let a: Array2 = if transpose { + random_hpd(n).reversed_axes() + } else { + random_hpd(n) + }; + let x: Array1 = random(n); let b = a.dot(&x); - let f = a.factorizeh_into().unwrap(); - let y = f.solveh_into(b).unwrap(); - assert_close_l2!(&x, &y, 1e-7); + let mut solutions = Vec::new(); + solutions.push(a.solveh_into(b.clone()).unwrap()); + solutions.push(a.factorizeh().unwrap().solveh_into(b.clone()).unwrap()); + solutions.push(a.factorizeh_into().unwrap().solveh_into(b.clone()).unwrap()); + for solution in solutions { + assert_close_l2!(&x, &solution, rtol); + } } #[test] -fn solveh_random_t() { - let a: Array2 = random_hpd(3).reversed_axes(); - let x: Array1 = random(3); - let b = a.dot(&x); - let y = a.solveh_into(b).unwrap(); - assert_close_l2!(&x, &y, 1e-7); +fn solveh_empty() { + test_solveh::(0, false, 0.); + test_solveh::(0, false, 0.); + test_solveh::(0, false, 0.); + test_solveh::(0, false, 0.); +} - let b = a.dot(&x); - let f = a.factorizeh_into().unwrap(); - let y = f.solveh_into(b).unwrap(); - assert_close_l2!(&x, &y, 1e-7); +#[test] +fn solveh_random_float() { + for n in 1..=8 { + test_solveh::(n, false, 1e-6); + test_solveh::(n, false, 1e-9); + } +} + +#[test] +fn solveh_random_complex() { + for n in 1..=8 { + test_solveh::(n, false, 1e-6); + test_solveh::(n, false, 1e-9); + } +} + +#[test] +fn solveh_into_random_float() { + for n in 1..=8 { + test_solveh_into::(n, false, 1e-6); + test_solveh_into::(n, false, 1e-9); + } +} + +#[test] +fn solveh_into_random_complex() { + for n in 1..=8 { + test_solveh_into::(n, false, 1e-6); + test_solveh_into::(n, false, 1e-9); + } +} + +#[test] +fn solveh_random_float_t() { + for n in 1..=8 { + test_solveh::(n, true, 1e-6); + test_solveh::(n, true, 1e-9); + } +} + +#[test] +fn solveh_random_complex_t() { + for n in 1..=8 { + test_solveh::(n, true, 1e-6); + test_solveh::(n, true, 1e-9); + } +} + +#[test] +fn solveh_into_random_float_t() { + for n in 1..=8 { + test_solveh_into::(n, true, 1e-6); + test_solveh_into::(n, true, 1e-9); + } +} + +#[test] +fn solveh_into_random_complex_t() { + for n in 1..=8 { + test_solveh_into::(n, true, 1e-6); + test_solveh_into::(n, true, 1e-9); + } }