Skip to content

Commit

Permalink
Add a check for unhashable values, refactor some functions to improve…
Browse files Browse the repository at this point in the history
… readability
  • Loading branch information
sundy1994 committed Feb 28, 2025
1 parent 3333d9b commit 2bf6896
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 250 deletions.
183 changes: 96 additions & 87 deletions src/team_comm_tools/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,39 +148,47 @@ def __init__(
) -> None:

# Some error catching
if type(input_df) != pd.DataFrame:
raise ValueError("You must pass in a valid dataframe as the input_df!")
if not vector_directory:
raise ValueError("You must pass in a valid directory to cache vectors! For example: ./vector_data/")
if not isinstance(input_df, pd.DataFrame):
raise TypeError(f"Expected a Pandas DataFrame as input_df, but got {type(df).__name__})")

# if not vector_directory:
# raise ValueError("You must pass in a valid directory to cache vectors! For example: ./vector_data/") # NOTE: This is redundant because we have a default value

print("Initializing Featurization...")

###### Set all parameters ######

assert(all(0 <= x <= 1 for x in analyze_first_pct)) # first, type check that this is a list of numbers between 0 and 1
self.first_pct = analyze_first_pct # Set first pct of conversation you want to analyze
self.turns = turns
self.conversation_id_col = conversation_id_col
self.speaker_id_col = speaker_id_col
self.message_col = message_col
self.timestamp_col = timestamp_col
self.timestamp_unit = timestamp_unit
self.column_names = {
'conversation_id_col': conversation_id_col,
'speaker_id_col': speaker_id_col,
'message_col': message_col,
'timestamp_col': timestamp_col
}
self.grouping_keys = grouping_keys
self.cumulative_grouping = cumulative_grouping # for grouping the chat data
self.within_task = within_task
self.ner_cutoff = ner_cutoff
self.regenerate_vectors = regenerate_vectors
self.convo_aggregation = convo_aggregation
self.convo_methods = convo_methods
self.convo_columns = convo_columns
self.user_aggregation = user_aggregation
self.user_methods = user_methods
self.user_columns = user_columns
# Defining input and output paths.
self.chat_data = input_df.copy()
self.orig_data = input_df.copy()
self.ner_training = ner_training_df
self.vector_directory = vector_directory

print("Initializing Featurization...")

if not custom_liwc_dictionary_path:
self.custom_liwc_dictionary = {}
else:
# Read .dic file if the path is provided
custom_liwc_dictionary_path = Path(custom_liwc_dictionary_path)
if not custom_liwc_dictionary_path.exists():
print(f"WARNING: The custom LIWC dictionary file does not exist: {custom_liwc_dictionary_path}")
self.custom_liwc_dictionary = {}
elif not custom_liwc_dictionary_path.suffix == '.dic':
print(f"WARNING: The custom LIWC dictionary file is not a .dic file: {custom_liwc_dictionary_path}")
self.custom_liwc_dictionary = {}
else:
with open(custom_liwc_dictionary_path, 'r', encoding='utf-8-sig') as file:
dicText = file.read()
try:
self.custom_liwc_dictionary = load_liwc_dict(dicText)
except Exception as e:
print(f"WARNING: Failed loading custom liwc dictionary: {e}")
self.custom_liwc_dictionary = {}

self.custom_liwc_dictionary = self.load_custem_liwc_dict(custom_liwc_dictionary_path)
# Set features to generate
# TODO --- think through more carefully which ones we want to exclude and why
self.feature_dict = feature_dict
Expand Down Expand Up @@ -219,7 +227,6 @@ def __init__(
"Conversation Level Aggregates",
"User Level Aggregates"
]

