From a18516bb1f0fc18f6f0f2f9f3c7db0bb53dcf6e7 Mon Sep 17 00:00:00 2001 From: relf Date: Fri, 16 Feb 2024 21:37:36 +0100 Subject: [PATCH] Adjust test following LHS changes --- gp/src/algorithm.rs | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/gp/src/algorithm.rs b/gp/src/algorithm.rs index 8177610d..af49745e 100644 --- a/gp/src/algorithm.rs +++ b/gp/src/algorithm.rs @@ -1561,6 +1561,7 @@ mod tests { let mut rng = Xoshiro256Plus::seed_from_u64(42); let xt = egobox_doe::Lhs::new(&array![[-$limit, $limit], [-$limit, $limit]]).with_rng(rng.clone()).sample($nt); let yt = [<$func>](&xt); + println!(stringify!(<$func>)); let gp = GaussianProcess::], [<$corr Corr>] >::params( [<$regr Mean>]::default(), @@ -1573,7 +1574,7 @@ mod tests { let x = Array::random_using((2,), Uniform::new(-$limit, $limit), &mut rng); let xa: f64 = x[0]; let xb: f64 = x[1]; - let e = 1e-4; + let e = 1e-5; let x = array![ [xa, xb], @@ -1595,13 +1596,8 @@ mod tests { let diff_g = (y_pred[[1, 0]] - y_pred[[2, 0]]) / (2. * e); let diff_d = (y_pred[[3, 0]] - y_pred[[4, 0]]) / (2. * e); - if "[<$corr>]" == "SquaredExponential" { - assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g); - assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d); - } else { - assert_abs_diff_eq!(y_deriv[[0, 0]], diff_g, epsilon=5e-1); - assert_abs_diff_eq!(y_deriv[[1, 0]], diff_d, epsilon=5e-1); - } + assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g); + assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d); } } } @@ -1708,14 +1704,14 @@ mod tests { fn assert_rel_or_abs_error(y_deriv: f64, fdiff: f64) { println!("analytic deriv = {y_deriv}, fdiff = {fdiff}"); - if fdiff.abs() < 2e-1 { + if fdiff.abs() < 2e-1 || y_deriv.abs() < 2e-1 { let atol = 2e-1; - println!("Check absolute error: should be < {atol}"); + println!("Check absolute error: abs({y_deriv}) should be < {atol}"); assert_abs_diff_eq!(y_deriv, 0.0, epsilon = atol); // check absolute when close to zero } else { - let rtol = 2e-1; - let rel_error = (y_deriv - fdiff).abs() / fdiff; // check relative - println!("Check relative error: should be < {rtol}"); + let rtol = 3e-1; + let rel_error = (y_deriv - fdiff).abs() / fdiff.abs(); // check relative + println!("Check relative error: {rel_error} should be < {rtol}"); assert_abs_diff_eq!(rel_error, 0.0, epsilon = rtol); } }