@@ -258,7 +258,7 @@ def check_samplers_pandas(name, Sampler):
258
258
X_res_df , y_res_df = sampler .fit_resample (X_df , y_df )
259
259
X_res , y_res = sampler .fit_resample (X , y )
260
260
261
- # check that we return the same type for dataframes or seires types
261
+ # check that we return the same type for dataframes or series types
262
262
assert isinstance (X_res_df , pd .DataFrame )
263
263
assert isinstance (y_res_df , pd .DataFrame )
264
264
assert isinstance (y_res_s , pd .Series )
@@ -272,6 +272,36 @@ def check_samplers_pandas(name, Sampler):
272
272
assert_allclose (y_res_s .to_numpy (), y_res )
273
273
274
274
275
+ def check_samplers_list (name , Sampler ):
276
+ # Check that the can samplers handle simple lists
277
+ X , y = make_classification (
278
+ n_samples = 1000 ,
279
+ n_classes = 3 ,
280
+ n_informative = 4 ,
281
+ weights = [0.2 , 0.3 , 0.5 ],
282
+ random_state = 0 ,
283
+ )
284
+ X_list = X .tolist ()
285
+ y_list = y .tolist ()
286
+ sampler = Sampler ()
287
+ if isinstance (Sampler (), NearMiss ):
288
+ samplers = [Sampler (version = version ) for version in (1 , 2 , 3 )]
289
+
290
+ else :
291
+ samplers = [Sampler ()]
292
+
293
+ for sampler in samplers :
294
+ set_random_state (sampler )
295
+ X_res , y_res = sampler .fit_resample (X , y )
296
+ X_res_list , y_res_list = sampler .fit_resample (X_list , y_list )
297
+
298
+ assert isinstance (X_res_list , list )
299
+ assert isinstance (y_res_list , list )
300
+
301
+ assert_allclose (X_res , X_res_list )
302
+ assert_allclose (y_res , y_res_list )
303
+
304
+
275
305
def check_samplers_multiclass_ova (name , Sampler ):
276
306
# Check that multiclass target lead to the same results than OVA encoding
277
307
X , y = make_classification (
0 commit comments