Skip to content

Commit ef683c8

Browse files
Julien RousselJulien Roussel
authored andcommitted
comparator holes are now identical for all imputers
1 parent 3cde1b4 commit ef683c8

File tree

9 files changed

+80
-30
lines changed

9 files changed

+80
-30
lines changed

qolmat/benchmark/comparator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import pandas as pd
8+
from sklearn import utils as sku
89

910
from qolmat.benchmark import hyperparameters, metrics
1011
from qolmat.benchmark.missing_patterns import _HoleGenerator
@@ -169,8 +170,12 @@ def compare(
169170
170171
"""
171172
dict_errors = {}
172-
173+
self.generator_holes.random_state = sku.check_random_state(
174+
self.generator_holes.random_state
175+
)
176+
self.generator_holes.save_rng_state()
173177
for name, imputer in self.dict_imputers.items():
178+
self.generator_holes.load_rng_state()
174179
dict_config_opti_imputer = self.dict_config_opti.get(name, {})
175180

176181
try:

qolmat/benchmark/missing_patterns.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ def _check_subset(self, X: pd.DataFrame):
190190
elif isinstance(self.subset, str):
191191
raise SubsetIsAString(self.subset)
192192

193+
def save_rng_state(self):
194+
self.state_rng = self.random_state.get_state()
195+
196+
def load_rng_state(self):
197+
self.random_state.set_state(self.state_rng)
198+
193199

194200
class UniformHoleGenerator(_HoleGenerator):
195201
"""UniformHoleGenerator class.

qolmat/imputations/em_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def __init__(
190190
self.n_iter_ou = n_iter_ou
191191
self.ampli = ampli
192192
self.rng = sku.check_random_state(random_state)
193-
self.cov = np.array([[]])
194193
self.dt = dt
195194
self.tolerance = tolerance
196195
self.stagnation_threshold = stagnation_threshold
@@ -657,6 +656,7 @@ def __init__(
657656
period=period,
658657
verbose=verbose,
659658
)
659+
self.cov = np.array([[]])
660660
self.dict_criteria_stop = {"logliks": [], "means": [], "covs": []}
661661

662662
def get_loglikelihood(self, X: NDArray) -> float:

qolmat/utils/input_check.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Util file for input checks."""
2+
23
import pandas as pd
34

45
from qolmat.utils.exceptions import TypeNotHandled

qolmat/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _get_numerical_features(df1: pd.DataFrame) -> List[str]:
3333
"""
3434
cols_numerical = df1.select_dtypes(include=np.number).columns.tolist()
3535
if len(cols_numerical) == 0:
36+
print(df1)
3637
raise Exception("No numerical feature is found.")
3738
else:
3839
return cols_numerical

tests/analysis/test_holes_characterization.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pandas as pd
33
import pytest
44
from scipy.stats import norm
5+
from sklearn import utils as sku
56

67
from qolmat.analysis.holes_characterization import LittleTest, PKLMTest
78
from qolmat.benchmark.missing_patterns import UniformHoleGenerator
@@ -12,7 +13,7 @@
1213

1314
@pytest.fixture
1415
def mcar_df() -> pd.DataFrame:
15-
rng = np.random.default_rng(42)
16+
rng = sku.check_random_state(42)
1617
matrix = rng.multivariate_normal(
1718
mean=[0, 0], cov=[[1, 0], [0, 1]], size=200
1819
)
@@ -26,7 +27,7 @@ def mcar_df() -> pd.DataFrame:
2627

2728
@pytest.fixture
2829
def mar_hm_df() -> pd.DataFrame:
29-
rng = np.random.default_rng(42)
30+
rng = sku.check_random_state(42)
3031
matrix = rng.multivariate_normal(
3132
mean=[0, 0], cov=[[1, 0], [0, 1]], size=200
3233
)
@@ -42,7 +43,7 @@ def mar_hm_df() -> pd.DataFrame:
4243

4344
@pytest.fixture
4445
def mar_hc_df() -> pd.DataFrame:
45-
rng = np.random.default_rng(42)
46+
rng = sku.check_random_state(42)
4647
matrix = rng.multivariate_normal(
4748
mean=[0, 0], cov=[[1, 0], [0, 1]], size=200
4849
)
@@ -88,7 +89,7 @@ def supported_multitypes_dataframe() -> pd.DataFrame:
8889

8990
@pytest.fixture
9091
def np_matrix_with_nan_mcar() -> np.ndarray:
91-
rng = np.random.default_rng(42)
92+
rng = sku.check_random_state(42)
9293
n_rows, n_cols = 10, 4
9394
matrix = rng.normal(size=(n_rows, n_cols))
9495
num_nan = int(n_rows * n_cols * 0.40)
@@ -104,7 +105,7 @@ def missingness_matrix_mcar(np_matrix_with_nan_mcar):
104105

105106
@pytest.fixture
106107
def missingness_matrix_mcar_perm(missingness_matrix_mcar):
107-
rng = np.random.default_rng(42)
108+
rng = sku.check_random_state(42)
108109
return rng.permutation(missingness_matrix_mcar)
109110

110111

tests/benchmark/test_comparator.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,28 @@
22

33
import numpy as np
44
import pandas as pd
5+
import pytest
56

67
from qolmat.benchmark.comparator import Comparator
8+
from qolmat.benchmark.missing_patterns import UniformHoleGenerator
9+
from qolmat.imputations.imputers import ImputerShuffle
710

8-
generator_holes_mock = MagicMock()
9-
generator_holes_mock.split.return_value = [
10-
pd.DataFrame({"A": [False, False, True], "B": [True, False, False]})
11-
]
1211

13-
comparator = Comparator(
14-
dict_models={},
15-
selected_columns=["A", "B"],
16-
generator_holes=generator_holes_mock,
17-
metrics=["mae", "mse"],
18-
)
12+
@pytest.fixture
13+
def comparator_fix():
14+
generator_holes_mock = MagicMock()
15+
generator_holes_mock.split.return_value = [
16+
pd.DataFrame({"A": [False, False, True], "B": [True, False, False]})
17+
]
18+
generator_holes_mock.random_state = 0
19+
comparator = Comparator(
20+
dict_models={},
21+
selected_columns=["A", "B"],
22+
generator_holes=generator_holes_mock,
23+
metrics=["mae", "mse"],
24+
)
25+
return comparator
26+
1927

2028
imputer_mock = MagicMock()
2129
expected_get_errors = pd.Series(
@@ -27,7 +35,7 @@
2735

2836

2937
@patch("qolmat.benchmark.metrics.get_metric")
30-
def test_get_errors(mock_get_metric):
38+
def test_get_errors(mock_get_metric, comparator_fix):
3139
df_origin = pd.DataFrame({"A": [1, np.nan, 3], "B": [np.nan, 5, 6]})
3240
df_imputed = pd.DataFrame({"A": [1, 2, 4], "B": [4, 5, 7]})
3341
df_mask = pd.DataFrame(
@@ -39,7 +47,7 @@ def test_get_errors(mock_get_metric):
3947
[1.0, 1.0], index=["A", "B"]
4048
)
4149
)
42-
errors = comparator.get_errors(df_origin, df_imputed, df_mask)
50+
errors = comparator_fix.get_errors(df_origin, df_imputed, df_mask)
4351
pd.testing.assert_series_equal(errors, expected_get_errors)
4452

4553

@@ -48,8 +56,10 @@ def test_get_errors(mock_get_metric):
4856
"qolmat.benchmark.comparator.Comparator.get_errors",
4957
return_value=expected_get_errors,
5058
)
51-
def test_evaluate_errors_sample(mock_get_errors, mock_optimize):
52-
errors_mean = comparator.evaluate_errors_sample(
59+
def test_evaluate_errors_sample(
60+
mock_get_errors, mock_optimize, comparator_fix
61+
):
62+
errors_mean = comparator_fix.evaluate_errors_sample(
5363
imputer_mock, pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, np.nan]})
5464
)
5565
expected_errors_mean = expected_get_errors
@@ -62,12 +72,12 @@ def test_evaluate_errors_sample(mock_get_errors, mock_optimize):
6272
"qolmat.benchmark.comparator.Comparator.evaluate_errors_sample",
6373
return_value=expected_get_errors,
6474
)
65-
def test_compare(mock_evaluate_errors_sample):
75+
def test_compare(mock_evaluate_errors_sample, comparator_fix):
6676
df_test = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
6777

