@@ -75,7 +75,6 @@ impl<R: Rng + SeedableRng + Clone> GpMixValidParams<f64, R> {
75
75
yt : & ArrayBase < impl Data < Elem = f64 > , Ix2 > ,
76
76
) -> Result < GpMixture > {
77
77
trace ! ( "Moe training..." ) ;
78
- let _opt = env_logger:: try_init ( ) . ok ( ) ;
79
78
let nx = xt. ncols ( ) ;
80
79
let data = concatenate ( Axis ( 1 ) , & [ xt. view ( ) , yt. view ( ) ] ) . unwrap ( ) ;
81
80
@@ -1085,24 +1084,23 @@ mod tests {
1085
1084
let x = Array1 :: linspace ( 0. , 1. , 50 ) . insert_axis ( Axis ( 1 ) ) ;
1086
1085
let preds = moe. predict_values ( & x) . expect ( "MOE prediction" ) ;
1087
1086
let dpreds = moe. predict_derivatives ( & x) . expect ( "MOE drv prediction" ) ;
1088
- println ! ( "dpred = {dpreds}" ) ;
1089
1087
1090
1088
let test_dir = "target/tests" ;
1091
1089
std:: fs:: create_dir_all ( test_dir) . ok ( ) ;
1092
- write_npy ( format ! ( "{test_dir}/x_hard .npy" ) , & x) . expect ( "x saved" ) ;
1093
- write_npy ( format ! ( "{test_dir}/preds_hard .npy" ) , & preds) . expect ( "preds saved" ) ;
1094
- write_npy ( format ! ( "{test_dir}/dpreds_hard .npy" ) , & dpreds) . expect ( "dpreds saved" ) ;
1090
+ write_npy ( format ! ( "{test_dir}/x_moe_smooth .npy" ) , & x) . expect ( "x saved" ) ;
1091
+ write_npy ( format ! ( "{test_dir}/preds_moe_smooth .npy" ) , & preds) . expect ( "preds saved" ) ;
1092
+ write_npy ( format ! ( "{test_dir}/dpreds_moe_smooth .npy" ) , & dpreds) . expect ( "dpreds saved" ) ;
1095
1093
1096
1094
let mut rng = Xoshiro256Plus :: seed_from_u64 ( 42 ) ;
1097
1095
for _ in 0 ..20 {
1098
1096
let x1: f64 = rng. gen_range ( 0.0 ..1.0 ) ;
1099
1097
1100
- let h = 1e-4 ;
1098
+ let h = 1e-8 ;
1101
1099
let xtest = array ! [ [ x1] ] ;
1102
1100
1103
1101
let x = array ! [ [ x1] , [ x1 + h] , [ x1 - h] ] ;
1104
- let preds = moe. predict_derivatives ( & x) . unwrap ( ) ;
1105
- let fdiff = preds[ [ 1 , 0 ] ] - preds[ [ 1 , 0 ] ] / 2. * h;
1102
+ let preds = moe. predict_values ( & x) . unwrap ( ) ;
1103
+ let fdiff = ( preds[ [ 1 , 0 ] ] - preds[ [ 2 , 0 ] ] ) / ( 2. * h) ;
1106
1104
1107
1105
let drv = moe. predict_derivatives ( & xtest) . unwrap ( ) ;
1108
1106
let df = df_test_1d ( & xtest) ;
@@ -1115,6 +1113,7 @@ mod tests {
1115
1113
println ! (
1116
1114
"Test predicted derivatives at {xtest}: drv {drv}, true df {df}, fdiff {fdiff}"
1117
1115
) ;
1116
+ println ! ( "preds(x, x+h, x-h)={}" , preds) ;
1118
1117
assert_abs_diff_eq ! ( err, 0.0 , epsilon = 2.5e-1 ) ;
1119
1118
}
1120
1119
}
0 commit comments