Skip to content

Commit ab09ade

Browse files
mostly converted - nan to string conversion issue
1 parent 2d10a54 commit ab09ade

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

tubular/imputers.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ class attribute, indicates whether transformer has been converted to polars/pand
8686
"""
8787

8888
polars_compatible = True
89-
9089
FITS = False
9190

9291
def __init__(
@@ -97,12 +96,11 @@ def __init__(
9796
) -> None:
9897
super().__init__(columns=columns, **kwargs)
9998

100-
if (
101-
not isinstance(impute_value, int)
102-
and not isinstance(impute_value, float)
103-
and not isinstance(impute_value, str)
104-
):
105-
msg = f"{self.classname()}: impute_value should be a single value (int, float or str)"
99+
if not isinstance(impute_value, (int, float, str)):
100+
msg = (
101+
f"{self.classname()}: impute_value should be a single value "
102+
"(int, float or str)"
103+
)
106104
raise ValueError(msg)
107105

108106
self.impute_values_ = {}
@@ -130,26 +128,35 @@ def transform(self, X: FrameT) -> FrameT:
130128
self.check_is_fitted(["impute_value"])
131129
self.columns_check(X)
132130

133-
X = nw.from_native(X)
131+
if len(X) == 0:
132+
msg = f"{self.classname()}: X has no rows; {X.shape}"
133+
raise ValueError(msg)
134134

135-
new_col_expressions = []
135+
# Save the original dtypes BEFORE we cast anything
136+
original_dtypes = {}
136137
for c in self.columns:
137-
# Handle categorical column cases explicitly using Polars' syntax with `nw`
138-
if (
139-
X[c].dtype == nw.Categorical
140-
and self.impute_value not in X[c].cat.categories
141-
):
142-
X[c] = X[c].cat.add_categories(self.impute_value) # add new category
138+
original_dtypes[c] = X[c].dtype
143139

144-
# Apply fill_null() properly within Narwhals
145-
new_col_expressions.append(nw.col(c).fill_null(self.impute_values_[c]))
140+
# If originally categorical, cast to nw.String so fill won't fail
141+
for c in self.columns:
142+
if original_dtypes[c] == nw.Categorical:
143+
# If the fill is numeric, turn fill into a str
144+
X = X.with_columns([nw.col(c).cast(nw.String)])
146145

147-
# Ensure dtype consistency
148-
X_transformed = X.with_columns(new_col_expressions)
146+
# Fill nulls
147+
new_col_exprs = [
148+
nw.col(c).fill_null(self.impute_values_[c]) for c in self.columns
149+
]
150+
X_transformed = X.with_columns(new_col_exprs)
149151

152+
# Cast back if originally categorical
150153
for c in self.columns:
151-
dtype = X[c].dtype # get the dtype of the original column
152-
X_transformed[c] = X_transformed[c].astype(dtype)
154+
if original_dtypes[c] == nw.Categorical:
155+
X_transformed = X_transformed.with_columns(
156+
[
157+
nw.col(c).cast(nw.Categorical),
158+
],
159+
)
153160

154161
return X_transformed
155162

0 commit comments

Comments
 (0)