6878
imputer1 = MagicMock(name="Imputer1")
6979
imputer2 = MagicMock(name="Imputer2")
70-
comparator.dict_imputers = {"imputer1": imputer1, "imputer2": imputer2}
80+
comparator_fix.dict_imputers = {"imputer1": imputer1, "imputer2": imputer2}
7181

7282
errors_imputer1 = pd.Series([0.1, 0.2], index=["mae", "mse"])
7383
errors_imputer2 = pd.Series([0.3, 0.4], index=["mae", "mse"])
@@ -76,7 +86,7 @@ def test_compare(mock_evaluate_errors_sample):
7686
errors_imputer2,
7787
]
7888

79-
df_errors = comparator.compare(df_test)
89+
df_errors = comparator_fix.compare(df_test)
8090
assert mock_evaluate_errors_sample.call_count == 2
8191

8292
mock_evaluate_errors_sample.assert_any_call(imputer1, df_test, {}, "mse")
@@ -85,3 +95,28 @@ def test_compare(mock_evaluate_errors_sample):
8595
{"imputer1": [0.1, 0.2], "imputer2": [0.3, 0.4]}, index=["mae", "mse"]
8696
)
8797
pd.testing.assert_frame_equal(df_errors, expected_df_errors)
98+
99+
100+
def test_compare_reproducibility():
101+
seed = 123
102+
dict_models = {
103+
"shuffle1": ImputerShuffle(random_state=seed),
104+
"shuffle2": ImputerShuffle(random_state=seed),
105+
}
106+
cols = ["A", "B"]
107+
df_data = pd.DataFrame(
108+
np.random.random((100, 2)), dtype=float, columns=cols
109+
)
110+
generator_holes = UniformHoleGenerator(
111+
n_splits=2, subset=cols, ratio_masked=0.5
112+
)
113+
comparator = Comparator(
114+
dict_models=dict_models,
115+
selected_columns=df_data.columns,
116+
generator_holes=generator_holes,
117+
metrics=["mae", "mse"],
118+
)
119+
df_errors = comparator.compare(df_data)
120+
pd.testing.assert_series_equal(
121+
df_errors["shuffle1"], df_errors["shuffle2"], check_names=False
122+
)

