59
59
sklearn_version = parse_version (sklearn .__version__ )
60
60
61
61
62
+ def sample_dataset_generator ():
63
+ X , y = make_classification (
64
+ n_samples = 1000 ,
65
+ n_classes = 3 ,
66
+ n_informative = 4 ,
67
+ weights = [0.2 , 0.3 , 0.5 ],
68
+ random_state = 0 ,
69
+ )
70
+ return X , y
71
+
72
+
73
+ @pytest .fixture (name = "sample_dataset_generator" )
74
+ def sample_dataset_generator_fixture ():
75
+ return sample_dataset_generator ()
76
+
77
+
62
78
def _set_checking_parameters (estimator ):
63
79
params = estimator .get_params ()
64
80
name = estimator .__class__ .__name__
@@ -233,13 +249,7 @@ def check_samplers_fit(name, sampler_orig):
233
249
234
250
def check_samplers_fit_resample (name , sampler_orig ):
235
251
sampler = clone (sampler_orig )
236
- X , y = make_classification (
237
- n_samples = 1000 ,
238
- n_classes = 3 ,
239
- n_informative = 4 ,
240
- weights = [0.2 , 0.3 , 0.5 ],
241
- random_state = 0 ,
242
- )
252
+ X , y = sample_dataset_generator ()
243
253
target_stats = Counter (y )
244
254
X_res , y_res = sampler .fit_resample (X , y )
245
255
if isinstance (sampler , BaseOverSampler ):
@@ -269,13 +279,7 @@ def check_samplers_fit_resample(name, sampler_orig):
269
279
def check_samplers_sampling_strategy_fit_resample (name , sampler_orig ):
270
280
sampler = clone (sampler_orig )
271
281
# in this test we will force all samplers to not change the class 1
272
- X , y = make_classification (
273
- n_samples = 1000 ,
274
- n_classes = 3 ,
275
- n_informative = 4 ,
276
- weights = [0.2 , 0.3 , 0.5 ],
277
- random_state = 0 ,
278
- )
282
+ X , y = sample_dataset_generator ()
279
283
expected_stat = Counter (y )[1 ]
280
284
if isinstance (sampler , BaseOverSampler ):
281
285
sampling_strategy = {2 : 498 , 0 : 498 }
@@ -298,13 +302,7 @@ def check_samplers_sparse(name, sampler_orig):
298
302
sampler = clone (sampler_orig )
299
303
# check that sparse matrices can be passed through the sampler leading to
300
304
# the same results than dense
301
- X , y = make_classification (
302
- n_samples = 1000 ,
303
- n_classes = 3 ,
304
- n_informative = 4 ,
305
- weights = [0.2 , 0.3 , 0.5 ],
306
- random_state = 0 ,
307
- )
305
+ X , y = sample_dataset_generator ()
308
306
X_sparse = sparse .csr_matrix (X )
309
307
X_res_sparse , y_res_sparse = sampler .fit_resample (X_sparse , y )
310
308
sampler = clone (sampler )
@@ -318,13 +316,7 @@ def check_samplers_pandas(name, sampler_orig):
318
316
pd = pytest .importorskip ("pandas" )
319
317
sampler = clone (sampler_orig )
320
318
# Check that the samplers handle pandas dataframe and pandas series
321
- X , y = make_classification (
322
- n_samples = 1000 ,
323
- n_classes = 3 ,
324
- n_informative = 4 ,
325
- weights = [0.2 , 0.3 , 0.5 ],
326
- random_state = 0 ,
327
- )
319
+ X , y = sample_dataset_generator ()
328
320
X_df = pd .DataFrame (X , columns = [str (i ) for i in range (X .shape [1 ])])
329
321
y_df = pd .DataFrame (y )
330
322
y_s = pd .Series (y , name = "class" )
@@ -351,13 +343,7 @@ def check_samplers_pandas(name, sampler_orig):
351
343
def check_samplers_list (name , sampler_orig ):
352
344
sampler = clone (sampler_orig )
353
345
# Check that the can samplers handle simple lists
354
- X , y = make_classification (
355
- n_samples = 1000 ,
356
- n_classes = 3 ,
357
- n_informative = 4 ,
358
- weights = [0.2 , 0.3 , 0.5 ],
359
- random_state = 0 ,
360
- )
346
+ X , y = sample_dataset_generator ()
361
347
X_list = X .tolist ()
362
348
y_list = y .tolist ()
363
349
@@ -374,13 +360,7 @@ def check_samplers_list(name, sampler_orig):
374
360
def check_samplers_multiclass_ova (name , sampler_orig ):
375
361
sampler = clone (sampler_orig )
376
362
# Check that multiclass target lead to the same results than OVA encoding
377
- X , y = make_classification (
378
- n_samples = 1000 ,
379
- n_classes = 3 ,
380
- n_informative = 4 ,
381
- weights = [0.2 , 0.3 , 0.5 ],
382
- random_state = 0 ,
383
- )
363
+ X , y = sample_dataset_generator ()
384
364
y_ova = label_binarize (y , classes = np .unique (y ))
385
365
X_res , y_res = sampler .fit_resample (X , y )
386
366
X_res_ova , y_res_ova = sampler .fit_resample (X , y_ova )
@@ -391,27 +371,15 @@ def check_samplers_multiclass_ova(name, sampler_orig):
391
371
392
372
def check_samplers_2d_target (name , sampler_orig ):
393
373
sampler = clone (sampler_orig )
394
- X , y = make_classification (
395
- n_samples = 100 ,
396
- n_classes = 3 ,
397
- n_informative = 4 ,
398
- weights = [0.2 , 0.3 , 0.5 ],
399
- random_state = 0 ,
400
- )
374
+ X , y = sample_dataset_generator ()
401
375
402
376
y = y .reshape (- 1 , 1 ) # Make the target 2d
403
377
sampler .fit_resample (X , y )
404
378
405
379
406
380
def check_samplers_preserve_dtype (name , sampler_orig ):
407
381
sampler = clone (sampler_orig )
408
- X , y = make_classification (
409
- n_samples = 1000 ,
410
- n_classes = 3 ,
411
- n_informative = 4 ,
412
- weights = [0.2 , 0.3 , 0.5 ],
413
- random_state = 0 ,
414
- )
382
+ X , y = sample_dataset_generator ()
415
383
# Cast X and y to not default dtype
416
384
X = X .astype (np .float32 )
417
385
y = y .astype (np .int32 )
@@ -422,13 +390,7 @@ def check_samplers_preserve_dtype(name, sampler_orig):
422
390
423
391
def check_samplers_sample_indices (name , sampler_orig ):
424
392
sampler = clone (sampler_orig )
425
- X , y = make_classification (
426
- n_samples = 1000 ,
427
- n_classes = 3 ,
428
- n_informative = 4 ,
429
- weights = [0.2 , 0.3 , 0.5 ],
430
- random_state = 0 ,
431
- )
393
+ X , y = sample_dataset_generator ()
432
394
sampler .fit_resample (X , y )
433
395
sample_indices = sampler ._get_tags ().get ("sample_indices" , None )
434
396
if sample_indices :
0 commit comments