Skip to content

Commit e0c88ed

Browse files
failing tests atm
1 parent 14ece40 commit e0c88ed

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

tests/imputers/test_ArbitraryImputer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
# Dataframe used exclusively in this testing script
17-
def create_downcast_df(library):
17+
def create_downcast_df(library="pandas"):
1818
"""Create a dataframe with mixed dtypes to use in downcasting tests."""
1919
df = pd.DataFrame(
2020
{

tubular/imputers.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66
from typing import TYPE_CHECKING
77

88
import narwhals as nw
9-
import pandas as pd
109

1110
from tubular.base import BaseTransformer
1211
from tubular.mixins import WeightColumnMixin
1312

1413
if TYPE_CHECKING:
15-
import pandas as pd
1614
from narwhals.typing import FrameT
1715

1816

@@ -87,7 +85,7 @@ class ArbitraryImputer(BaseImputer):
8785
class attribute, indicates whether transformer has been converted to polars/pandas agnostic narwhals framework
8886
"""
8987

90-
polars_compatible = False
88+
polars_compatible = True
9189

9290
FITS = False
9391

@@ -113,18 +111,19 @@ def __init__(
113111
for c in self.columns:
114112
self.impute_values_[c] = self.impute_value
115113

116-
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
114+
@nw.narwhalify
115+
def transform(self, X: FrameT) -> FrameT:
117116
"""Impute missing values with the supplied impute_value.
118117
If columns is None all columns in X will be imputed.
119118
120119
Parameters
121120
----------
122-
X : pd.DataFrame
121+
X : pd/pl.DataFrame
123122
Data containing columns to impute.
124123
125124
Returns
126125
-------
127-
X : pd.DataFrame
126+
X : pd/pl.DataFrame
128127
Transformed input X with nulls imputed with the specified impute_value, for the specified columns.
129128
130129
Additions
@@ -136,11 +135,11 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
136135
self.columns_check(X)
137136
for c in self.columns:
138137
if (
139-
"category" in X[c].dtype.name
140-
and self.impute_value not in X[c].cat.categories
138+
"category" in str(X[c].dtype)
139+
and self.impute_value not in X[c].unique().to_list()
141140
):
142-
X[c] = X[c].cat.add_categories(
143-
self.impute_value,
141+
X = X.with_columns(
142+
X[c].cast(nw.Categorical).cat.add_categories(self.impute_value),
144143
) # add new category
145144

146145
# Calling the BaseImputer's transform method to impute the values
@@ -149,7 +148,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
149148
# casting imputer value as same dtype as original column
150149
for c in self.columns:
151150
dtype = X[c].dtype # get the dtype of original column
152-
X_transformed[c] = X_transformed[c].astype(dtype)
151+
X_transformed[c] = X_transformed.with_columns(X_transformed[c].cast(dtype))
153152

154153
return X_transformed
155154

0 commit comments

Comments
 (0)