@@ -86,7 +86,6 @@ class attribute, indicates whether transformer has been converted to polars/pand
86
86
"""
87
87
88
88
polars_compatible = True
89
-
90
89
FITS = False
91
90
92
91
def __init__ (
@@ -97,12 +96,11 @@ def __init__(
97
96
) -> None :
98
97
super ().__init__ (columns = columns , ** kwargs )
99
98
100
- if (
101
- not isinstance (impute_value , int )
102
- and not isinstance (impute_value , float )
103
- and not isinstance (impute_value , str )
104
- ):
105
- msg = f"{ self .classname ()} : impute_value should be a single value (int, float or str)"
99
+ if not isinstance (impute_value , (int , float , str )):
100
+ msg = (
101
+ f"{ self .classname ()} : impute_value should be a single value "
102
+ "(int, float or str)"
103
+ )
106
104
raise ValueError (msg )
107
105
108
106
self .impute_values_ = {}
@@ -130,26 +128,35 @@ def transform(self, X: FrameT) -> FrameT:
130
128
self .check_is_fitted (["impute_value" ])
131
129
self .columns_check (X )
132
130
133
- X = nw .from_native (X )
131
+ if len (X ) == 0 :
132
+ msg = f"{ self .classname ()} : X has no rows; { X .shape } "
133
+ raise ValueError (msg )
134
134
135
- new_col_expressions = []
135
+ # Save the original dtypes BEFORE we cast anything
136
+ original_dtypes = {}
136
137
for c in self .columns :
137
- # Handle categorical column cases explicitly using Polars' syntax with `nw`
138
- if (
139
- X [c ].dtype == nw .Categorical
140
- and self .impute_value not in X [c ].cat .categories
141
- ):
142
- X [c ] = X [c ].cat .add_categories (self .impute_value ) # add new category
138
+ original_dtypes [c ] = X [c ].dtype
143
139
144
- # Apply fill_null() properly within Narwhals
145
- new_col_expressions .append (nw .col (c ).fill_null (self .impute_values_ [c ]))
140
+ # If originally categorical, cast to nw.String so fill won't fail
141
+ for c in self .columns :
142
+ if original_dtypes [c ] == nw .Categorical :
143
+ # If the fill is numeric, turn fill into a str
144
+ X = X .with_columns ([nw .col (c ).cast (nw .String )])
146
145
147
- # Ensure dtype consistency
148
- X_transformed = X .with_columns (new_col_expressions )
146
+ # Fill nulls
147
+ new_col_exprs = [
148
+ nw .col (c ).fill_null (self .impute_values_ [c ]) for c in self .columns
149
+ ]
150
+ X_transformed = X .with_columns (new_col_exprs )
149
151
152
+ # Cast back if originally categorical
150
153
for c in self .columns :
151
- dtype = X [c ].dtype # get the dtype of the original column
152
- X_transformed [c ] = X_transformed [c ].astype (dtype )
154
+ if original_dtypes [c ] == nw .Categorical :
155
+ X_transformed = X_transformed .with_columns (
156
+ [
157
+ nw .col (c ).cast (nw .Categorical ),
158
+ ],
159
+ )
153
160
154
161
return X_transformed
155
162
0 commit comments