@@ -295,9 +295,10 @@ def check_samplers_pandas(name, sampler):
295
295
assert_allclose (y_res_s .to_numpy (), y_res )
296
296
297
297
298
- def check_samplers_dask_array (name , sampler ):
298
+ def check_samplers_dask_array (name , sampler_orig ):
299
299
pytest .importorskip ("dask" )
300
300
from dask import array
301
+ sampler = clone (sampler_orig )
301
302
# Check that the samplers handle dask array
302
303
X , y = make_classification (
303
304
n_samples = 1000 ,
@@ -309,20 +310,25 @@ def check_samplers_dask_array(name, sampler):
309
310
X_dask = array .from_array (X , chunks = 100 )
310
311
y_dask = array .from_array (y , chunks = 100 )
311
312
312
- X_res_dask , y_res_dask = sampler .fit_resample (X_dask , y_dask )
313
- X_res , y_res = sampler .fit_resample (X , y )
313
+ for validate_if_dask_collection in (True , False ):
314
+ sampler .set_params (
315
+ validate_if_dask_collection = validate_if_dask_collection
316
+ )
317
+ X_res_dask , y_res_dask = sampler .fit_resample (X_dask , y_dask )
318
+ X_res , y_res = sampler .fit_resample (X , y )
314
319
315
- # check that we return the same type for dataframes or series types
316
- assert isinstance (X_res_dask , array .Array )
317
- assert isinstance (y_res_dask , array .Array )
320
+ # check that we return the same type for dataframes or series types
321
+ assert isinstance (X_res_dask , array .Array )
322
+ assert isinstance (y_res_dask , array .Array )
318
323
319
- assert_allclose (X_res_dask , X_res )
320
- assert_allclose (y_res_dask , y_res )
324
+ assert_allclose (X_res_dask , X_res )
325
+ assert_allclose (y_res_dask , y_res )
321
326
322
327
323
- def check_samplers_dask_dataframe (name , sampler ):
328
+ def check_samplers_dask_dataframe (name , sampler_orig ):
324
329
pytest .importorskip ("dask" )
325
330
from dask import dataframe
331
+ sampler = clone (sampler_orig )
326
332
# Check that the samplers handle dask dataframe and dask series
327
333
X , y = make_classification (
328
334
n_samples = 1000 ,
@@ -337,18 +343,22 @@ def check_samplers_dask_dataframe(name, sampler):
337
343
y_s = dataframe .from_array (y )
338
344
y_s = y_s .rename ("target" )
339
345
340
- X_res_df , y_res_s = sampler .fit_resample (X_df , y_s )
341
- X_res , y_res = sampler .fit_resample (X , y )
346
+ for validate_if_dask_collection in (True , False ):
347
+ sampler .set_params (
348
+ validate_if_dask_collection = validate_if_dask_collection
349
+ )
350
+ X_res_df , y_res_s = sampler .fit_resample (X_df , y_s )
351
+ X_res , y_res = sampler .fit_resample (X , y )
342
352
343
- # check that we return the same type for dataframes or series types
344
- assert isinstance (X_res_df , dataframe .DataFrame )
345
- assert isinstance (y_res_s , dataframe .Series )
353
+ # check that we return the same type for dataframes or series types
354
+ assert isinstance (X_res_df , dataframe .DataFrame )
355
+ assert isinstance (y_res_s , dataframe .Series )
346
356
347
- assert X_df .columns .to_list () == X_res_df .columns .to_list ()
348
- assert y_s .name == y_res_s .name
357
+ assert X_df .columns .to_list () == X_res_df .columns .to_list ()
358
+ assert y_s .name == y_res_s .name
349
359
350
- assert_allclose (np .array (X_res_df ), X_res )
351
- assert_allclose (np .array (y_res_s ), y_res )
360
+ assert_allclose (np .array (X_res_df ), X_res )
361
+ assert_allclose (np .array (y_res_s ), y_res )
352
362
353
363
354
364
def check_samplers_list (name , sampler ):
0 commit comments