tests/imputations/test_em_sampler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import scipy
66
from numpy.typing import NDArray
77
from scipy import linalg
8+
from sklearn import utils as sku
89
from sklearn.datasets import make_spd_matrix
910

1011
from qolmat.imputations import em_sampler
@@ -31,8 +32,8 @@
3132

3233
# @pytest.fixture
3334
def generate_multinormal_predefined_mean_cov(d=3, n=500):
34-
rng = np.random.default_rng(42)
35-
seed = rng.integers(np.iinfo(np.int32).max)
35+
rng = sku.check_random_state(42)
36+
seed = rng.randint(np.iinfo(np.int32).max)
3637
random_state = np.random.RandomState(seed=seed)
3738
mean = np.array([rng.uniform(low=0, high=d) for _ in range(d)])
3839
covariance = make_spd_matrix(n_dim=d, random_state=random_state)
@@ -51,7 +52,7 @@ def generate_multinormal_predefined_mean_cov(d=3, n=500):
5152

5253

5354
def get_matrix_B(d, p, eigmax=1):
54-
rng = np.random.default_rng(42)
55+
rng = sku.check_random_state(42)
5556
B = rng.normal(0, 1, size=(d * p + 1, d))
5657
U, S, Vt = linalg.svd(B, check_finite=False, full_matrices=False)
5758
S = rng.uniform(0, eigmax, size=d)
@@ -60,8 +61,8 @@ def get_matrix_B(d, p, eigmax=1):
6061

6162

6263
def generate_varp_process(d=3, n=10000, p=1):
63-
rng = np.random.default_rng(42)
64-
seed = rng.integers(np.iinfo(np.int32).max)
64+
rng = sku.check_random_state(42)
65+
seed = rng.randint(np.iinfo(np.int32).max)
6566
random_state = np.random.RandomState(seed=seed)
6667
B = get_matrix_B(d, p, eigmax=0.9)
6768
nu = B[0, :]
@@ -434,7 +435,7 @@ def test_gradient_X_loglik(em: em_sampler.EM, p: int):
434435
d = 3
435436
X, _, _, _ = generate_varp_process(d=d, n=10, p=p)
436437
em.fit_parameters(X)
437-
rng = np.random.default_rng(42)
438+
rng = sku.check_random_state(42)
438439
X0 = rng.uniform(0, 10, size=X.shape)
439440
# X0 = X
440441
loglik = em.get_loglikelihood(X0)
File renamed without changes.

0 commit comments

Comments
 (0)