Skip to content

Commit d8419e6

Browse files
[ENH] refactor repetitive clone tests with pytest.mark.parametrize, fixes #170 (#392)
#### Reference Issues/PRs Fixes #170 #### What does this implement/fix? Explain your changes. refactored repetitive clone tests with pytest.mark.parametrize
1 parent ea50550 commit d8419e6

File tree

2 files changed

+30
-68
lines changed

2 files changed

+30
-68
lines changed

.all-contributorsrc

+10
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,16 @@
131131
"bug",
132132
"code"
133133
]
134+
},
135+
{
136+
"login": "JahnaviDhanaSri",
137+
"name": "Jahnavi Dhana Sri",
138+
"avatar_url": "https://avatars.githubusercontent.com/u/143936922?v=4",
139+
"profile": "https://github.com/JahnaviDhanaSri",
140+
"contributions": [
141+
"code",
142+
"test"
143+
]
134144
}
135145
],
136146
"projectName": "skbase",

skbase/tests/test_base.py

+20-68
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@
5151
"test_clone",
5252
"test_clone_2",
5353
"test_clone_raises_error_for_nonconforming_objects",
54-
"test_clone_param_is_none",
55-
"test_clone_empty_array",
56-
"test_clone_sparse_matrix",
57-
"test_clone_nan",
54+
"test_clone_none_and_empty_array_nan_sparse_matrix",
5855
"test_clone_estimator_types",
5956
"test_clone_class_rather_than_instance_raises_error",
6057
"test_clone_sklearn_composite",
@@ -1025,75 +1022,30 @@ def __init__(self, obj, obj_iterable):
10251022
not _check_soft_dependencies("scikit-learn", severity="none"),
10261023
reason="skip test if sklearn is not available",
10271024
) # sklearn is part of the dev dependency set, test should be executed with that
1028-
def test_clone_param_is_none(fixture_class_parent: Type[Parent]):
1029-
"""Test clone with keyword parameter set to None."""
1030-
from sklearn.base import clone
1031-
1032-
base_obj = fixture_class_parent(c=None)
1033-
new_base_obj = clone(base_obj)
1034-
new_base_obj2 = base_obj.clone()
1035-
assert base_obj.c is new_base_obj.c
1036-
assert base_obj.c is new_base_obj2.c
1037-
1038-
1039-
@pytest.mark.skipif(
1040-
not _check_soft_dependencies("scikit-learn", severity="none"),
1041-
reason="skip test if sklearn is not available",
1042-
) # sklearn is part of the dev dependency set, test should be executed with that
1043-
def test_clone_empty_array(fixture_class_parent: Type[Parent]):
1044-
"""Test clone with keyword parameter is scipy sparse matrix.
1045-
1046-
This test is based on scikit-learn regression test to make sure clone
1047-
works with default parameter set to scipy sparse matrix.
1048-
"""
1049-
from sklearn.base import clone
1050-
1051-
# Regression test for cloning estimators with empty arrays
1052-
base_obj = fixture_class_parent(c=np.array([]))
1053-
new_base_obj = clone(base_obj)
1054-
new_base_obj2 = base_obj.clone()
1055-
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1056-
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1057-
1058-
1059-
@pytest.mark.skipif(
1060-
not _check_soft_dependencies("scikit-learn", severity="none"),
1061-
reason="skip test if sklearn is not available",
1062-
) # sklearn is part of the dev dependency set, test should be executed with that
1063-
def test_clone_sparse_matrix(fixture_class_parent: Type[Parent]):
1064-
"""Test clone with keyword parameter is scipy sparse matrix.
1065-
1066-
This test is based on scikit-learn regression test to make sure clone
1067-
works with default parameter set to scipy sparse matrix.
1068-
"""
1069-
from sklearn.base import clone
1070-
1071-
base_obj = fixture_class_parent(c=sp.csr_matrix(np.array([[0]])))
1072-
new_base_obj = clone(base_obj)
1073-
new_base_obj2 = base_obj.clone()
1074-
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1075-
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1076-
1077-
1078-
@pytest.mark.skipif(
1079-
not _check_soft_dependencies("scikit-learn", severity="none"),
1080-
reason="skip test if sklearn is not available",
1081-
) # sklearn is part of the dev dependency set, test should be executed with that
1082-
def test_clone_nan(fixture_class_parent: Type[Parent]):
1083-
"""Test clone with keyword parameter is np.nan.
1084-
1085-
This test is based on scikit-learn regression test to make sure clone
1086-
works with default parameter set to np.nan.
1087-
"""
1025+
@pytest.mark.parametrize(
1026+
"c_value",
1027+
[
1028+
None,
1029+
np.array([]),
1030+
sp.csr_matrix(np.array([[0]])),
1031+
np.nan,
1032+
],
1033+
)
1034+
def test_clone_none_and_empty_array_nan_sparse_matrix(
1035+
fixture_class_parent: Type[Parent], c_value
1036+
):
10881037
from sklearn.base import clone
10891038

1090-
# Regression test for cloning estimators with default parameter as np.nan
1091-
base_obj = fixture_class_parent(c=np.nan)
1039+
base_obj = fixture_class_parent(c=c_value)
10921040
new_base_obj = clone(base_obj)
10931041
new_base_obj2 = base_obj.clone()
10941042

1095-
assert base_obj.c is new_base_obj.c
1096-
assert base_obj.c is new_base_obj2.c
1043+
if isinstance(base_obj.c, (np.ndarray, type(sp.csr_matrix(np.array([[0]]))))):
1044+
np.testing.assert_array_equal(base_obj.c, new_base_obj.c)
1045+
np.testing.assert_array_equal(base_obj.c, new_base_obj2.c)
1046+
else:
1047+
assert base_obj.c is new_base_obj.c
1048+
assert base_obj.c is new_base_obj2.c
10971049

10981050

10991051
def test_clone_estimator_types(fixture_class_parent: Type[Parent]):

0 commit comments

Comments
 (0)