Skip to content

Commit

Permalink
Added tests to confirm featurebuilder behavior on unhashable columns …
Browse files Browse the repository at this point in the history
…is correct
  • Loading branch information
xehu committed Mar 2, 2025
1 parent 040bfc2 commit 4e4097f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
3 changes: 0 additions & 3 deletions src/team_comm_tools/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ def __init__(
if not isinstance(input_df, pd.DataFrame):
raise TypeError(f"Expected a Pandas DataFrame as input_df, but got {type(df).__name__})")

# if not vector_directory:
# raise ValueError("You must pass in a valid directory to cache vectors! For example: ./vector_data/") # NOTE: This is redundant because we have a default value

print("Initializing Featurization...")

###### Set all parameters ######
Expand Down
17 changes: 16 additions & 1 deletion src/team_comm_tools/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,23 @@ def remove_unhashable_cols(df: pd.DataFrame, column_names: dict) -> pd.DataFrame
"""
# Check cols with unhashable types
unhashable_cols = {}

def is_unhashable(obj):
"""
Small function to test whether a data type is hashable, without storing a hard-coded list of hashable types.
:param obj: an object to test hashability
:return: Whether or not the object is unhashable
:rtype: bool
"""
try:
hash(obj)
return False
except TypeError:
return True

for col in df.columns:
unhashable_values = df[col].apply(lambda x: isinstance(x, (set, list, dict)))
unhashable_values = df[col].apply(lambda x: is_unhashable(x))
if unhashable_values.any():
unique_types = df[col][unhashable_values].apply(lambda x: type(x)).unique()
unhashable_cols[col] = unique_types
Expand Down
45 changes: 45 additions & 0 deletions tests/run_package_grouping_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# Importing the Feature Generating Class
from team_comm_tools import FeatureBuilder
import pandas as pd
import warnings
warnings.simplefilter("always", category=UserWarning)

# Main Function
if __name__ == "__main__":
Expand Down Expand Up @@ -231,3 +233,46 @@
user_aggregation = False,
)
custom_agg_fb_no_agg.featurize()


"""
Test that it's possible to run the FB with unhashable types that are NOT required columns.
"""

print("Testing unhashable types on non-required columns...")
test_df_with_nonrequired_unhashables = pd.DataFrame({
"conversation_id": [1, 1, 2, 2, 2],
"text": ["test1", "test2", "test3", "test4", "test5"],
"speaker_id": [1, 2, 1, 2, 1],
"unhashable_col": [{}, set(), ["foo", "bar"], {"foo": "bar"}, {"foo": ["bar1", "bar2", "bar3"]}]
})

fb_nonrequred_unhashable_test = FeatureBuilder(
input_df = test_df_with_nonrequired_unhashables,
conversation_id_col = "conversation_id",
message_col = "text",
speaker_id_col = "speaker_id"
)

"""
Test that, if we have unhashable types as required columns, we throw a ValueError as expected.
"""

print("Testing unhashable types on required columns...")
test_df_with_required_unhashables = pd.DataFrame({
"conversation_id": [[1], [1], [2], [2], [2]], # conversation_id is a required col and contains lists
"text": ["test1", "test2", "test3", "test4", "test5"],
"speaker_id": [1, 2, 1, 2, 1]
})

try:
fb_required_unhashable_test = FeatureBuilder(
input_df = test_df_with_required_unhashables,
conversation_id_col = "conversation_id",
message_col = "text",
speaker_id_col = "speaker_id"
)
except ValueError as e:
assert('has unhashable data types' in str(e)) # make sure we caught the right error!
print("Test has properly exited with error.")

0 comments on commit 4e4097f

Please sign in to comment.