Skip to content

Commit 3e3b77f

Browse files
Julien RousselJulien Roussel
authored andcommitted
model attributes correclty accessed
1 parent a9598df commit 3e3b77f

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

qolmat/imputations/diffusions/ddpms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import time
55
from datetime import timedelta
6-
from typing import Callable, Dict, List, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Tuple, Union
77

88
import numpy as np
99
import pandas as pd
@@ -132,7 +132,7 @@ def __init__(
132132
seed_torch = self.random_state.randint(2**31 - 1)
133133
torch.manual_seed(seed_torch)
134134

135-
def __getstate__(self) -> str:
135+
def __getstate__(self) -> dict[str, Any]:
136136
"""Hashing method used in sklearn check tests.
137137
138138
Returns

qolmat/imputations/imputers_pytorch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,8 @@ def get_summary_training(self) -> Dict:
876876
Summary of the training
877877
878878
"""
879-
return self.model.summary
879+
model = self.get_model()
880+
return model.summary
880881

881882
def get_summary_architecture(self) -> Dict:
882883
"""Get the summary of the architecture.
@@ -887,7 +888,8 @@ def get_summary_architecture(self) -> Dict:
887888
Summary of the architecture
888889
889890
"""
891+
model = self.get_model()
890892
return {
891-
"number_parameters": self.model.num_params,
892-
"epsilon_model": self.model._eps_model,
893+
"number_parameters": model.num_params,
894+
"epsilon_model": model._eps_model,
893895
}

0 commit comments

Comments
 (0)