Skip to content

Commit a9598df

Browse files
Julien RousselJulien Roussel
authored andcommitted
tests passing on 3.9 and 3.12
1 parent 25a9443 commit a9598df

File tree

8 files changed

+188
-106
lines changed

8 files changed

+188
-106
lines changed

examples/benchmark.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ from qolmat.imputations.imputers_pytorch import ImputerDiffusion
311311
from qolmat.imputations.diffusions.ddpms import TabDDPM
312312

313313
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
314-
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
314+
imputer = ImputerDiffusion(epochs=50, batch_size=1, random_state=11)
315315

316316
imputer.fit_transform(X)
317317
```
@@ -322,7 +322,7 @@ from qolmat.imputations.imputers_pytorch import ImputerDiffusion
322322
from qolmat.imputations.diffusions.ddpms import TabDDPM
323323

324324
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
325-
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
325+
imputer = ImputerDiffusion(epochs=50, batch_size=1, random_state=11)
326326

327327
imputer.fit_transform(X)
328328
```
@@ -358,7 +358,7 @@ encoder, decoder = imputers_pytorch.build_autoencoder(input_dim=n_variables,lat
358358
```python
359359
dict_imputers["MLP"] = imputer_mlp = imputers_pytorch.ImputerRegressorPyTorch(estimator=estimator, groups=('station',), epochs=500)
360360
dict_imputers["Autoencoder"] = imputer_autoencoder = imputers_pytorch.ImputerAutoencoder(encoder, decoder, max_iterations=100, epochs=100)
361-
dict_imputers["Diffusion"] = imputer_diffusion = imputers_pytorch.ImputerDiffusion(model=TabDDPM(num_sampling=5), epochs=100, batch_size=100)
361+
dict_imputers["Diffusion"] = imputer_diffusion = imputers_pytorch.ImputerDiffusion(epochs=100, batch_size=100, num_sampling=5)
362362
```
363363

364364
We can re-run the imputation model benchmark as before.

examples/tutorials/plot_tuto_diffusion_models.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
df_data_valid = df_data.iloc[:500]
7272

7373
tabddpm = ImputerDiffusion(
74-
model=TabDDPM(),
7574
epochs=10,
7675
batch_size=100,
7776
x_valid=df_data_valid,
@@ -160,12 +159,8 @@
160159
# reconstruction errors (mae) but increases distribution distance (kl_columnwise).
161160

162161
dict_imputers = {
163-
"num_sampling=5": ImputerDiffusion(
164-
model=TabDDPM(num_sampling=5), epochs=10, batch_size=100
165-
),
166-
"num_sampling=10": ImputerDiffusion(
167-
model=TabDDPM(num_sampling=10), epochs=10, batch_size=100
168-
),
162+
"num_sampling=5": ImputerDiffusion(epochs=10, batch_size=100, num_sampling=5),
163+
"num_sampling=10": ImputerDiffusion(epochs=10, batch_size=100, num_sampling=10),
169164
}
170165

171166
comparison = comparator.Comparator(
@@ -187,7 +182,7 @@
187182
#
188183
# Two important hyperparameters for processing time-series data are ``index_datetime``
189184
# and ``freq_str``.
190-
# E.g., ``ImputerDiffusion(model=TabDDPM(), index_datetime='datetime', freq_str='1D')``,
185+
# E.g., ``ImputerDiffusion(index_datetime='datetime', freq_str='1D')``,
191186
#
192187
# * ``index_datetime``: the column name of datetime in index. It must be a pandas datetime object.
193188
#
@@ -210,15 +205,16 @@
210205
# but requires a longer training/inference time.
211206

212207
dict_imputers = {
213-
"tabddpm": ImputerDiffusion(
214-
model=TabDDPM(num_sampling=5), epochs=10, batch_size=100
208+
"tabddpm": ImputerDiffusion(model="TabDDPM", epochs=10, batch_size=100, num_sampling=5
215209
),
216210
"tsddpm": ImputerDiffusion(
217-
model=TsDDPM(num_sampling=5, is_rolling=False),
211+
model="TsDDPM",
218212
epochs=10,
219213
batch_size=5,
220214
index_datetime="date",
221215
freq_str="5D",
216+
num_sampling=5,
217+
is_rolling=False
222218
),
223219
}
224220

qolmat/imputations/diffusions/ddpms.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
beta_start: float = 1e-4,
4848
beta_end: float = 0.02,
4949
lr: float = 0.001,
50-
ratio_nan: float = 0.1,
50+
ratio_masked: float = 0.1,
5151
dim_embedding: int = 128,
5252
num_blocks: int = 1,
5353
p_dropout: float = 0.0,
@@ -67,7 +67,7 @@ def __init__(
6767
Range of beta (noise scale value), by default 0.02
6868
lr : float, optional
6969
Learning rate, by default 0.001
70-
ratio_nan : float, optional
70+
ratio_masked : float, optional
7171
Ratio of artificial nan for training and validation, by default 0.1
7272
dim_embedding : int, optional
7373
Embedding dimension, by default 128
@@ -119,7 +119,7 @@ def __init__(
119119
self.loss_func = torch.nn.MSELoss(reduction="none")
120120

121121
self.lr = lr
122-
self.ratio_nan = ratio_nan
122+
self.ratio_masked = ratio_masked
123123
self.num_noise_steps = num_noise_steps
124124
self.dim_embedding = dim_embedding
125125
self.num_blocks = num_blocks
@@ -132,6 +132,21 @@ def __init__(
132132
seed_torch = self.random_state.randint(2**31 - 1)
133133
torch.manual_seed(seed_torch)
134134

135+
def __getstate__(self) -> str:
136+
"""Hashing method used in sklearn check tests.
137+
138+
Returns
139+
-------
140+
________
141+
str
142+
Hashed object containing the underlying model weights
143+
144+
"""
145+
state = self.__dict__.copy()
146+
if "optimiser" in state:
147+
state.pop("optimiser")
148+
return state
149+
135150
def _q_sample(
136151
self, x: torch.Tensor, t: torch.Tensor
137152
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -448,6 +463,9 @@ def fit(
448463
Return Self
449464
450465
"""
466+
seed_torch = self.random_state.randint(2**31 - 1)
467+
torch.manual_seed(seed_torch)
468+
451469
self.dim_input = len(x.columns)
452470
self.epochs = epochs
453471
self.batch_size = batch_size
@@ -486,7 +504,7 @@ def fit(
486504
# (with one mask)
487505
# in validation dataset
488506
x_valid_mask = missing_patterns.UniformHoleGenerator(
489-
n_splits=1, ratio_masked=self.ratio_nan
507+
n_splits=1, ratio_masked=self.ratio_masked
490508
).split(x_valid)[0]
491509
# x_valid_obs_mask is the mask for observed values
492510
x_valid_obs_mask = ~x_valid_mask
@@ -520,7 +538,7 @@ def fit(
520538
for id_batch, (x_batch, mask_x_batch) in enumerate(dataloader):
521539
mask_obs_rand = (
522540
torch.FloatTensor(mask_x_batch.size()).uniform_()
523-
> self.ratio_nan
541+
> self.ratio_masked
524542
)
525543
for col in self.cols_idx_not_imputed:
526544
mask_obs_rand[:, col] = 0.0
@@ -576,6 +594,8 @@ def predict(self, x: pd.DataFrame) -> pd.DataFrame:
576594
Imputed data
577595
578596
"""
597+
seed_torch = self.random_state.randint(2**31 - 1)
598+
torch.manual_seed(seed_torch)
579599
self._eps_model.eval()
580600

581601
x_processed, x_mask, x_indices = self._process_data(

qolmat/imputations/imputers.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,39 +2003,6 @@ def get_model(self, **hyperparams) -> softimpute.SoftImpute:
20032003

20042004
return model
20052005

2006-
# def _fit_element(
2007-
# self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0
2008-
# ) -> softimpute.SoftImpute:
2009-
# """
2010-
# Fits the imputer on `df`, at the group and/or column level depending
2011-
# on self.groups and self.columnwise.
2012-
2013-
# Parameters
2014-
# ----------
2015-
# df : pd.DataFrame
2016-
# Dataframe on which the imputer is fitted
2017-
# col : str, optional
2018-
# Column on which the imputer is fitted, by default "__all__"
2019-
# ngroup : int, optional
2020-
# Id of the group on which the method is applied
2021-
2022-
# Returns
2023-
# -------
2024-
# Any
2025-
# Return fitted SoftImpute model
2026-
2027-
# Raises
2028-
# ------
2029-
# NotDataFrame
2030-
# Input has to be a pandas.DataFrame.
2031-
# """
2032-
# self._check_dataframe(df)
2033-
# assert col == "__all__"
2034-
# hyperparams = self.get_hyperparams()
2035-
# model = softimpute.SoftImpute(random_state=self._rng, **hyperparams)
2036-
# model = model.fit(df.values)
2037-
# return model
2038-
20392006
def _transform_element(
20402007
self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0
20412008
) -> pd.DataFrame:

0 commit comments

Comments
 (0)