Skip to content

Commit

Permalink
Adjust test following LHS changes
Browse files Browse the repository at this point in the history
  • Loading branch information
relf committed Feb 16, 2024
1 parent 523f3ce commit a18516b
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions gp/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,7 @@ mod tests {
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let xt = egobox_doe::Lhs::new(&array![[-$limit, $limit], [-$limit, $limit]]).with_rng(rng.clone()).sample($nt);
let yt = [<$func>](&xt);
println!(stringify!(<$func>));

let gp = GaussianProcess::<f64, [<$regr Mean>], [<$corr Corr>] >::params(
[<$regr Mean>]::default(),
Expand All @@ -1573,7 +1574,7 @@ mod tests {
let x = Array::random_using((2,), Uniform::new(-$limit, $limit), &mut rng);
let xa: f64 = x[0];
let xb: f64 = x[1];
let e = 1e-4;
let e = 1e-5;

let x = array![
[xa, xb],
Expand All @@ -1595,13 +1596,8 @@ mod tests {
let diff_g = (y_pred[[1, 0]] - y_pred[[2, 0]]) / (2. * e);
let diff_d = (y_pred[[3, 0]] - y_pred[[4, 0]]) / (2. * e);

if "[<$corr>]" == "SquaredExponential" {
assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g);
assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d);
} else {
assert_abs_diff_eq!(y_deriv[[0, 0]], diff_g, epsilon=5e-1);
assert_abs_diff_eq!(y_deriv[[1, 0]], diff_d, epsilon=5e-1);
}
assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g);
assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d);
}
}
}
Expand Down Expand Up @@ -1708,14 +1704,14 @@ mod tests {

fn assert_rel_or_abs_error(y_deriv: f64, fdiff: f64) {
println!("analytic deriv = {y_deriv}, fdiff = {fdiff}");
if fdiff.abs() < 2e-1 {
if fdiff.abs() < 2e-1 || y_deriv.abs() < 2e-1 {
let atol = 2e-1;
println!("Check absolute error: should be < {atol}");
println!("Check absolute error: abs({y_deriv}) should be < {atol}");
assert_abs_diff_eq!(y_deriv, 0.0, epsilon = atol); // check absolute when close to zero
} else {
let rtol = 2e-1;
let rel_error = (y_deriv - fdiff).abs() / fdiff; // check relative
println!("Check relative error: should be < {rtol}");
let rtol = 3e-1;
let rel_error = (y_deriv - fdiff).abs() / fdiff.abs(); // check relative
println!("Check relative error: {rel_error} should be < {rtol}");
assert_abs_diff_eq!(rel_error, 0.0, epsilon = rtol);
}
}
Expand Down

0 comments on commit a18516b

Please sign in to comment.