Skip to content

Commit 3a7633d

Browse files
awinmlglemaitre
andauthored
MAINT Add fixture for dataset generation in common tests (#932)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 5be5670 commit 3a7633d

File tree

1 file changed

+25
-63
lines changed

1 file changed

+25
-63
lines changed

imblearn/utils/estimator_checks.py

+25-63
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@
5959
sklearn_version = parse_version(sklearn.__version__)
6060

6161

62+
def sample_dataset_generator():
63+
X, y = make_classification(
64+
n_samples=1000,
65+
n_classes=3,
66+
n_informative=4,
67+
weights=[0.2, 0.3, 0.5],
68+
random_state=0,
69+
)
70+
return X, y
71+
72+
73+
@pytest.fixture(name="sample_dataset_generator")
74+
def sample_dataset_generator_fixture():
75+
return sample_dataset_generator()
76+
77+
6278
def _set_checking_parameters(estimator):
6379
params = estimator.get_params()
6480
name = estimator.__class__.__name__
@@ -233,13 +249,7 @@ def check_samplers_fit(name, sampler_orig):
233249

234250
def check_samplers_fit_resample(name, sampler_orig):
235251
sampler = clone(sampler_orig)
236-
X, y = make_classification(
237-
n_samples=1000,
238-
n_classes=3,
239-
n_informative=4,
240-
weights=[0.2, 0.3, 0.5],
241-
random_state=0,
242-
)
252+
X, y = sample_dataset_generator()
243253
target_stats = Counter(y)
244254
X_res, y_res = sampler.fit_resample(X, y)
245255
if isinstance(sampler, BaseOverSampler):
@@ -269,13 +279,7 @@ def check_samplers_fit_resample(name, sampler_orig):
269279
def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
270280
sampler = clone(sampler_orig)
271281
# in this test we will force all samplers to not change the class 1
272-
X, y = make_classification(
273-
n_samples=1000,
274-
n_classes=3,
275-
n_informative=4,
276-
weights=[0.2, 0.3, 0.5],
277-
random_state=0,
278-
)
282+
X, y = sample_dataset_generator()
279283
expected_stat = Counter(y)[1]
280284
if isinstance(sampler, BaseOverSampler):
281285
sampling_strategy = {2: 498, 0: 498}
@@ -298,13 +302,7 @@ def check_samplers_sparse(name, sampler_orig):
298302
sampler = clone(sampler_orig)
299303
# check that sparse matrices can be passed through the sampler leading to
300304
# the same results than dense
301-
X, y = make_classification(
302-
n_samples=1000,
303-
n_classes=3,
304-
n_informative=4,
305-
weights=[0.2, 0.3, 0.5],
306-
random_state=0,
307-
)
305+
X, y = sample_dataset_generator()
308306
X_sparse = sparse.csr_matrix(X)
309307
X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y)
310308
sampler = clone(sampler)
@@ -318,13 +316,7 @@ def check_samplers_pandas(name, sampler_orig):
318316
pd = pytest.importorskip("pandas")
319317
sampler = clone(sampler_orig)
320318
# Check that the samplers handle pandas dataframe and pandas series
321-
X, y = make_classification(
322-
n_samples=1000,
323-
n_classes=3,
324-
n_informative=4,
325-
weights=[0.2, 0.3, 0.5],
326-
random_state=0,
327-
)
319+
X, y = sample_dataset_generator()
328320
X_df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
329321
y_df = pd.DataFrame(y)
330322
y_s = pd.Series(y, name="class")
@@ -351,13 +343,7 @@ def check_samplers_pandas(name, sampler_orig):
351343
def check_samplers_list(name, sampler_orig):
352344
sampler = clone(sampler_orig)
353345
# Check that the can samplers handle simple lists
354-
X, y = make_classification(
355-
n_samples=1000,
356-
n_classes=3,
357-
n_informative=4,
358-
weights=[0.2, 0.3, 0.5],
359-
random_state=0,
360-
)
346+
X, y = sample_dataset_generator()
361347
X_list = X.tolist()
362348
y_list = y.tolist()
363349

@@ -374,13 +360,7 @@ def check_samplers_list(name, sampler_orig):
374360
def check_samplers_multiclass_ova(name, sampler_orig):
375361
sampler = clone(sampler_orig)
376362
# Check that multiclass target lead to the same results than OVA encoding
377-
X, y = make_classification(
378-
n_samples=1000,
379-
n_classes=3,
380-
n_informative=4,
381-
weights=[0.2, 0.3, 0.5],
382-
random_state=0,
383-
)
363+
X, y = sample_dataset_generator()
384364
y_ova = label_binarize(y, classes=np.unique(y))
385365
X_res, y_res = sampler.fit_resample(X, y)
386366
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
@@ -391,27 +371,15 @@ def check_samplers_multiclass_ova(name, sampler_orig):
391371

392372
def check_samplers_2d_target(name, sampler_orig):
393373
sampler = clone(sampler_orig)
394-
X, y = make_classification(
395-
n_samples=100,
396-
n_classes=3,
397-
n_informative=4,
398-
weights=[0.2, 0.3, 0.5],
399-
random_state=0,
400-
)
374+
X, y = sample_dataset_generator()
401375

402376
y = y.reshape(-1, 1) # Make the target 2d
403377
sampler.fit_resample(X, y)
404378

405379

406380
def check_samplers_preserve_dtype(name, sampler_orig):
407381
sampler = clone(sampler_orig)
408-
X, y = make_classification(
409-
n_samples=1000,
410-
n_classes=3,
411-
n_informative=4,
412-
weights=[0.2, 0.3, 0.5],
413-
random_state=0,
414-
)
382+
X, y = sample_dataset_generator()
415383
# Cast X and y to not default dtype
416384
X = X.astype(np.float32)
417385
y = y.astype(np.int32)
@@ -422,13 +390,7 @@ def check_samplers_preserve_dtype(name, sampler_orig):
422390

423391
def check_samplers_sample_indices(name, sampler_orig):
424392
sampler = clone(sampler_orig)
425-
X, y = make_classification(
426-
n_samples=1000,
427-
n_classes=3,
428-
n_informative=4,
429-
weights=[0.2, 0.3, 0.5],
430-
random_state=0,
431-
)
393+
X, y = sample_dataset_generator()
432394
sampler.fit_resample(X, y)
433395
sample_indices = sampler._get_tags().get("sample_indices", None)
434396
if sample_indices:

0 commit comments

Comments
 (0)