diff --git a/boruta/test/test_boruta.py b/boruta/test/test_boruta.py new file mode 100644 index 0000000..27d42e6 --- /dev/null +++ b/boruta/test/test_boruta.py @@ -0,0 +1,67 @@ +import numpy as np +import pandas as pd +import pytest +from sklearn.ensemble import RandomForestClassifier + +from boruta import BorutaPy + + +@pytest.mark.parametrize("tree_n,expected", [(10, 44), (100, 141)]) +def test_get_tree_num(tree_n, expected): + rfc = RandomForestClassifier(max_depth=10) + bt = BorutaPy(rfc) + assert bt._get_tree_num(tree_n) == expected + + +@pytest.fixture(scope="module") +def Xy(): + np.random.seed(42) + y = np.random.binomial(1, 0.5, 1000) + X = np.zeros((1000, 10)) + + z = (y - np.random.binomial(1, 0.1, 1000) + + np.random.binomial(1, 0.1, 1000)) + z[z == -1] = 0 + z[z == 2] = 1 + + # 5 relevant features + X[:, 0] = z + X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000)) + + np.random.normal(0, 0.1, 1000)) + X[:, 2] = y + np.random.normal(0, 1, 1000) + X[:, 3] = y**2 + np.random.normal(0, 1, 1000) + X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000) + + # 5 irrelevant features + X[:, 5] = np.random.normal(0, 1, 1000) + X[:, 6] = np.random.poisson(1, 1000) + X[:, 7] = np.random.binomial(1, 0.3, 1000) + X[:, 8] = np.random.normal(0, 1, 1000) + X[:, 9] = np.random.poisson(1, 1000) + + return X, y + + +def test_if_boruta_extracts_relevant_features(Xy): + X, y = Xy + rfc = RandomForestClassifier() + bt = BorutaPy(rfc) + bt.fit(X, y) + assert list(range(5)) == list(np.where(bt.support_)[0]) + + +def test_if_it_works_with_dataframe_input(Xy): + X, y = Xy + X_df, y_df = pd.DataFrame(X), pd.Series(y) + bt = BorutaPy(RandomForestClassifier()) + bt.fit(X_df, y_df) + assert list(range(5)) == list(np.where(bt.support_)[0]) + + +def test_dataframe_is_returned(Xy): + X, y = Xy + X_df, y_df = pd.DataFrame(X), pd.Series(y) + rfc = RandomForestClassifier() + bt = BorutaPy(rfc) + bt.fit(X_df, y_df) + assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame) diff --git a/boruta/test/unit_tests.py b/boruta/test/unit_tests.py deleted file mode 100644 index 5d5ce9f..0000000 --- a/boruta/test/unit_tests.py +++ /dev/null @@ -1,57 +0,0 @@ -import unittest -from boruta import BorutaPy -import pandas as pd -from sklearn.ensemble import RandomForestClassifier -import numpy as np - - -class BorutaTestCases(unittest.TestCase): - - def test_get_tree_num(self): - rfc = RandomForestClassifier(max_depth=10) - bt = BorutaPy(rfc) - self.assertEqual(bt._get_tree_num(10), 44, "Tree Est. Math Fail") - self.assertEqual(bt._get_tree_num(100), 141, "Tree Est. Math Fail") - - def test_if_boruta_extracts_relevant_features(self): - np.random.seed(42) - y = np.random.binomial(1, 0.5, 1000) - X = np.zeros((1000, 10)) - - z = y - np.random.binomial(1, 0.1, 1000) + np.random.binomial(1, 0.1, 1000) - z[z == -1] = 0 - z[z == 2] = 1 - - # 5 relevant features - X[:, 0] = z - X[:, 1] = y * np.abs(np.random.normal(0, 1, 1000)) + np.random.normal(0, 0.1, 1000) - X[:, 2] = y + np.random.normal(0, 1, 1000) - X[:, 3] = y ** 2 + np.random.normal(0, 1, 1000) - X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000) - - # 5 irrelevant features - X[:, 5] = np.random.normal(0, 1, 1000) - X[:, 6] = np.random.poisson(1, 1000) - X[:, 7] = np.random.binomial(1, 0.3, 1000) - X[:, 8] = np.random.normal(0, 1, 1000) - X[:, 9] = np.random.poisson(1, 1000) - - rfc = RandomForestClassifier() - bt = BorutaPy(rfc) - bt.fit(X, y) - - # make sure that only all the relevant features are returned - self.assertListEqual(list(range(5)), list(np.where(bt.support_)[0])) - - # test if this works as expected for dataframe input - X_df, y_df = pd.DataFrame(X), pd.Series(y) - bt.fit(X_df, y_df) - self.assertListEqual(list(range(5)), list(np.where(bt.support_)[0])) - - # check it dataframe is returned when return_df=True - self.assertIsInstance(bt.transform(X_df, return_df=True), pd.DataFrame) - -if __name__ == '__main__': - unittest.main() - - diff --git a/test_requirements.txt b/test_requirements.txt new file mode 100644 index 0000000..eedc6fd --- /dev/null +++ b/test_requirements.txt @@ -0,0 +1,7 @@ +-r requirements.txt +pytest>=5.4.1 + +# repo maintenance tooling +black>=21.5b1 +flake8>=3.9.2 +isort>=5.8.0 \ No newline at end of file