Skip to content

Commit

Permalink
Mixture of sparse GP (#131)
Browse files Browse the repository at this point in the history
* Add GpSurrogate traits

* Add Sgp surrogate implementations

* Finally move SGP surrogate traits in MOE

* Make test reproducible

* Add MoE of sparse GP

* Test Sgp in Python

* Renaming Python SparseGpMix

* Renaming GpMixture, SparseGpMixture

* Renaming SparseGpMixParams,  SparseGpMixValidParams

* Renaming GpMixParams, GpMixValidParams

* Fix doc comments

* Renaming Moe::gp_parameters, Moe::gp_algorithm

* Renaming SparseGpx

* Renaming SparseGpMixtureParams

* Fix pyhton sgp test

* Add trait in scope

* Fix feature activation
  • Loading branch information
relf authored Jan 25, 2024
1 parent e7705fd commit 82c24c8
Show file tree
Hide file tree
Showing 23 changed files with 1,712 additions and 329 deletions.
2 changes: 1 addition & 1 deletion ego/src/criteria/wb2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ mod tests {
fn test_grad_wbs2() {
let xt = egobox_doe::Lhs::new(&array![[-10., 10.], [-10., 10.]]).sample(10);
let yt = sphere(&xt);
let gp = Moe::params()
let gp = GpMixture::params()
.regression_spec(RegressionSpec::CONSTANT)
.correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
.recombination(Recombination::Hard)
Expand Down
6 changes: 3 additions & 3 deletions ego/src/egor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ use crate::errors::Result;
use crate::mixint::*;
use crate::types::*;

use egobox_moe::MoeParams;
use egobox_moe::GpMixParams;
use log::info;
use ndarray::{concatenate, Array2, ArrayBase, Axis, Data, Ix2};
use ndarray_rand::rand::SeedableRng;
Expand Down Expand Up @@ -136,7 +136,7 @@ impl<O: GroupFunc> EgorBuilder<O> {
pub fn min_within(
self,
xlimits: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> Egor<O, MoeParams<f64, Xoshiro256Plus>> {
) -> Egor<O, GpMixParams<f64, Xoshiro256Plus>> {
let rng = if let Some(seed) = self.config.seed {
Xoshiro256Plus::seed_from_u64(seed)
} else {
Expand All @@ -155,7 +155,7 @@ impl<O: GroupFunc> EgorBuilder<O> {
/// Build an Egor optimizer to minimize the function R^n -> R^p taking
/// inputs specified with given xtypes where some of components may be
/// discrete variables (mixed-integer optimization).
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> Egor<O, MixintMoeParams> {
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> Egor<O, MixintGpMixParams> {
let rng = if let Some(seed) = self.config.seed {
Xoshiro256Plus::seed_from_u64(seed)
} else {
Expand Down
6 changes: 3 additions & 3 deletions ego/src/egor_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use crate::egor_solver::*;
use crate::mixint::*;
use crate::types::*;

use egobox_moe::MoeParams;
use egobox_moe::GpMixParams;
use ndarray::{Array2, ArrayBase, Data, Ix2};
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
Expand Down Expand Up @@ -80,7 +80,7 @@ impl EgorServiceBuilder {
pub fn min_within(
self,
xlimits: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> EgorService<MoeParams<f64, Xoshiro256Plus>> {
) -> EgorService<GpMixParams<f64, Xoshiro256Plus>> {
let rng = if let Some(seed) = self.config.seed {
Xoshiro256Plus::seed_from_u64(seed)
} else {
Expand All @@ -98,7 +98,7 @@ impl EgorServiceBuilder {
/// Build an Egor optimizer to minimize the function R^n -> R^p taking
/// inputs specified with given xtypes where some of components may be
/// discrete variables (mixed-integer optimization).
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> EgorService<MixintMoeParams> {
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> EgorService<MixintGpMixParams> {
let rng = if let Some(seed) = self.config.seed {
Xoshiro256Plus::seed_from_u64(seed)
} else {
Expand Down
16 changes: 8 additions & 8 deletions ego/src/egor_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! use ndarray::{array, Array2, ArrayView1, ArrayView2, Zip};
//! use egobox_doe::{Lhs, SamplingMethod};
//! use egobox_ego::{EgorBuilder, EgorConfig, InfillStrategy, InfillOptimizer, ObjFunc, EgorSolver, to_xtypes};
//! use egobox_moe::MoeParams;
//! use egobox_moe::GpMixParams;
//! use rand_xoshiro::Xoshiro256Plus;
//! use ndarray_rand::rand::SeedableRng;
//! use argmin::core::Executor;
Expand All @@ -28,7 +28,7 @@
//! let xtypes = to_xtypes(&array![[-2., 2.], [-2., 2.]]);
//! let fobj = ObjFunc::new(rosenb);
//! let config = EgorConfig::default().xtypes(&xtypes);
//! let solver: EgorSolver<MoeParams<f64, Xoshiro256Plus>> = EgorSolver::new(config, rng);
//! let solver: EgorSolver<GpMixParams<f64, Xoshiro256Plus>> = EgorSolver::new(config, rng);
//! let res = Executor::new(fobj, solver)
//! .configure(|state| state.max_iters(20))
//! .run()
Expand All @@ -48,7 +48,7 @@
//! use ndarray::{array, Array2, ArrayView1, ArrayView2, Zip};
//! use egobox_doe::{Lhs, SamplingMethod};
//! use egobox_ego::{EgorBuilder, EgorConfig, InfillStrategy, InfillOptimizer, ObjFunc, EgorSolver, to_xtypes};
//! use egobox_moe::MoeParams;
//! use egobox_moe::GpMixParams;
//! use rand_xoshiro::Xoshiro256Plus;
//! use ndarray_rand::rand::SeedableRng;
//! use argmin::core::Executor;
Expand Down Expand Up @@ -95,7 +95,7 @@
//! .doe(&doe)
//! .target(-5.5080);
//!
//! let solver: EgorSolver<MoeParams<f64, Xoshiro256Plus>> =
//! let solver: EgorSolver<GpMixParams<f64, Xoshiro256Plus>> =
//! EgorSolver::new(config, rng);
//!
//! let res = Executor::new(fobj, solver)
Expand All @@ -117,7 +117,7 @@ use crate::types::*;
use crate::utils::{compute_cstr_scales, update_data};

use egobox_doe::{Lhs, LhsKind, SamplingMethod};
use egobox_moe::{ClusteredSurrogate, Clustering, CorrelationSpec, MoeParams, RegressionSpec};
use egobox_moe::{ClusteredSurrogate, Clustering, CorrelationSpec, GpMixParams, RegressionSpec};
use env_logger::{Builder, Env};
use finitediff::FiniteDiff;
use linfa::ParamGuard;
Expand Down Expand Up @@ -164,14 +164,14 @@ pub struct EgorSolver<SB: SurrogateBuilder> {
pub(crate) rng: Xoshiro256Plus,
}

impl SurrogateBuilder for MoeParams<f64, Xoshiro256Plus> {
impl SurrogateBuilder for GpMixParams<f64, Xoshiro256Plus> {
/// Constructor from domain space specified with types
/// **panic** if xtypes contains other types than continuous type `Float`
fn new_with_xtypes(xtypes: &[XType]) -> Self {
if crate::utils::discrete(xtypes) {
panic!("MoeParams cannot be created with discrete types!");
panic!("GpMixParams cannot be created with discrete types!");
}
MoeParams::new()
GpMixParams::new()
}

/// Sets the allowed regression models used in gaussian processes.
Expand Down
57 changes: 30 additions & 27 deletions ego/src/mixint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::errors::{EgoError, Result};
use crate::types::{SurrogateBuilder, XType};
use egobox_doe::{FullFactorial, Lhs, Random};
use egobox_moe::{
Clustered, ClusteredSurrogate, Clustering, CorrelationSpec, Moe, MoeParams, RegressionSpec,
Surrogate,
Clustered, ClusteredSurrogate, Clustering, CorrelationSpec, FullGpSurrogate, GpMixParams,
GpMixture, GpSurrogate, RegressionSpec,
};
use linfa::traits::{Fit, PredictInplace};
use linfa::{DatasetBase, Float, ParamGuard};
Expand Down Expand Up @@ -279,14 +279,14 @@ impl<F: Float, S: egobox_doe::SamplingMethod<F>> egobox_doe::SamplingMethod<F>
}

/// Moe type builder for mixed-integer Egor optimizer
pub type MoeBuilder = MoeParams<f64, Xoshiro256Plus>;
pub type MoeBuilder = GpMixParams<f64, Xoshiro256Plus>;
/// A decorator of Moe surrogate builder that takes into account XType specifications
///
/// It allows to implement continuous relaxation over continuous Moe builder.
#[derive(Clone, Serialize, Deserialize)]
pub struct MixintMoeValidParams {
/// The surrogate factory
surrogate_builder: MoeParams<f64, Xoshiro256Plus>,
surrogate_builder: GpMixParams<f64, Xoshiro256Plus>,
/// The input specifications
xtypes: Vec<XType>,
/// whether data are in given in folded space (enum indexes) or not (enum masks)
Expand All @@ -309,12 +309,12 @@ impl MixintMoeValidParams {

/// Parameters for mixture of experts surrogate model
#[derive(Clone, Serialize, Deserialize)]
pub struct MixintMoeParams(MixintMoeValidParams);
pub struct MixintGpMixParams(MixintMoeValidParams);

impl MixintMoeParams {
impl MixintGpMixParams {
/// Constructor given `xtypes` specifications and given surrogate builder
pub fn new(xtypes: &[XType], surrogate_builder: &MoeBuilder) -> Self {
MixintMoeParams(MixintMoeValidParams {
MixintGpMixParams(MixintMoeValidParams {
surrogate_builder: surrogate_builder.clone(),
xtypes: xtypes.to_vec(),
work_in_folded_space: false,
Expand Down Expand Up @@ -384,9 +384,9 @@ impl MixintMoeValidParams {
}
}

impl SurrogateBuilder for MixintMoeParams {
impl SurrogateBuilder for MixintGpMixParams {
fn new_with_xtypes(xtypes: &[XType]) -> Self {
MixintMoeParams::new(xtypes, &MoeParams::new())
MixintGpMixParams::new(xtypes, &GpMixParams::new())
}

/// Sets the allowed regression models used in gaussian processes.
Expand Down Expand Up @@ -468,7 +468,7 @@ impl<D: Data<Elem = f64>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, EgoError>
}
}

impl ParamGuard for MixintMoeParams {
impl ParamGuard for MixintGpMixParams {
type Checked = MixintMoeValidParams;
type Error = EgoError;

Expand All @@ -482,17 +482,17 @@ impl ParamGuard for MixintMoeParams {
}
}

impl From<MixintMoeValidParams> for MixintMoeParams {
impl From<MixintMoeValidParams> for MixintGpMixParams {
fn from(item: MixintMoeValidParams) -> Self {
MixintMoeParams(item)
MixintGpMixParams(item)
}
}

/// The Moe model that takes into account XType specifications
#[derive(Serialize, Deserialize)]
pub struct MixintMoe {
/// the decorated Moe
moe: Moe,
moe: GpMixture,
/// The input specifications
xtypes: Vec<XType>,
/// whether training input data are in given in folded space (enum indexes) or not (enum masks)
Expand Down Expand Up @@ -527,7 +527,7 @@ impl Clustered for MixintMoe {
}

#[typetag::serde]
impl Surrogate for MixintMoe {
impl GpSurrogate for MixintMoe {
fn predict_values(&self, x: &ArrayView2<f64>) -> egobox_moe::Result<Array2<f64>> {
let mut xcast = if self.work_in_folded_space {
unfold_with_enum_mask(&self.xtypes, x)
Expand All @@ -548,6 +548,20 @@ impl Surrogate for MixintMoe {
self.moe.predict_variances(&xcast)
}

/// Save Moe model in given file.
fn save(&self, path: &str) -> egobox_moe::Result<()> {
let mut file = fs::File::create(path).unwrap();
let bytes = match serde_json::to_string(self) {
Ok(b) => b,
Err(err) => return Err(MoeError::SaveError(err)),
};
file.write_all(bytes.as_bytes())?;
Ok(())
}
}

#[typetag::serde]
impl FullGpSurrogate for MixintMoe {
fn predict_derivatives(&self, x: &ArrayView2<f64>) -> egobox_moe::Result<Array2<f64>> {
let mut xcast = if self.work_in_folded_space {
unfold_with_enum_mask(&self.xtypes, x)
Expand Down Expand Up @@ -577,17 +591,6 @@ impl Surrogate for MixintMoe {
cast_to_discrete_values_mut(&self.xtypes, &mut xcast);
self.moe.sample(&xcast.view(), n_traj)
}

/// Save Moe model in given file.
fn save(&self, path: &str) -> egobox_moe::Result<()> {
let mut file = fs::File::create(path).unwrap();
let bytes = match serde_json::to_string(self) {
Ok(b) => b,
Err(err) => return Err(MoeError::SaveError(err)),
};
file.write_all(bytes.as_bytes())?;
Ok(())
}
}

impl ClusteredSurrogate for MixintMoe {}
Expand All @@ -609,7 +612,7 @@ impl<D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array2<f64>> for Mix
}
}

struct MoeVariancePredictor<'a>(&'a Moe);
struct MoeVariancePredictor<'a>(&'a GpMixture);
impl<'a, D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array2<f64>>
for MoeVariancePredictor<'a>
{
Expand Down Expand Up @@ -708,7 +711,7 @@ impl MixintContext {
surrogate_builder: &MoeBuilder,
dataset: &DatasetBase<Array2<f64>, Array2<f64>>,
) -> Result<MixintMoe> {
let mut params = MixintMoeParams::new(&self.xtypes, surrogate_builder);
let mut params = MixintGpMixParams::new(&self.xtypes, surrogate_builder);
let params = params.work_in_folded_space(self.work_in_folded_space);
params.fit(dataset)
}
Expand Down
12 changes: 5 additions & 7 deletions gp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ categories = ["algorithms", "mathematics", "science"]

default = []

serializable = ["serde", "linfa/serde"]
serializable = ["serde", "typetag", "linfa/serde"]
persistent = ["serializable", "serde_json"]
blas = ["ndarray-linalg", "linfa/ndarray-linalg", "linfa-pls/blas"]

[dependencies]
Expand All @@ -37,12 +38,9 @@ num-traits = "0.2"
thiserror = "1"
log = "0.4"

[dependencies.serde]
package = "serde"
version = "1.0"
default-features = false
features = ["std", "derive"]
optional = true
serde = { version = "1", features = ["derive"], optional = true }
serde_json = { version = "1", optional = true }
typetag = { version = "0.1", optional = true }

[dev-dependencies]
criterion = "0.4.0"
Expand Down
12 changes: 12 additions & 0 deletions gp/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,16 @@ pub enum GpError {
/// When a linfa error occurs
#[error(transparent)]
LinfaError(#[from] linfa::error::Error),
#[cfg(feature = "persistent")]
#[error("Save error: {0}")]
SaveError(#[from] serde_json::Error),
/// When error during loading
#[error("Load IO error")]
LoadIoError(#[from] std::io::Error),
/// When error during loading
#[error("Load error: {0}")]
LoadError(String),
/// When error during loading
#[error("InvalidValue error: {0}")]
InvalidValueError(String),
}
8 changes: 7 additions & 1 deletion gp/src/sgp_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,10 +804,11 @@ mod tests {
let eta2: f64 = 0.01;
let (xt, yt) = make_test_data(nt, eta2, &mut rng);

let xplot = Array::linspace(-1., 1., 100).insert_axis(Axis(1));
let xplot = Array::linspace(-0.5, 0.5, 100).insert_axis(Axis(1));
let n_inducings = 30;

let sgp = SparseKriging::params(Inducings::Randomized(n_inducings))
.seed(Some(42))
.initial_theta(Some(vec![0.1]))
.fit(&Dataset::new(xt.clone(), yt.clone()))
.expect("GP fitted");
Expand All @@ -816,7 +817,12 @@ mod tests {
assert_abs_diff_eq!(eta2, sgp.noise_variance());

let sgp_vals = sgp.predict_values(&xplot).unwrap();
let yplot = f_obj(&xplot);
let errvals = (yplot - &sgp_vals).mapv(|v| v.abs());
assert_abs_diff_eq!(errvals, Array2::zeros((xplot.nrows(), 1)), epsilon = 1.0);
let sgp_vars = sgp.predict_variances(&xplot).unwrap();
let errvars = (&sgp_vars - Array2::from_elem((xplot.nrows(), 1), 0.01)).mapv(|v| v.abs());
assert_abs_diff_eq!(errvars, Array2::zeros((xplot.nrows(), 1)), epsilon = 0.05);

save_data(&xt, &yt, sgp.inducings(), &xplot, &sgp_vals, &sgp_vars);
}
Expand Down
6 changes: 3 additions & 3 deletions moe/examples/clustering.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use egobox_doe::{Lhs, SamplingMethod};
use egobox_moe::{Moe, Recombination};
use egobox_moe::{GpMixture, Recombination};
use linfa::prelude::{Dataset, Fit};
use ndarray::{arr2, Array, Array2, Axis, Zip};
use std::error::Error;
Expand All @@ -22,8 +22,8 @@ fn main() -> Result<(), Box<dyn Error>> {
let xtrain = Lhs::new(&arr2(&[[0., 1.]])).sample(50);
let ytrain = f3parts(&xtrain);
let ds = Dataset::new(xtrain, ytrain);
let moe1 = Moe::params().fit(&ds)?;
let moe3 = Moe::params()
let moe1 = GpMixture::params().fit(&ds)?;
let moe3 = GpMixture::params()
.n_clusters(3)
.recombination(Recombination::Hard)
.fit(&ds)?;
Expand Down
4 changes: 2 additions & 2 deletions moe/examples/moe_norm1.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use csv::ReaderBuilder;
use egobox_doe::{FullFactorial, SamplingMethod};
use egobox_moe::Moe;
use egobox_moe::GpMixture;
use linfa::{traits::Fit, Dataset};
use ndarray::{arr2, s, Array2, Axis};
use ndarray_csv::Array2Reader;
Expand All @@ -24,7 +24,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let xtrain = data_train.slice(s![.., ..2_usize]).to_owned();
let ytrain = data_train.slice(s![.., 2_usize..]).to_owned();
let ds = Dataset::new(xtrain, ytrain);
let moe = Moe::params().n_clusters(4).fit(&ds)?;
let moe = GpMixture::params().n_clusters(4).fit(&ds)?;

let xlimits = arr2(&[[-1., 1.], [-1., 1.]]);
let xtest = FullFactorial::new(&xlimits).sample(100);
Expand Down
6 changes: 3 additions & 3 deletions moe/examples/norm1.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use egobox_doe::{Lhs, SamplingMethod};
use egobox_moe::{Moe, Recombination};
use egobox_moe::{GpMixture, Recombination};
use linfa::{traits::Fit, Dataset};
use ndarray::{arr2, Array2, Axis};
use std::error::Error;
Expand All @@ -12,8 +12,8 @@ fn main() -> Result<(), Box<dyn Error>> {
let xtrain = Lhs::new(&arr2(&[[-1., 1.], [-1., 1.]])).sample(200);
let ytrain = norm1(&xtrain);
let ds = Dataset::new(xtrain, ytrain);
let moe1 = Moe::params().fit(&ds)?;
let moe5 = Moe::params()
let moe1 = GpMixture::params().fit(&ds)?;
let moe5 = GpMixture::params()
.n_clusters(6)
.recombination(Recombination::Hard)
.fit(&ds)?;
Expand Down
Loading

0 comments on commit 82c24c8

Please sign in to comment.