Skip to content

Commit 37564b4

Browse files
Julien RousselJulien Roussel
authored andcommitted
sphinx doc built
1 parent 9c8d782 commit 37564b4

File tree

4 files changed

+36
-35
lines changed

4 files changed

+36
-35
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ codecov = "^2.1.13"
7474

7575
[tool.poetry.group.docs.dependencies]
7676
numpydoc = "1.1.0"
77-
sphinx = "4.3.2"
77+
sphinx = ">= 5.0"
7878
sphinx-gallery = "0.10.1"
7979
sphinx_rtd_theme = "1.0.0"
8080
sphinx_markdown_tables = "0.0.17"

qolmat/imputations/diffusions/ddpms.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
ResidualBlock,
2121
ResidualBlockTS,
2222
)
23-
from qolmat.imputations.diffusions.utils import get_num_params
2423

2524
logging.basicConfig(
2625
format="%(asctime)s %(levelname)-8s %(message)s",
@@ -176,8 +175,8 @@ def _q_sample(
176175
epsilon = torch.randn_like(x, device=self.device)
177176
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon
178177

179-
def _set_eps_model(self) -> None:
180-
self._eps_model = AutoEncoder(
178+
def _get_eps_model(self) -> AutoEncoder:
179+
model = AutoEncoder(
181180
num_noise_steps=self.num_noise_steps,
182181
dim_input=self.dim_input,
183182
residual_block=ResidualBlock(
@@ -186,12 +185,35 @@ def _set_eps_model(self) -> None:
186185
dim_embedding=self.dim_embedding,
187186
num_blocks=self.num_blocks,
188187
p_dropout=self.p_dropout,
189-
).to(self.device)
188+
)
189+
return model
190+
191+
def _set_eps_model(self) -> None:
192+
model = self._get_eps_model()
193+
self._eps_model = model.to(self.device)
190194

191195
self.optimiser = torch.optim.Adam(
192196
self._eps_model.parameters(), lr=self.lr
193197
)
194198

199+
def get_num_params(self) -> int:
200+
"""Compute the number of parameters of the underlying model.
201+
202+
Returns
203+
-------
204+
int: Number of parameters if the model has been fitted,
205+
0 otherwise.
206+
207+
"""
208+
if hasattr(self, "_eps_model"):
209+
model_parameters = filter(
210+
lambda p: p.requires_grad, self._eps_model.parameters()
211+
)
212+
params = sum([np.prod(p.size()) for p in model_parameters])
213+
return int(params)
214+
else:
215+
return 0
216+
195217
def _print_valid(self, epoch: int, time_duration: float) -> None:
196218
"""Print model performance on validation data.
197219
@@ -206,8 +228,9 @@ def _print_valid(self, epoch: int, time_duration: float) -> None:
206228
self.time_durations.append(time_duration)
207229
print_step = 1 if int(self.epochs / 10) == 0 else int(self.epochs / 10)
208230
if self.print_valid and epoch == 0:
231+
n_params = self.get_num_params()
209232
logging.info(
210-
f"Num params of {self.__class__.__name__}: {self.num_params}"
233+
f"Num params of {self.__class__.__name__}: {n_params}"
211234
)
212235
if self.print_valid and epoch % print_step == 0:
213236
string_valid = f"Epoch {epoch}: "
@@ -526,7 +549,6 @@ def fit(
526549
)
527550

528551
self._set_eps_model()
529-
self.num_params: int = get_num_params(self._eps_model)
530552
self.summary: Dict[str, List] = {
531553
"epoch_loss": [],
532554
}

qolmat/imputations/diffusions/utils.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

qolmat/imputations/imputers_pytorch.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Script for pytroch imputers."""
22

33
import logging
4+
from copy import copy
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
56

67
import numpy as np
@@ -809,6 +810,7 @@ def _fit_element(
809810
model = self.get_model()
810811
hp_fit = self._get_params_fit()
811812
model = model.fit(df, **hp_fit)
813+
self.model_last = copy(model)
812814
return model
813815

814816
def _transform_element(
@@ -876,7 +878,7 @@ def get_summary_training(self) -> Dict:
876878
Summary of the training
877879
878880
"""
879-
model = self.get_model()
881+
model = self.model_last
880882
return model.summary
881883

882884
def get_summary_architecture(self) -> Dict:
@@ -888,8 +890,9 @@ def get_summary_architecture(self) -> Dict:
888890
Summary of the architecture
889891
890892
"""
891-
model = self.get_model()
893+
model = self.model_last
894+
eps_model = model._get_eps_model()
892895
return {
893-
"number_parameters": model.num_params,
894-
"epsilon_model": model._eps_model,
896+
"number_parameters": model.get_num_params(),
897+
"epsilon_model": eps_model,
895898
}

0 commit comments

Comments
 (0)