|
8 | 8 | # Importing the Feature Generating Class |
9 | 9 | from team_comm_tools import FeatureBuilder |
10 | 10 | import pandas as pd |
| 11 | +import warnings |
| 12 | +warnings.simplefilter("always", category=UserWarning) |
11 | 13 |
|
12 | 14 | # Main Function |
13 | 15 | if __name__ == "__main__": |
|
231 | 233 | user_aggregation = False, |
232 | 234 | ) |
233 | 235 | custom_agg_fb_no_agg.featurize() |
| 236 | + |
| 237 | + |
| 238 | + """ |
| 239 | + Test that it's possible to run the FB with unhashable types that are NOT required columns. |
| 240 | + """ |
| 241 | + |
| 242 | + print("Testing unhashable types on non-required columns...") |
| 243 | + test_df_with_nonrequired_unhashables = pd.DataFrame({ |
| 244 | + "conversation_id": [1, 1, 2, 2, 2], |
| 245 | + "text": ["test1", "test2", "test3", "test4", "test5"], |
| 246 | + "speaker_id": [1, 2, 1, 2, 1], |
| 247 | + "unhashable_col": [{}, set(), ["foo", "bar"], {"foo": "bar"}, {"foo": ["bar1", "bar2", "bar3"]}] |
| 248 | + }) |
| 249 | + |
| 250 | + fb_nonrequred_unhashable_test = FeatureBuilder( |
| 251 | + input_df = test_df_with_nonrequired_unhashables, |
| 252 | + conversation_id_col = "conversation_id", |
| 253 | + message_col = "text", |
| 254 | + speaker_id_col = "speaker_id" |
| 255 | + ) |
| 256 | + |
| 257 | + """ |
| 258 | + Test that, if we have unhashable types as required columns, we throw a ValueError as expected. |
| 259 | + """ |
| 260 | + |
| 261 | + print("Testing unhashable types on required columns...") |
| 262 | + test_df_with_required_unhashables = pd.DataFrame({ |
| 263 | + "conversation_id": [[1], [1], [2], [2], [2]], # conversation_id is a required col and contains lists |
| 264 | + "text": ["test1", "test2", "test3", "test4", "test5"], |
| 265 | + "speaker_id": [1, 2, 1, 2, 1] |
| 266 | + }) |
| 267 | + |
| 268 | + try: |
| 269 | + fb_required_unhashable_test = FeatureBuilder( |
| 270 | + input_df = test_df_with_required_unhashables, |
| 271 | + conversation_id_col = "conversation_id", |
| 272 | + message_col = "text", |
| 273 | + speaker_id_col = "speaker_id" |
| 274 | + ) |
| 275 | + except ValueError as e: |
| 276 | + assert('has unhashable data types' in str(e)) # make sure we caught the right error! |
| 277 | + print("Test has properly exited with error.") |
| 278 | + |
0 commit comments