# warning if user added invalid custom/exclude features
self.custom_features = []
invalid_features = set()
Expand All @@ -230,14 +237,12 @@ def __init__(
invalid_features.add(feat)
if invalid_features:
invalid_features_str = ', '.join(invalid_features)
print(f"WARNING: Invalid custom features provided. Ignoring `{invalid_features_str}`.")

warnings.warn(f"WARNING: Invalid custom features provided. Ignoring `{invalid_features_str}`.")
# keep track of which features we are generating
self.feature_names = self.default_features + self.custom_features
# remove named entities if we didn't pass in the column
if(self.ner_training is None):
self.feature_names.remove("Named Entity Recognition")

# deduplicate functions and append them into a list for calculation
self.feature_methods_chat = []
self.feature_methods_conv = []
Expand All @@ -258,65 +263,13 @@ def __init__(
self.chat_data = self.chat_data.drop(columns=columns_to_drop)
self.orig_data = self.orig_data.drop(columns=columns_to_drop)

# Set first pct of conversation you want to analyze
assert(all(0 <= x <= 1 for x in analyze_first_pct)) # first, type check that this is a list of numbers between 0 and 1
self.first_pct = analyze_first_pct

# Parameters for preprocessing chat data
self.turns = turns
self.conversation_id_col = conversation_id_col
self.speaker_id_col = speaker_id_col
self.message_col = message_col
self.timestamp_col = timestamp_col
self.timestamp_unit = timestamp_unit
self.column_names = {
'conversation_id_col': conversation_id_col,
'speaker_id_col': speaker_id_col,
'message_col': message_col,
'timestamp_col': timestamp_col
}
self.grouping_keys = grouping_keys
self.cumulative_grouping = cumulative_grouping # for grouping the chat data
self.within_task = within_task
self.ner_cutoff = ner_cutoff
self.regenerate_vectors = regenerate_vectors
self.convo_aggregation = convo_aggregation
self.convo_methods = convo_methods
self.convo_columns = convo_columns
self.user_aggregation = user_aggregation
self.user_methods = user_methods
self.user_columns = user_columns

if(compute_vectors_from_preprocessed == True):
if compute_vectors_from_preprocessed:
self.vector_colname = self.message_col # because the message col will eventually get preprocessed
else:
self.vector_colname = self.message_col + "_original" # because this contains the original message

# check grouping rules
if self.conversation_id_col not in self.chat_data.columns and len(self.grouping_keys)==0:
if(self.conversation_id_col == "conversation_num"):
raise ValueError("Conversation identifier not present in data. Did you perhaps forget to pass in a `conversation_id_col`?")
raise ValueError("Conversation identifier not present in data.")
if self.cumulative_grouping and len(grouping_keys) == 0:
warnings.warn("WARNING: No grouping keys provided. Ignoring `cumulative_grouping` argument.")
self.cumulative_grouping = False
if self.cumulative_grouping and len(grouping_keys) != 3:
warnings.warn("WARNING: Can only perform cumulative grouping for three-layer nesting. Ignoring cumulative command and grouping by unique combinations in the grouping_keys.")
self.cumulative_grouping = False
self.conversation_id_col = "conversation_num"
if self.cumulative_grouping and self.conversation_id_col not in self.grouping_keys:
raise ValueError("Conversation identifier for cumulative grouping must be one of the grouping keys.")
if self.grouping_keys and not self.cumulative_grouping and self.conversation_id_col != "conversation_num":
warnings.warn("WARNING: When grouping by the unique combination of a list of keys (`grouping_keys`), the conversation identifier must be auto-generated (`conversation_num`) rather than a user-provided column. Resetting conversation_id.")
self.conversation_id_col = "conversation_num"

self.preprocess_chat_data()

# set new identifier column for cumulative grouping.
if self.cumulative_grouping and len(grouping_keys) == 3:
warnings.warn("NOTE: User has requested cumulative grouping. Auto-generating the key `conversation_num` as the conversation identifier for cumulative conversations.")
self.conversation_id_col = "conversation_num"

# Set all paths for vector retrieval (contingent on turns)
df_type = "turns" if self.turns else "chats"
if(self.cumulative_grouping): # create special vector paths for cumulative groupings
Expand Down Expand Up @@ -450,6 +403,8 @@ def __init__(
# Deriving the base conversation level dataframe.
self.conv_data = self.chat_data[[self.conversation_id_col]].drop_duplicates()



def set_self_conv_data(self) -> None:
"""
Derives the base conversation level dataframe.
Expand Down Expand Up @@ -585,16 +540,38 @@ def preprocess_chat_data(self) -> None:
:return: None
:rtype: None
"""
# check grouping rules
if self.conversation_id_col not in self.chat_data.columns and len(self.grouping_keys)==0:
if(self.conversation_id_col == "conversation_num"):
raise ValueError("Conversation identifier not present in data. Did you perhaps forget to pass in a `conversation_id_col`?")
raise ValueError("Conversation identifier not present in data.")
if self.cumulative_grouping and len(self.grouping_keys) == 0:
warnings.warn("WARNING: No grouping keys provided. Ignoring `cumulative_grouping` argument.")
self.cumulative_grouping = False
if self.cumulative_grouping and len(self.grouping_keys) != 3:
warnings.warn("WARNING: Can only perform cumulative grouping for three-layer nesting. Ignoring cumulative command and grouping by unique combinations in the grouping_keys.")
self.cumulative_grouping = False
self.conversation_id_col = "conversation_num"
if self.cumulative_grouping and self.conversation_id_col not in self.grouping_keys:
raise ValueError("Conversation identifier for cumulative grouping must be one of the grouping keys.")
if self.grouping_keys and not self.cumulative_grouping and self.conversation_id_col != "conversation_num":
warnings.warn("WARNING: When grouping by the unique combination of a list of keys (`grouping_keys`), the conversation identifier must be auto-generated (`conversation_num`) rather than a user-provided column. Resetting conversation_id.")
self.conversation_id_col = "conversation_num"
# set new identifier column for cumulative grouping.
if self.cumulative_grouping and len(self.grouping_keys) == 3:
warnings.warn("NOTE: User has requested cumulative grouping. Auto-generating the key `conversation_num` as the conversation identifier for cumulative conversations.")
self.conversation_id_col = "conversation_num"

# create the appropriate grouping variables and assert the columns are present
self.chat_data = preprocess_conversation_columns(self.chat_data, self.conversation_id_col, self.timestamp_col, self.grouping_keys, self.cumulative_grouping, self.within_task)
assert_key_columns_present(self.chat_data, self.column_names)
self.chat_data = preprocess_conversation_columns(self.chat_data, self.column_names, self.grouping_keys, self.cumulative_grouping, self.within_task)
self.chat_data = remove_unhashable_cols(self.chat_data, self.column_names)

# save original column with no preprocessing
self.chat_data[self.message_col + "_original"] = self.chat_data[self.message_col]

# create new column that retains punctuation
self.chat_data["message_lower_with_punc"] = self.chat_data[self.message_col].astype(str).apply(preprocess_text_lowercase_but_retain_punctuation)
self.chat_data["message_lower_with_punc"] = self.chat_data[self.message_col].astype(str).apply(lambda x: x.lower())

# Preprocessing the text in `message_col` and then overwriting the column `message_col`.
# TODO: We should probably use classes to abstract preprocessing module as well?
Expand Down Expand Up @@ -723,4 +700,36 @@ def save_features(self) -> None:
"""
self.chat_data.to_csv(self.output_file_path_chat_level, index=False)
self.user_data.to_csv(self.output_file_path_user_level, index=False)
self.conv_data.to_csv(self.output_file_path_conv_level, index=False)
self.conv_data.to_csv(self.output_file_path_conv_level, index=False)

def load_custem_liwc_dict(self, custom_liwc_dictionary_path: str) -> dict:
"""
Load the custom LIWC dictionary from the provided path.
This function reads the custom LIWC dictionary from the provided path and returns the dictionary.
:param custom_liwc_dictionary_path: Path to the custom LIWC dictionary file
:type custom_liwc_dictionary_path: str
:return: Custom LIWC dictionary
:rtype: dict
"""
if not custom_liwc_dictionary_path:
return {}
else:
# Read .dic file if the path is provided
custom_liwc_dictionary_path = Path(custom_liwc_dictionary_path)
if not custom_liwc_dictionary_path.exists():
warnings.warn(f"WARNING: The custom LIWC dictionary file does not exist: {custom_liwc_dictionary_path}")
return {}
elif not custom_liwc_dictionary_path.suffix == '.dic':
warnings.warn(f"WARNING: The custom LIWC dictionary file is not a .dic file: {custom_liwc_dictionary_path}")
return {}
else:
with open(custom_liwc_dictionary_path, 'r', encoding='utf-8-sig') as file:
dicText = file.read()
try:
return load_liwc_dict(dicText)
except Exception as e:
warnings.warn(f"WARNING: Failed loading custom liwc dictionary: {e}")
return {}
Loading

0 comments on commit 2bf6896

Please sign in to comment.