Skip to content

Commit 4e4097f

Browse files
committed
Added tests to confirm featurebuilder behavior on unhashable columns is correct
1 parent 040bfc2 commit 4e4097f

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

src/team_comm_tools/feature_builder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def __init__(
151151
if not isinstance(input_df, pd.DataFrame):
152152
raise TypeError(f"Expected a Pandas DataFrame as input_df, but got {type(df).__name__})")
153153

154-
# if not vector_directory:
155-
# raise ValueError("You must pass in a valid directory to cache vectors! For example: ./vector_data/") # NOTE: This is redundant because we have a default value
156-
157154
print("Initializing Featurization...")
158155

159156
###### Set all parameters ######

src/team_comm_tools/utils/preprocess.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,23 @@ def remove_unhashable_cols(df: pd.DataFrame, column_names: dict) -> pd.DataFrame
8080
"""
8181
# Check cols with unhashable types
8282
unhashable_cols = {}
83+
84+
def is_unhashable(obj):
85+
"""
86+
Small function to test whether a data type is hashable, without storing a hard-coded list of hashable types.
87+
88+
:param obj: an object to test hashability
89+
:return: Whether or not the object is unhashable
90+
:rtype: bool
91+
"""
92+
try:
93+
hash(obj)
94+
return False
95+
except TypeError:
96+
return True
97+
8398
for col in df.columns:
84-
unhashable_values = df[col].apply(lambda x: isinstance(x, (set, list, dict)))
99+
unhashable_values = df[col].apply(lambda x: is_unhashable(x))
85100
if unhashable_values.any():
86101
unique_types = df[col][unhashable_values].apply(lambda x: type(x)).unique()
87102
unhashable_cols[col] = unique_types

tests/run_package_grouping_tests.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# Importing the Feature Generating Class
99
from team_comm_tools import FeatureBuilder
1010
import pandas as pd
11+
import warnings
12+
warnings.simplefilter("always", category=UserWarning)
1113

1214
# Main Function
1315
if __name__ == "__main__":
@@ -231,3 +233,46 @@
231233
user_aggregation = False,
232234
)
233235
custom_agg_fb_no_agg.featurize()
236+
237+
238+
"""
239+
Test that it's possible to run the FB with unhashable types that are NOT required columns.
240+
"""
241+
242+
print("Testing unhashable types on non-required columns...")
243+
test_df_with_nonrequired_unhashables = pd.DataFrame({
244+
"conversation_id": [1, 1, 2, 2, 2],
245+
"text": ["test1", "test2", "test3", "test4", "test5"],
246+
"speaker_id": [1, 2, 1, 2, 1],
247+
"unhashable_col": [{}, set(), ["foo", "bar"], {"foo": "bar"}, {"foo": ["bar1", "bar2", "bar3"]}]
248+
})
249+
250+
fb_nonrequred_unhashable_test = FeatureBuilder(
251+
input_df = test_df_with_nonrequired_unhashables,
252+
conversation_id_col = "conversation_id",
253+
message_col = "text",
254+
speaker_id_col = "speaker_id"
255+
)
256+
257+
"""
258+
Test that, if we have unhashable types as required columns, we throw a ValueError as expected.
259+
"""
260+
261+
print("Testing unhashable types on required columns...")
262+
test_df_with_required_unhashables = pd.DataFrame({
263+
"conversation_id": [[1], [1], [2], [2], [2]], # conversation_id is a required col and contains lists
264+
"text": ["test1", "test2", "test3", "test4", "test5"],
265+
"speaker_id": [1, 2, 1, 2, 1]
266+
})
267+
268+
try:
269+
fb_required_unhashable_test = FeatureBuilder(
270+
input_df = test_df_with_required_unhashables,
271+
conversation_id_col = "conversation_id",
272+
message_col = "text",
273+
speaker_id_col = "speaker_id"
274+
)
275+
except ValueError as e:
276+
assert('has unhashable data types' in str(e)) # make sure we caught the right error!
277+
print("Test has properly exited with error.")
278+

0 commit comments

Comments
 (0)