Skip to content

Commit 0f8bf0a

Browse files
Julien RousselJulien Roussel
authored andcommitted
data validation issue patched
1 parent 9dde1aa commit 0f8bf0a

File tree

5 files changed

+11
-10
lines changed

5 files changed

+11
-10
lines changed

qolmat/benchmark/missing_patterns.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,15 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
247247
ratio_masked = self.ratio_masked
248248
if self.sample_proportional:
249249
ratio_masked *= self.dict_ratios[col] * len(X.columns)
250-
n_masked_col = math.ceil(self.ratio_masked * len(X))
251-
indices = np.where(X[col].notna())[0]
252-
indices = self.random_state.choice(
253-
indices,
250+
n_masked_col = math.ceil(ratio_masked * len(X))
251+
indices_int = np.where(X[col].notna())[0]
252+
indices_int = self.random_state.choice(
253+
indices_int,
254254
replace=False,
255255
size=n_masked_col,
256256
)
257-
df_mask[col].iloc[indices] = True
257+
indices_int = df_mask.index[indices_int]
258+
df_mask.loc[indices_int, col] = True
258259

259260
return df_mask
260261

qolmat/imputations/preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def fit(self, X: NDArray, y: Optional[NDArray] = None) -> "BinTransformer":
167167
X,
168168
accept_sparse=False,
169169
ensure_all_finite="allow-nan",
170-
reset=False,
170+
reset=True,
171171
dtype=["float", "int", "string", "categorical", "object"],
172172
)
173173
df = utils._validate_input(X)

qolmat/utils/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ def get_data(
198198
df = pd.read_csv(path)
199199
df = df[
200200
["Survived", "Sex", "Age", "SibSp", "Parch", "Fare", "Embarked"]
201-
]
201+
].copy()
202202
df["Age"] = pd.to_numeric(df["Age"], errors="coerce")
203-
df["Fare"] = pd.to_numeric(df["Fare"], errors="coerce")
203+
df.loc["Fare"] = pd.to_numeric(df["Fare"], errors="coerce")
204204
return df
205205
elif name_data == "Artificial":
206206
city = "Wonderland"

tests/imputations/test_imputers_diffusions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"col5": [93, 75, np.nan, 12, np.nan],
2727
},
2828
index=pd.date_range("2023-04-17", periods=5, freq="D"),
29+
dtype=float,
2930
)
3031
df_incomplete.index = df_incomplete.index.set_names("datetime")
3132

tests/utils/test_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ def test_data_get_data(
336336
"Beijing_online",
337337
"Superconductor_online",
338338
"Monach_weather",
339-
"Monach_weather",
340339
"Monach_electricity_australia",
341340
"Titanic",
342341
"SNCF",
@@ -372,7 +371,7 @@ def test_data_get_data(
372371
pd.testing.assert_frame_equal(df_result, df_monach_elec_preprocess)
373372
elif name_data == "Titanic":
374373
assert mock_read_dl.call_count == 1
375-
assert np.shape(df_result) == (3, 7)
374+
assert np.shape(df_result) == (4, 7)
376375
elif name_data == "SNCF":
377376
assert not df_result.empty
378377
assert df_result.index.name == "station"

0 commit comments

Comments
 (0)