Skip to content

Commit

Permalink
Renaming Python SparseGpMix
Browse files Browse the repository at this point in the history
  • Loading branch information
relf committed Jan 25, 2024
1 parent 5a7961d commit d6ff6a2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import egobox as egx
import logging
import time

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -32,7 +33,11 @@ def test_sgp(self):
random_idx = rng.permutation(nt)[:n_inducing]
Z = xt[random_idx].copy()

egx.GpSparse(z=Z).fit(xt, yt)
start = time.time()
sgp = egx.GpSparse(z=Z).fit(xt, yt)
elapsed = time.time() - start
print(elapsed)
sgp.save("sgp.json")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

mod egor;
mod gp_mix;
mod gp_sparse;
mod sampling;
mod sparse_gp_mix;
pub(crate) mod types;

use egor::*;
use gp_mix::*;
use gp_sparse::*;
use sampling::*;
use sparse_gp_mix::*;
use types::*;

use pyo3::prelude::*;
Expand Down Expand Up @@ -40,7 +40,7 @@ fn egobox(_py: Python, m: &PyModule) -> PyResult<()> {
// Surrogate Model
m.add_class::<GpMix>()?;
m.add_class::<Gpx>()?;
m.add_class::<GpSparse>()?;
m.add_class::<SparseGpMix>()?;
m.add_class::<Gps>()?;

// Optimizer
Expand Down
10 changes: 5 additions & 5 deletions src/gp_sparse.rs → src/sparse_gp_mix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use rand_xoshiro::Xoshiro256Plus;
/// Random generator seed to allow computation reproducibility.
///
#[pyclass]
pub(crate) struct GpSparse {
pub(crate) struct SparseGpMix {
pub correlation_spec: CorrelationSpec,
pub kpls_dim: Option<usize>,
pub nz: Option<usize>,
Expand All @@ -63,7 +63,7 @@ pub(crate) struct GpSparse {
}

#[pymethods]
impl GpSparse {
impl SparseGpMix {
#[new]
#[pyo3(signature = (
corr_spec = CorrelationSpec::SQUARED_EXPONENTIAL,
Expand All @@ -80,7 +80,7 @@ impl GpSparse {
z: Option<PyReadonlyArray2<f64>>,
seed: Option<u64>,
) -> Self {
GpSparse {
SparseGpMix {
correlation_spec: CorrelationSpec(corr_spec),
kpls_dim,
nz,
Expand Down Expand Up @@ -154,8 +154,8 @@ impl Gps {
nz: Option<usize>,
z: Option<PyReadonlyArray2<f64>>,
seed: Option<u64>,
) -> GpSparse {
GpSparse::new(corr_spec, kpls_dim, nz, z, seed)
) -> SparseGpMix {
SparseGpMix::new(corr_spec, kpls_dim, nz, z, seed)
}

/// Returns the String representation from serde json serializer
Expand Down

0 comments on commit d6ff6a2

Please sign in to comment.