Skip to content

Commit

Permalink
Fix moe smooth derivatives test and some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
relf committed Feb 12, 2024
1 parent 0e90141 commit 5b0b4ef
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 14 deletions.
5 changes: 0 additions & 5 deletions moe/src/clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ mod tests {
use crate::gp_algorithm::GpMixture;
use approx::assert_abs_diff_eq;
use egobox_doe::{FullFactorial, Lhs, SamplingMethod};
use env_logger::{Builder, Env};
#[cfg(not(feature = "blas"))]
use linfa_linalg::norm::*;
use ndarray::{array, Array1, Array2, Axis, Zip};
Expand Down Expand Up @@ -413,10 +412,6 @@ mod tests {

#[test]
fn test_find_best_cluster_nb_2d() {
let env = Env::new().filter_or("EGOBOX_LOG", "info");
let mut builder = Builder::from_env(env);
let builder = builder.target(env_logger::Target::Stdout);
builder.try_init().ok();
let doe = egobox_doe::FullFactorial::new(&array![[-1., 1.], [-1., 1.]]);
let xtrain = doe.sample(100);
let ytrain = l1norm(&xtrain);
Expand Down
15 changes: 7 additions & 8 deletions moe/src/gp_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ impl<R: Rng + SeedableRng + Clone> GpMixValidParams<f64, R> {
yt: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> Result<GpMixture> {
trace!("Moe training...");
let _opt = env_logger::try_init().ok();
let nx = xt.ncols();
let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();

Expand Down Expand Up @@ -1085,24 +1084,23 @@ mod tests {
let x = Array1::linspace(0., 1., 50).insert_axis(Axis(1));
let preds = moe.predict_values(&x).expect("MOE prediction");
let dpreds = moe.predict_derivatives(&x).expect("MOE drv prediction");
println!("dpred = {dpreds}");

let test_dir = "target/tests";
std::fs::create_dir_all(test_dir).ok();
write_npy(format!("{test_dir}/x_hard.npy"), &x).expect("x saved");
write_npy(format!("{test_dir}/preds_hard.npy"), &preds).expect("preds saved");
write_npy(format!("{test_dir}/dpreds_hard.npy"), &dpreds).expect("dpreds saved");
write_npy(format!("{test_dir}/x_moe_smooth.npy"), &x).expect("x saved");
write_npy(format!("{test_dir}/preds_moe_smooth.npy"), &preds).expect("preds saved");
write_npy(format!("{test_dir}/dpreds_moe_smooth.npy"), &dpreds).expect("dpreds saved");

let mut rng = Xoshiro256Plus::seed_from_u64(42);
for _ in 0..20 {
let x1: f64 = rng.gen_range(0.0..1.0);

let h = 1e-4;
let h = 1e-8;
let xtest = array![[x1]];

let x = array![[x1], [x1 + h], [x1 - h]];
let preds = moe.predict_derivatives(&x).unwrap();
let fdiff = preds[[1, 0]] - preds[[1, 0]] / 2. * h;
let preds = moe.predict_values(&x).unwrap();
let fdiff = (preds[[1, 0]] - preds[[2, 0]]) / (2. * h);

let drv = moe.predict_derivatives(&xtest).unwrap();
let df = df_test_1d(&xtest);
Expand All @@ -1115,6 +1113,7 @@ mod tests {
println!(
"Test predicted derivatives at {xtest}: drv {drv}, true df {df}, fdiff {fdiff}"
);
println!("preds(x, x+h, x-h)={}", preds);
assert_abs_diff_eq!(err, 0.0, epsilon = 2.5e-1);
}
}
Expand Down
1 change: 0 additions & 1 deletion moe/src/sgp_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ impl<R: Rng + SeedableRng + Clone> SparseGpMixtureValidParams<f64, R> {
yt: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> Result<SparseGpMixture> {
trace!("Sgp training...");
let _opt = env_logger::try_init().ok();
let nx = xt.ncols();
let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();

Expand Down

0 comments on commit 5b0b4ef

Please sign in to comment.