Skip to content

Commit

Permalink
Renaming GpMixParams, GpMixValidParams
Browse files Browse the repository at this point in the history
  • Loading branch information
relf committed Jan 25, 2024
1 parent e5c72c5 commit e8a013d
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 59 deletions.
6 changes: 3 additions & 3 deletions ego/src/egor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ use crate::errors::Result;
use crate::mixint::*;
use crate::types::*;

use egobox_moe::MoeParams;
use egobox_moe::GpMixParams;
use log::info;
use ndarray::{concatenate, Array2, ArrayBase, Axis, Data, Ix2};
use ndarray_rand::rand::SeedableRng;
Expand Down Expand Up @@ -136,7 +136,7 @@ impl<O: GroupFunc> EgorBuilder<O> {
pub fn min_within(
self,
xlimits: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> Egor<O, MoeParams<f64, Xoshiro256Plus>> {
) -> Egor<O, GpMixParams<f64, Xoshiro256Plus>> {
let rng = if let Some(seed) = self.config.seed {
Xoshiro256Plus::seed_from_u64(seed)
} else {
Expand All @@ -155,7 +155,7 @@ impl<O: GroupFunc> EgorBuilder<O> {
/// Build an Egor optimizer to minimize the function R^n -> R^p taking
/// inputs specified with given xtypes where some of components may be
/// discrete variables (mixed-integer optimization).
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> Egor<O, MixintMoeParams> {
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> Egor<O, MixintGpMixParams> {
let rng = if let Some(seed) = self.config.seed {
Xoshiro256Plus::seed_from_u64(seed)
} else {
Expand Down
6 changes: 3 additions & 3 deletions ego/src/egor_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use crate::egor_solver::*;
use crate::mixint::*;
use crate::types::*;

use egobox_moe::MoeParams;
use egobox_moe::GpMixParams;
use ndarray::{Array2, ArrayBase, Data, Ix2};
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
Expand Down Expand Up @@ -80,7 +80,7 @@ impl EgorServiceBuilder {
pub fn min_within(
self,
xlimits: &ArrayBase<impl Data<Elem = f64>, Ix2>,
) -> EgorService<MoeParams<f64, Xoshiro256Plus>> {
) -> EgorService<GpMixParams<f64, Xoshiro256Plus>> {
let rng = if let Some(seed) = self.config.seed {
Xoshiro256Plus::seed_from_u64(seed)
} else {
Expand All @@ -98,7 +98,7 @@ impl EgorServiceBuilder {
/// Build an Egor optimizer to minimize the function R^n -> R^p taking
/// inputs specified with given xtypes where some of components may be
/// discrete variables (mixed-integer optimization).
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> EgorService<MixintMoeParams> {
pub fn min_within_mixint_space(self, xtypes: &[XType]) -> EgorService<MixintGpMixParams> {
let rng = if let Some(seed) = self.config.seed {
Xoshiro256Plus::seed_from_u64(seed)
} else {
Expand Down
16 changes: 8 additions & 8 deletions ego/src/egor_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! use ndarray::{array, Array2, ArrayView1, ArrayView2, Zip};
//! use egobox_doe::{Lhs, SamplingMethod};
//! use egobox_ego::{EgorBuilder, EgorConfig, InfillStrategy, InfillOptimizer, ObjFunc, EgorSolver, to_xtypes};
//! use egobox_moe::MoeParams;
//! use egobox_moe::GpMixParams;
//! use rand_xoshiro::Xoshiro256Plus;
//! use ndarray_rand::rand::SeedableRng;
//! use argmin::core::Executor;
Expand All @@ -28,7 +28,7 @@
//! let xtypes = to_xtypes(&array![[-2., 2.], [-2., 2.]]);
//! let fobj = ObjFunc::new(rosenb);
//! let config = EgorConfig::default().xtypes(&xtypes);
//! let solver: EgorSolver<MoeParams<f64, Xoshiro256Plus>> = EgorSolver::new(config, rng);
//! let solver: EgorSolver<GpMixParams<f64, Xoshiro256Plus>> = EgorSolver::new(config, rng);
//! let res = Executor::new(fobj, solver)
//! .configure(|state| state.max_iters(20))
//! .run()
Expand All @@ -48,7 +48,7 @@
//! use ndarray::{array, Array2, ArrayView1, ArrayView2, Zip};
//! use egobox_doe::{Lhs, SamplingMethod};
//! use egobox_ego::{EgorBuilder, EgorConfig, InfillStrategy, InfillOptimizer, ObjFunc, EgorSolver, to_xtypes};
//! use egobox_moe::MoeParams;
//! use egobox_moe::GpMixParams;
//! use rand_xoshiro::Xoshiro256Plus;
//! use ndarray_rand::rand::SeedableRng;
//! use argmin::core::Executor;
Expand Down Expand Up @@ -95,7 +95,7 @@
//! .doe(&doe)
//! .target(-5.5080);
//!
//! let solver: EgorSolver<MoeParams<f64, Xoshiro256Plus>> =
//! let solver: EgorSolver<GpMixParams<f64, Xoshiro256Plus>> =
//! EgorSolver::new(config, rng);
//!
//! let res = Executor::new(fobj, solver)
Expand All @@ -117,7 +117,7 @@ use crate::types::*;
use crate::utils::{compute_cstr_scales, update_data};

use egobox_doe::{Lhs, LhsKind, SamplingMethod};
use egobox_moe::{ClusteredSurrogate, Clustering, CorrelationSpec, MoeParams, RegressionSpec};
use egobox_moe::{ClusteredSurrogate, Clustering, CorrelationSpec, GpMixParams, RegressionSpec};
use env_logger::{Builder, Env};
use finitediff::FiniteDiff;
use linfa::ParamGuard;
Expand Down Expand Up @@ -164,14 +164,14 @@ pub struct EgorSolver<SB: SurrogateBuilder> {
pub(crate) rng: Xoshiro256Plus,
}

impl SurrogateBuilder for MoeParams<f64, Xoshiro256Plus> {
impl SurrogateBuilder for GpMixParams<f64, Xoshiro256Plus> {
/// Constructor from domain space specified with types
/// **panic** if xtypes contains other types than continuous type `Float`
fn new_with_xtypes(xtypes: &[XType]) -> Self {
if crate::utils::discrete(xtypes) {
panic!("MoeParams cannot be created with discrete types!");
panic!("GpMixParams cannot be created with discrete types!");
}
MoeParams::new()
GpMixParams::new()
}

/// Sets the allowed regression models used in gaussian processes.
Expand Down
26 changes: 13 additions & 13 deletions ego/src/mixint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::errors::{EgoError, Result};
use crate::types::{SurrogateBuilder, XType};
use egobox_doe::{FullFactorial, Lhs, Random};
use egobox_moe::{
Clustered, ClusteredSurrogate, Clustering, CorrelationSpec, FullGpSurrogate, GpMixture,
GpSurrogate, MoeParams, RegressionSpec,
Clustered, ClusteredSurrogate, Clustering, CorrelationSpec, FullGpSurrogate, GpMixParams,
GpMixture, GpSurrogate, RegressionSpec,
};
use linfa::traits::{Fit, PredictInplace};
use linfa::{DatasetBase, Float, ParamGuard};
Expand Down Expand Up @@ -279,14 +279,14 @@ impl<F: Float, S: egobox_doe::SamplingMethod<F>> egobox_doe::SamplingMethod<F>
}

/// Moe type builder for mixed-integer Egor optimizer
pub type MoeBuilder = MoeParams<f64, Xoshiro256Plus>;
pub type MoeBuilder = GpMixParams<f64, Xoshiro256Plus>;
/// A decorator of Moe surrogate builder that takes into account XType specifications
///
/// It allows to implement continuous relaxation over continuous Moe builder.
#[derive(Clone, Serialize, Deserialize)]
pub struct MixintMoeValidParams {
/// The surrogate factory
surrogate_builder: MoeParams<f64, Xoshiro256Plus>,
surrogate_builder: GpMixParams<f64, Xoshiro256Plus>,
/// The input specifications
xtypes: Vec<XType>,
/// whether data are in given in folded space (enum indexes) or not (enum masks)
Expand All @@ -309,12 +309,12 @@ impl MixintMoeValidParams {

/// Parameters for mixture of experts surrogate model
#[derive(Clone, Serialize, Deserialize)]
pub struct MixintMoeParams(MixintMoeValidParams);
pub struct MixintGpMixParams(MixintMoeValidParams);

impl MixintMoeParams {
impl MixintGpMixParams {
/// Constructor given `xtypes` specifications and given surrogate builder
pub fn new(xtypes: &[XType], surrogate_builder: &MoeBuilder) -> Self {
MixintMoeParams(MixintMoeValidParams {
MixintGpMixParams(MixintMoeValidParams {
surrogate_builder: surrogate_builder.clone(),
xtypes: xtypes.to_vec(),
work_in_folded_space: false,
Expand Down Expand Up @@ -384,9 +384,9 @@ impl MixintMoeValidParams {
}
}

impl SurrogateBuilder for MixintMoeParams {
impl SurrogateBuilder for MixintGpMixParams {
fn new_with_xtypes(xtypes: &[XType]) -> Self {
MixintMoeParams::new(xtypes, &MoeParams::new())
MixintGpMixParams::new(xtypes, &GpMixParams::new())
}

/// Sets the allowed regression models used in gaussian processes.
Expand Down Expand Up @@ -468,7 +468,7 @@ impl<D: Data<Elem = f64>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, EgoError>
}
}

impl ParamGuard for MixintMoeParams {
impl ParamGuard for MixintGpMixParams {
type Checked = MixintMoeValidParams;
type Error = EgoError;

Expand All @@ -482,9 +482,9 @@ impl ParamGuard for MixintMoeParams {
}
}

impl From<MixintMoeValidParams> for MixintMoeParams {
impl From<MixintMoeValidParams> for MixintGpMixParams {
fn from(item: MixintMoeValidParams) -> Self {
MixintMoeParams(item)
MixintGpMixParams(item)
}
}

Expand Down Expand Up @@ -711,7 +711,7 @@ impl MixintContext {
surrogate_builder: &MoeBuilder,
dataset: &DatasetBase<Array2<f64>, Array2<f64>>,
) -> Result<MixintMoe> {
let mut params = MixintMoeParams::new(&self.xtypes, surrogate_builder);
let mut params = MixintGpMixParams::new(&self.xtypes, surrogate_builder);
let params = params.work_in_folded_space(self.work_in_folded_space);
params.fit(dataset)
}
Expand Down
12 changes: 6 additions & 6 deletions moe/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::clustering::{find_best_number_of_clusters, sort_by_cluster};
use crate::errors::MoeError;
use crate::errors::Result;
use crate::expertise_macros::*;
use crate::parameters::{MoeParams, MoeValidParams};
use crate::parameters::{GpMixParams, GpMixValidParams};
use crate::surrogates::*;
use crate::types::*;

Expand Down Expand Up @@ -47,7 +47,7 @@ macro_rules! check_allowed {
}

impl<D: Data<Elem = f64>, R: Rng + SeedableRng + Clone>
Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, MoeError> for MoeValidParams<f64, R>
Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix2>, MoeError> for GpMixValidParams<f64, R>
{
type Object = GpMixture;

Expand All @@ -68,7 +68,7 @@ impl<D: Data<Elem = f64>, R: Rng + SeedableRng + Clone>
}
}

impl<R: Rng + SeedableRng + Clone> MoeValidParams<f64, R> {
impl<R: Rng + SeedableRng + Clone> GpMixValidParams<f64, R> {
pub fn train(
&self,
xt: &ArrayBase<impl Data<Elem = f64>, Ix2>,
Expand Down Expand Up @@ -174,7 +174,7 @@ impl<R: Rng + SeedableRng + Clone> MoeValidParams<f64, R> {
let ytest = test.slice(s![.., nx..]).to_owned();
let factor = self.optimize_heaviside_factor(&experts, gmx, &xtest, &ytest);
info!("Retrain mixture with optimized heaviside factor={}", factor);
let moe = MoeParams::from(self.clone())
let moe = GpMixParams::from(self.clone())
.n_clusters(gmx.n_clusters())
.recombination(Recombination::Smooth(Some(factor)))
.check()?
Expand Down Expand Up @@ -463,8 +463,8 @@ impl ClusteredSurrogate for GpMixture {}

impl GpMixture {
/// Constructor of mixture of experts parameters
pub fn params() -> MoeParams<f64, Xoshiro256Plus> {
MoeParams::new()
pub fn params() -> GpMixParams<f64, Xoshiro256Plus> {
GpMixParams::new()
}

/// Recombination mode
Expand Down
4 changes: 2 additions & 2 deletions moe/src/clustering.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![allow(dead_code)]
use crate::parameters::MoeParams;
use crate::parameters::GpMixParams;
use crate::types::*;
use log::{debug, info};

Expand Down Expand Up @@ -122,7 +122,7 @@ pub fn find_best_number_of_clusters<R: Rng + Clone>(
let gmm = Box::new(gmm);
// Cross Validation
for (train, valid) in dataset.fold(5).into_iter() {
if let Ok(mixture) = MoeParams::default()
if let Ok(mixture) = GpMixParams::default()
.n_clusters(n_clusters)
.regression_spec(regression_spec)
.correlation_spec(correlation_spec)
Expand Down
4 changes: 2 additions & 2 deletions moe/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
//!
//! ```no_run
//! use ndarray::{Array2, Array1, Zip, Axis};
//! use egobox_moe::{Moe, Recombination};
//! use egobox_moe::{GpMixture, Recombination};
//! use ndarray_rand::{RandomExt, rand::SeedableRng, rand_distr::Uniform};
//! use rand_xoshiro::Xoshiro256Plus;
//! use linfa::{traits::Fit, ParamGuard, Dataset};
Expand Down Expand Up @@ -70,7 +70,7 @@
//!
//! // Predictions
//! let observations = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
//! let predictions = Moe::params()
//! let predictions = GpMixture::params()
//! .n_clusters(3)
//! .recombination(Recombination::Hard)
//! .fit(&ds)
Expand Down
44 changes: 22 additions & 22 deletions moe/src/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use serde::{Deserialize, Serialize};
/// Mixture of experts checked parameters
#[derive(Clone)]
#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
pub struct MoeValidParams<F: Float, R: Rng + Clone> {
pub struct GpMixValidParams<F: Float, R: Rng + Clone> {
/// Number of clusters (i.e. number of experts)
n_clusters: usize,
/// [Recombination] mode
Expand All @@ -40,9 +40,9 @@ pub struct MoeValidParams<F: Float, R: Rng + Clone> {
rng: R,
}

impl<F: Float, R: Rng + SeedableRng + Clone> Default for MoeValidParams<F, R> {
fn default() -> MoeValidParams<F, R> {
MoeValidParams {
impl<F: Float, R: Rng + SeedableRng + Clone> Default for GpMixValidParams<F, R> {
fn default() -> GpMixValidParams<F, R> {
GpMixValidParams {
n_clusters: 1,
recombination: Recombination::Smooth(Some(F::one())),
regression_spec: RegressionSpec::ALL,
Expand All @@ -55,7 +55,7 @@ impl<F: Float, R: Rng + SeedableRng + Clone> Default for MoeValidParams<F, R> {
}
}

impl<F: Float, R: Rng + Clone> MoeValidParams<F, R> {
impl<F: Float, R: Rng + Clone> GpMixValidParams<F, R> {
/// The number of clusters, hence the number of experts of the mixture.
pub fn n_clusters(&self) -> usize {
self.n_clusters
Expand Down Expand Up @@ -101,15 +101,15 @@ impl<F: Float, R: Rng + Clone> MoeValidParams<F, R> {
/// Mixture of experts parameters
#[derive(Clone)]
#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
pub struct MoeParams<F: Float, R: Rng + Clone>(MoeValidParams<F, R>);
pub struct GpMixParams<F: Float, R: Rng + Clone>(GpMixValidParams<F, R>);

impl<F: Float> Default for MoeParams<F, Xoshiro256Plus> {
fn default() -> MoeParams<F, Xoshiro256Plus> {
MoeParams(MoeValidParams::default())
impl<F: Float> Default for GpMixParams<F, Xoshiro256Plus> {
fn default() -> GpMixParams<F, Xoshiro256Plus> {
GpMixParams(GpMixValidParams::default())
}
}

impl<F: Float> MoeParams<F, Xoshiro256Plus> {
impl<F: Float> GpMixParams<F, Xoshiro256Plus> {
/// Constructor of Moe parameters with `n_clusters`.
///
/// Default values are provided as follows:
Expand All @@ -119,17 +119,17 @@ impl<F: Float> MoeParams<F, Xoshiro256Plus> {
/// * correlation_spec: `ALL`
/// * kpls_dim: `None`
#[allow(clippy::new_ret_no_self)]
pub fn new() -> MoeParams<F, Xoshiro256Plus> {
pub fn new() -> GpMixParams<F, Xoshiro256Plus> {
Self::new_with_rng(Xoshiro256Plus::from_entropy())
}
}

impl<F: Float, R: Rng + Clone> MoeParams<F, R> {
impl<F: Float, R: Rng + Clone> GpMixParams<F, R> {
/// Constructor of Moe parameters specifying randon number generator for reproducibility
///
/// See [MoeParams::new] for default parameters.
pub fn new_with_rng(rng: R) -> MoeParams<F, R> {
Self(MoeValidParams {
/// See [GpMixParams::new] for default parameters.
pub fn new_with_rng(rng: R) -> GpMixParams<F, R> {
Self(GpMixValidParams {
n_clusters: 1,
recombination: Recombination::Smooth(Some(F::one())),
regression_spec: RegressionSpec::ALL,
Expand Down Expand Up @@ -198,8 +198,8 @@ impl<F: Float, R: Rng + Clone> MoeParams<F, R> {
}

/// Sets the random number generator for reproducibility
pub fn with_rng<R2: Rng + Clone>(self, rng: R2) -> MoeParams<F, R2> {
MoeParams(MoeValidParams {
pub fn with_rng<R2: Rng + Clone>(self, rng: R2) -> GpMixParams<F, R2> {
GpMixParams(GpMixValidParams {
n_clusters: self.0.n_clusters(),
recombination: self.0.recombination(),
regression_spec: self.0.regression_spec(),
Expand All @@ -212,8 +212,8 @@ impl<F: Float, R: Rng + Clone> MoeParams<F, R> {
}
}

impl<F: Float, R: Rng + Clone> ParamGuard for MoeParams<F, R> {
type Checked = MoeValidParams<F, R>;
impl<F: Float, R: Rng + Clone> ParamGuard for GpMixParams<F, R> {
type Checked = GpMixValidParams<F, R>;
type Error = MoeError;

fn check_ref(&self) -> Result<&Self::Checked> {
Expand All @@ -233,8 +233,8 @@ impl<F: Float, R: Rng + Clone> ParamGuard for MoeParams<F, R> {
}
}

impl<F: Float, R: Rng + Clone> From<MoeValidParams<F, R>> for MoeParams<F, R> {
fn from(item: MoeValidParams<F, R>) -> Self {
MoeParams(item)
impl<F: Float, R: Rng + Clone> From<GpMixValidParams<F, R>> for GpMixParams<F, R> {
fn from(item: GpMixValidParams<F, R>) -> Self {
GpMixParams(item)
}
}

0 comments on commit e8a013d

Please sign in to comment.