2
2
import re
3
3
4
4
import pandas as pd
5
+ import polars as pl
5
6
import pytest
6
- import test_aide as ta
7
7
8
8
import tests .test_data as d
9
9
from tests .base_tests import (
12
12
GenericTransformTests ,
13
13
OtherBaseBehaviourTests ,
14
14
)
15
- from tests .utils import assert_frame_equal_dispatch
15
+ from tests .utils import assert_frame_equal_dispatch , dataframe_init_dispatch
16
16
from tubular .mapping import BaseMappingTransformMixin
17
17
18
18
# Note there are no tests that need inheriting from this file as the only difference is an expected transform output
21
21
@pytest .fixture ()
22
22
def mapping ():
23
23
return {
24
- "a" : {1 : "a" , 2 : "b" , 3 : "c" , 4 : "d" , 5 : "e" , 6 : "f" },
25
- "b" : {"a" : 1 , "b" : 2 , "c" : 3 , "d" : 4 , "e" : 5 , "f" : 6 },
24
+ "a" : {1 : "a" , 2 : "b" , 3 : "c" , 4 : "d" , 5 : "e" , 6 : "f" , 7 : "g" , 8 : "h" , 9 : None },
25
+ "b" : {"a" : 1 , "b" : 2 , "c" : 3 , "d" : 4 , "e" : 5 , "f" : 6 , "g" : 7 , "h" : 8 , None : 9 },
26
26
}
27
27
28
28
@@ -51,100 +51,133 @@ class TestTransform(GenericTransformTests):
51
51
def setup_class (cls ):
52
52
cls .transformer_name = "BaseMappingTransformMixin"
53
53
54
- def test_expected_output (self , mapping ):
54
+ @pytest .mark .parametrize ("library" , ["pandas" , "polars" ])
55
+ def test_expected_output (self , mapping , library ):
55
56
"""Test that X is returned from transform."""
56
57
57
- df = d .create_df_1 ()
58
+ df = d .create_df_1 (library = library )
58
59
59
- expected = pd .DataFrame (
60
- {
61
- "a" : ["a" , "b" , "c" , "d" , "e" , "f" ],
62
- "b" : [1 , 2 , 3 , 4 , 5 , 6 ],
63
- },
60
+ expected_dict = {
61
+ "a" : ["a" , "b" , "c" , "d" , "e" , "f" ],
62
+ "b" : [1 , 2 , 3 , 4 , 5 , 6 ],
63
+ }
64
+
65
+ expected = dataframe_init_dispatch (
66
+ dataframe_dict = expected_dict ,
67
+ library = library ,
64
68
)
65
69
66
- x = BaseMappingTransformMixin (columns = ["a" , "b" ])
70
+ transformer = BaseMappingTransformMixin (columns = ["a" , "b" ])
67
71
68
- x .mappings = mapping
72
+ # if transformer is not yet polars compatible, skip this test
73
+ if not transformer .polars_compatible and isinstance (df , pl .DataFrame ):
74
+ return
69
75
70
- df_transformed = x .transform (df )
76
+ transformer .mappings = mapping
77
+ transformer .return_dtypes = {"a" : "String" , "b" : "Int64" }
71
78
72
- ta .equality .assert_equal_dispatch (
73
- expected = expected ,
74
- actual = df_transformed ,
75
- msg = "BaseMappingTransformMixin from transform" ,
76
- )
79
+ df_transformed = transformer .transform (df )
80
+
81
+ assert_frame_equal_dispatch (expected , df_transformed )
77
82
78
- def test_mappings_unchanged (self , mapping ):
83
+ @pytest .mark .parametrize ("library" , ["pandas" , "polars" ])
84
+ def test_mappings_unchanged (self , mapping , library ):
79
85
"""Test that mappings is unchanged in transform."""
80
- df = d .create_df_1 ()
86
+ df = d .create_df_1 (library = library )
81
87
82
- x = BaseMappingTransformMixin (columns = ["a" , "b" ])
88
+ transformer = BaseMappingTransformMixin (columns = ["a" , "b" ])
83
89
84
- x .mappings = mapping
90
+ # if transformer is not yet polars compatible, skip this test
91
+ if not transformer .polars_compatible and isinstance (df , pl .DataFrame ):
92
+ return
85
93
86
- x .transform (df )
94
+ transformer .mappings = mapping
95
+ transformer .return_dtypes = {
96
+ "a" : "String" ,
97
+ "b" : "Int64" ,
98
+ }
87
99
88
- ta . equality . assert_equal_dispatch (
89
- expected = mapping ,
90
- actual = x . mappings ,
91
- msg = "BaseMappingTransformer.transform has changed self .mappings unexpectedly" ,
92
- )
100
+ transformer . transform ( df )
101
+
102
+ assert (
103
+ mapping == transformer .mappings
104
+ ), f"BaseMappingTransformer.transform has changed self.mappings unexpectedly, expected { mapping } but got { transformer . mappings } "
93
105
106
+ @pytest .mark .parametrize ("library" , ["pandas" , "polars" ])
94
107
@pytest .mark .parametrize ("non_df" , [1 , True , "a" , [1 , 2 ], {"a" : 1 }, None ])
95
108
def test_non_pd_type_error (
96
109
self ,
97
110
non_df ,
98
111
mapping ,
112
+ library ,
99
113
):
100
114
"""Test that an error is raised in transform is X is not a pd.DataFrame."""
101
115
102
- df = d .create_df_10 ()
116
+ df = d .create_df_10 (library = library )
103
117
104
- x = BaseMappingTransformMixin (columns = ["a" ])
118
+ transformer = BaseMappingTransformMixin (columns = ["a" ])
105
119
106
- x .mappings = mapping
120
+ # if transformer is not yet polars compatible, skip this test
121
+ if not transformer .polars_compatible and isinstance (df , pl .DataFrame ):
122
+ return
107
123
108
- x_fitted = x .fit (df , df ["c" ])
124
+ transformer .mappings = mapping
125
+ transformer .return_dtypes = {
126
+ "a" : "String" ,
127
+ }
128
+
129
+ x_fitted = transformer .fit (df , df ["c" ])
109
130
110
131
with pytest .raises (
111
132
TypeError ,
112
133
match = "BaseMappingTransformMixin: X should be a polars or pandas DataFrame/LazyFrame" ,
113
134
):
114
135
x_fitted .transform (X = non_df )
115
136
116
- def test_no_rows_error (self , mapping ):
137
+ @pytest .mark .parametrize ("library" , ["pandas" , "polars" ])
138
+ def test_no_rows_error (self , mapping , library ):
117
139
"""Test an error is raised if X has no rows."""
118
- df = d .create_df_10 ()
140
+ df = d .create_df_10 (library = library )
141
+
142
+ transformer = BaseMappingTransformMixin (columns = ["a" ])
119
143
120
- x = BaseMappingTransformMixin (columns = ["a" ])
144
+ # if transformer is not yet polars compatible, skip this test
145
+ if not transformer .polars_compatible and isinstance (df , pl .DataFrame ):
146
+ return
121
147
122
- x .mappings = mapping
148
+ transformer .mappings = mapping
149
+ transformer .return_dtypes = {"a" : "String" }
123
150
124
- x = x .fit (df , df ["c" ])
151
+ transformer = transformer .fit (df , df ["c" ])
125
152
126
153
df = pd .DataFrame (columns = ["a" , "b" , "c" ])
127
154
128
155
with pytest .raises (
129
156
ValueError ,
130
157
match = re .escape ("BaseMappingTransformMixin: X has no rows; (0, 3)" ),
131
158
):
132
- x .transform (df )
159
+ transformer .transform (df )
133
160
134
- def test_original_df_not_updated (self , mapping ):
161
+ @pytest .mark .parametrize ("library" , ["pandas" , "polars" ])
162
+ def test_original_df_not_updated (self , mapping , library ):
135
163
"""Test that the original dataframe is not transformed when transform method used."""
136
164
137
- df = d .create_df_10 ()
165
+ df = d .create_df_10 (library = library )
166
+
167
+ transformer = BaseMappingTransformMixin (columns = ["a" ])
138
168
139
- x = BaseMappingTransformMixin (columns = ["a" ])
169
+ # if transformer is not yet polars compatible, skip this test
170
+ if not transformer .polars_compatible and isinstance (df , pl .DataFrame ):
171
+ return
140
172
141
- x .mappings = mapping
173
+ transformer .mappings = mapping
174
+ transformer .return_dtypes = {"a" : "String" , "b" : "Int64" }
142
175
143
- x = x .fit (df , df ["c" ])
176
+ transformer = transformer .fit (df , df ["c" ])
144
177
145
- _ = x .transform (df )
178
+ _ = transformer .transform (df )
146
179
147
- pd . testing . assert_frame_equal (df , d .create_df_10 ())
180
+ assert_frame_equal_dispatch (df , d .create_df_10 (library = library ))
148
181
149
182
@pytest .mark .parametrize (
150
183
"minimal_dataframe_lookup" ,
@@ -160,17 +193,23 @@ def test_pandas_index_not_updated(
160
193
"""Test that the original (pandas) dataframe index is not transformed when transform method used."""
161
194
162
195
df = minimal_dataframe_lookup [self .transformer_name ]
163
- x = initialized_transformers [self .transformer_name ]
164
- x .mappings = mapping
196
+ transformer = initialized_transformers [self .transformer_name ]
197
+
198
+ # if transformer is not yet polars compatible, skip this test
199
+ if not transformer .polars_compatible and isinstance (df , pl .DataFrame ):
200
+ return
201
+
202
+ transformer .mappings = mapping
203
+ transformer .return_dtypes = {"a" : "String" , "b" : "String" }
165
204
166
205
# update to abnormal index
167
206
df .index = [2 * i for i in df .index ]
168
207
169
208
original_df = copy .deepcopy (df )
170
209
171
- x = x .fit (df , df ["a" ])
210
+ transformer = transformer .fit (df , df ["a" ])
172
211
173
- _ = x .transform (df )
212
+ _ = transformer .transform (df )
174
213
175
214
assert_frame_equal_dispatch (df , original_df )
176
215
0 commit comments