Skip to content

Commit 97ec7a2

Browse files
Merge pull request #349 from lvgig/feature/narwhalify_median_imputer
Feature/narwhalify median imputer
2 parents 5c2aaf3 + 628f7c1 commit 97ec7a2

File tree

3 files changed

+87
-76
lines changed

3 files changed

+87
-76
lines changed

Diff for: CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Changed
7777
- Refactored OrdinalEncoderTransformer tests in new format `#330 <https://github.com/lvgig/tubular/issues/330>`_
7878
- Narwhal-ified NullIndicator `#319 <https://github.com/lvgig/tubular/issues/319>_`
7979
- Narwhal-ified NearestMeanResponseImputer `#320 <https://github.com/lvgig/tubular/issues/320>_`
80+
- Narwhal-ified MedianImputer `#317 <https://github.com/lvgig/tubular/issues/317>_`
8081

8182

8283
1.3.1 (2024-07-18)

Diff for: tests/imputers/test_MedianImputer.py

+61-51
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import narwhals as nw
12
import numpy as np
2-
import pandas as pd
3-
import test_aide as ta
3+
import pytest
44

55
import tests.test_data as d
6+
from tests import utils as u
67
from tests.base_tests import (
78
ColumnStrListInitTests,
89
GenericFitTests,
@@ -33,68 +34,77 @@ class TestFit(WeightColumnFitMixinTests, GenericFitTests):
3334
def setup_class(cls):
3435
cls.transformer_name = "MedianImputer"
3536

36-
def test_learnt_values(self):
37+
@pytest.mark.parametrize("library", ["pandas", "polars"])
38+
def test_learnt_values(self, library):
3739
"""Test that the impute values learnt during fit are expected."""
38-
df = d.create_df_3()
39-
df["d"] = np.nan
40-
41-
x = MedianImputer(columns=["a", "b", "c", "d"])
42-
43-
x.fit(df)
44-
45-
ta.classes.test_object_attributes(
46-
obj=x,
47-
expected_attributes={
48-
"impute_values_": {
49-
"a": df["a"].median(),
50-
"b": df["b"].median(),
51-
"c": df["c"].median(),
52-
"d": np.float64(np.nan),
53-
},
54-
},
55-
msg="impute_values_ attribute",
40+
df = d.create_df_3(library=library)
41+
42+
df = nw.from_native(df)
43+
native_namespace = nw.get_native_namespace(df)
44+
45+
# replace 'a' with all null values to trigger warning
46+
df = df.with_columns(
47+
nw.new_series(
48+
name="d",
49+
values=[None] * len(df),
50+
native_namespace=native_namespace,
51+
),
5652
)
5753

58-
def test_learnt_values_weighted(self):
54+
df = df.to_native()
55+
56+
transformer = MedianImputer(columns=["a", "b", "c", "d"])
57+
58+
transformer.fit(df)
59+
60+
assert transformer.impute_values_ == {
61+
"a": df["a"].median(),
62+
"b": df["b"].median(),
63+
"c": df["c"].median(),
64+
"d": None,
65+
}, "impute_values_ attribute"
66+
67+
@pytest.mark.parametrize("library", ["pandas", "polars"])
68+
def test_learnt_values_weighted(self, library):
5969
"""Test that the impute values learnt during fit are expected - when using weights."""
60-
df = d.create_df_9()
61-
df["d"] = np.nan
62-
63-
df = pd.DataFrame(
64-
{
65-
"a": [1, 2, 4, 6],
66-
"c": [3, 2, 4, 6],
67-
"d": np.nan,
68-
},
70+
df = d.create_df_9(library=library)
71+
72+
df = nw.from_native(df)
73+
native_namespace = nw.get_native_namespace(df)
74+
75+
# replace 'a' with all null values to trigger warning
76+
df = df.with_columns(
77+
nw.new_series(
78+
name="d",
79+
values=[None] * len(df),
80+
native_namespace=native_namespace,
81+
),
6982
)
7083

71-
x = MedianImputer(columns=["a", "d"], weights_column="c")
84+
df = df.to_native()
7285

73-
x.fit(df)
86+
transformer = MedianImputer(columns=["a", "d"], weights_column="c")
7487

75-
ta.classes.test_object_attributes(
76-
obj=x,
77-
expected_attributes={
78-
"impute_values_": {
79-
"a": np.int64(4),
80-
"d": np.nan,
81-
},
82-
},
83-
msg="impute_values_ attribute",
84-
)
88+
transformer.fit(df)
89+
90+
assert transformer.impute_values_ == {
91+
"a": np.int64(4),
92+
"d": None,
93+
}, "impute_values_ attribute"
8594

86-
def test_fit_not_changing_data(self):
95+
@pytest.mark.parametrize("library", ["pandas", "polars"])
96+
def test_fit_not_changing_data(self, library):
8797
"""Test fit does not change X."""
88-
df = d.create_df_1()
98+
df = d.create_df_1(library=library)
8999

90-
x = MedianImputer(columns="a")
100+
transformer = MedianImputer(columns="a")
91101

92-
x.fit(df)
102+
transformer.fit(df)
93103

94-
ta.equality.assert_equal_dispatch(
95-
expected=d.create_df_1(),
96-
actual=df,
97-
msg="Check X not changing during fit",
104+
# Check whole dataframes
105+
u.assert_frame_equal_dispatch(
106+
d.create_df_1(library=library),
107+
df,
98108
)
99109

100110

Diff for: tubular/imputers.py

+25-25
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import TYPE_CHECKING
77

88
import narwhals as nw
9-
import numpy as np
109
import pandas as pd
1110

1211
from tubular.base import BaseTransformer
@@ -181,7 +180,7 @@ class attribute, indicates whether transformer has been converted to polars/pand
181180
182181
"""
183182

184-
polars_compatible = False
183+
polars_compatible = True
185184

186185
FITS = True
187186

@@ -195,51 +194,52 @@ def __init__(
195194

196195
WeightColumnMixin.check_and_set_weight(self, weights_column)
197196

198-
def fit(self, X: pd.DataFrame, y: pd.Series | None = None) -> pd.DataFrame:
197+
@nw.narwhalify
198+
def fit(self, X: FrameT, y: nw.Series | None = None) -> FrameT:
199199
"""Calculate median values to impute with from X.
200200
201201
Parameters
202202
----------
203-
X : pd.DataFrame
203+
X : pd/pl.DataFrame
204204
Data to "learn" the median values from.
205205
206-
y : None or pd.DataFrame or pd.Series, default = None
206+
y : None or pd/pl.Series, default = None
207207
Not required.
208208
209209
"""
210210
super().fit(X, y)
211211

212212
self.impute_values_ = {}
213213

214-
if self.weights_column is not None:
215-
WeightColumnMixin.check_weights_column(self, X, self.weights_column)
214+
for c in self.columns:
215+
# filter out null rows so their weight doesn't influence calc
216+
filtered = X.filter(~nw.col(c).is_null())
216217

217-
for c in self.columns:
218-
# filter out null rows so their weight doesn't influence calc
219-
filtered = X[X[c].notna()]
218+
# if column is only nulls, then median is None
219+
if len(filtered) <= 0:
220+
self.impute_values_[c] = None
220221

221-
# below algorithm only works for >1 non null values
222-
if len(filtered) <= 0:
223-
median = np.nan
222+
elif self.weights_column is not None:
223+
WeightColumnMixin.check_weights_column(self, X, self.weights_column)
224224

225-
else:
226-
# first sort df by column to be imputed (order of weight column shouldn't matter for median)
227-
filtered = filtered.sort_values(c)
225+
# first sort df by column to be imputed (order of weight column shouldn't matter for median)
226+
filtered = filtered.sort(c)
228227

229-
# next calculate cumulative weight sums
230-
cumsum = filtered[self.weights_column].cumsum()
228+
# next calculate cumulative weight sums
229+
cumsum = filtered[self.weights_column].cum_sum()
231230

232-
# find midpoint
233-
cutoff = filtered[self.weights_column].sum() / 2.0
231+
# find midpoint
232+
cutoff = filtered[self.weights_column].sum() / 2.0
234233

235-
# find first value >= this point
236-
median = filtered[c][cumsum >= cutoff].iloc[0]
234+
# find first value >= this point
235+
median = filtered.filter(cumsum >= cutoff).select(c)[0].item()
237236

237+
# impute value is weighted median
238238
self.impute_values_[c] = median
239239

240-
else:
241-
for c in self.columns:
242-
self.impute_values_[c] = X[c].median()
240+
else:
241+
# impute value is median without considering weight
242+
self.impute_values_[c] = X.select(nw.col(c).median()).item()
243243

244244
return self
245245

0 commit comments

Comments
 (0)