diff --git a/feature_engine/requirements.txt b/feature_engine/requirements.txt index 86b0e3e4..1b858ef5 100644 --- a/feature_engine/requirements.txt +++ b/feature_engine/requirements.txt @@ -121,6 +121,7 @@ pyphen==0.14.0 pyre-extensions==0.0.30 pyro-api==0.1.2 pyro-ppl==1.8.6 +pytest pytest-runner==6.0.1 python-dateutil==2.8.1 python-editor==1.0.4 diff --git a/feature_engine/testing/test_feature_metrics.py b/feature_engine/testing/test_feature_metrics.py index 750de30c..96c7ca72 100644 --- a/feature_engine/testing/test_feature_metrics.py +++ b/feature_engine/testing/test_feature_metrics.py @@ -4,7 +4,12 @@ from numpy import nan import logging -test_df = pd.read_csv("../output/chat/reddit_test_chat_level.csv") +# test_input_df = pd.read_csv("../feature_engine/reddit_test_chat_level.csv") +test_chat_df = pd.read_csv("../output/chat/reddit_test_chat_level.csv") +# test_conv_df_output = pd.read_csv("../output/conv/reddit_test_chat_level.csv") +# join test_input_df with test_conv_df_output on conversation_num id +# test_conv_df = pd.merge(test_input_df, test_conv_df_output, on='conversation_num') + # test_df['test_pass'] = test_df.apply(lambda row: row[row['expected_column']] == row['expected_value'], axis=1) # test_df['obtained_value'] = test_df.apply(lambda row: row[row['expected_column']], axis=1) # test_df[["message", "expected_column", "expected_value", "obtained_value", "test_pass"]] @@ -18,8 +23,8 @@ console_handler.setLevel(logging.ERROR) logger.addHandler(console_handler) -@pytest.mark.parametrize("row", test_df.iterrows()) -def test_unit_equality(row): +@pytest.mark.parametrize("row", test_chat_df.iterrows()) +def test_chat_unit_equality(row): actual = row[1][row[1]['expected_column']] expected = row[1]['expected_value'] @@ -33,3 +38,21 @@ def test_unit_equality(row): logger.error("Actual value: %s", actual) raise # Re-raise the AssertionError to mark the test as failed + + +# @pytest.mark.parametrize("conversation_num, conversation_rows", test_conv_df.groupby('conversation_num')) +# def test_conv_unit_equality(conversation_num, conversation_rows): +# actual = row[1][row[1]['expected_column']] +# expected = row[1]['expected_value'] + +# try: +# assert actual == expected +# except AssertionError: +# logger.error("") +# logger.error("------TEST FAILED------") +# logger.error("Testing %s for conversation: %s ", row[1]['expected_column'], row[1]['conversation_num']) +# logger.error("Expected value: %s ", expected) +# logger.error("Actual value: %s", actual) + +# raise # Re-raise the AssertionError to mark the test as failed +