6
6
from typing import TYPE_CHECKING
7
7
8
8
import narwhals as nw
9
- import pandas as pd
10
9
11
10
from tubular .base import BaseTransformer
12
11
from tubular .mixins import WeightColumnMixin
13
12
14
13
if TYPE_CHECKING :
15
- import pandas as pd
16
14
from narwhals .typing import FrameT
17
15
18
16
@@ -87,7 +85,7 @@ class ArbitraryImputer(BaseImputer):
87
85
class attribute, indicates whether transformer has been converted to polars/pandas agnostic narwhals framework
88
86
"""
89
87
90
- polars_compatible = False
88
+ polars_compatible = True
91
89
92
90
FITS = False
93
91
@@ -113,18 +111,19 @@ def __init__(
113
111
for c in self .columns :
114
112
self .impute_values_ [c ] = self .impute_value
115
113
116
- def transform (self , X : pd .DataFrame ) -> pd .DataFrame :
114
+ @nw .narwhalify
115
+ def transform (self , X : FrameT ) -> FrameT :
117
116
"""Impute missing values with the supplied impute_value.
118
117
If columns is None all columns in X will be imputed.
119
118
120
119
Parameters
121
120
----------
122
- X : pd.DataFrame
121
+ X : pd/pl .DataFrame
123
122
Data containing columns to impute.
124
123
125
124
Returns
126
125
-------
127
- X : pd.DataFrame
126
+ X : pd/pl .DataFrame
128
127
Transformed input X with nulls imputed with the specified impute_value, for the specified columns.
129
128
130
129
Additions
@@ -136,11 +135,11 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
136
135
self .columns_check (X )
137
136
for c in self .columns :
138
137
if (
139
- "category" in X [c ].dtype . name
140
- and self .impute_value not in X [c ].cat . categories
138
+ "category" in str ( X [c ].dtype )
139
+ and self .impute_value not in X [c ].unique (). to_list ()
141
140
):
142
- X [ c ] = X [ c ]. cat . add_categories (
143
- self .impute_value ,
141
+ X = X . with_columns (
142
+ X [ c ]. cast ( nw . Categorical ). cat . add_categories ( self .impute_value ) ,
144
143
) # add new category
145
144
146
145
# Calling the BaseImputer's transform method to impute the values
@@ -149,7 +148,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
149
148
# casting imputer value as same dtype as original column
150
149
for c in self .columns :
151
150
dtype = X [c ].dtype # get the dtype of original column
152
- X_transformed [c ] = X_transformed [c ].astype (dtype )
151
+ X_transformed [c ] = X_transformed . with_columns ( X_transformed [c ].cast (dtype ) )
153
152
154
153
return X_transformed
155
154
0 commit comments