Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy1994 committed Feb 28, 2025
1 parent 674b31d commit 09ccd7d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 46 deletions.
11 changes: 6 additions & 5 deletions src/team_comm_tools/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,14 +557,10 @@ def preprocess_chat_data(self) -> None:
if self.grouping_keys and not self.cumulative_grouping and self.conversation_id_col != "conversation_num":
warnings.warn("WARNING: When grouping by the unique combination of a list of keys (`grouping_keys`), the conversation identifier must be auto-generated (`conversation_num`) rather than a user-provided column. Resetting conversation_id.")
self.conversation_id_col = "conversation_num"
# set new identifier column for cumulative grouping.
if self.cumulative_grouping and len(self.grouping_keys) == 3:
warnings.warn("NOTE: User has requested cumulative grouping. Auto-generating the key `conversation_num` as the conversation identifier for cumulative conversations.")
self.conversation_id_col = "conversation_num"

# create the appropriate grouping variables and assert the columns are present
assert_key_columns_present(self.chat_data, self.column_names)
self.chat_data = preprocess_conversation_columns(self.chat_data, self.column_names, self.grouping_keys, self.cumulative_grouping, self.within_task)
# assert_key_columns_present(self.chat_data, self.column_names)
self.chat_data = remove_unhashable_cols(self.chat_data, self.column_names)

# save original column with no preprocessing
Expand All @@ -583,6 +579,11 @@ def preprocess_chat_data(self) -> None:
# Save the preprocessed data (so we don't have to do this again)
self.preprocessed_data = self.chat_data

# set new identifier column for cumulative grouping.
if self.cumulative_grouping and len(self.grouping_keys) == 3:
warnings.warn("NOTE: User has requested cumulative grouping. Auto-generating the key `conversation_num` as the conversation identifier for cumulative conversations.")
self.conversation_id_col = "conversation_num"

def chat_level_features(self) -> None:
"""
Instantiate and use the ChatLevelFeaturesCalculator to create chat-level features.
Expand Down
76 changes: 35 additions & 41 deletions src/team_comm_tools/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,7 @@
import pandas as pd
import warnings

def assert_key_columns_present(df: pd.DataFrame, column_names: dict) -> None:
"""Ensure that the DataFrame has essential columns and handle missing values.
This function if the essential columns `conversation_id_col`, `speaker_id_col`, and
`message_col` are present. If any of these columns are missing, a
KeyError is raised.
:param df: The DataFrame to check and process.
:type df: pandas.DataFrame
:param column_names: Columns to preprocess.
:type column_names: dict
:raises KeyError: If one of `conversation_id_col`, `speaker_id_col`, and `message_col` columns is missing.
"""

# conversation_id_col = column_names['conversation_id_col']
# speaker_id_col = column_names['speaker_id_col']
# message_col = column_names['message_col']

# remove all special characters from df
df.columns = df.columns.str.replace('[^A-Za-z0-9_]', '', regex=True)
# Assert that key columns are present
for role, col in column_names.items():
if role == 'timestamp_col':
continue # skip timestamp column
if col not in df.columns:
raise KeyError(f"Missing required columns in DataFrame: '{col}' (expected for {role})\n Columns available: {df.columns}")
else:
print(f"Confirmed that data has {role} column: {col}!")
df[col] = df[col].fillna('')

# if {conversation_id_col, speaker_id_col, message_col}.issubset(df.columns):
# print(f"Confirmed that data has conversation_id: {conversation_id_col}, speaker_id: {speaker_id_col} and message: {message_col} columns!")
# # ensure no NA's in essential columns #NOTE: This is moved to preprocess_conversation_columns
# df[conversation_id_col] = df[conversation_id_col].fillna(0)
# df[speaker_id_col] = df[speaker_id_col].fillna(0)
# df[message_col] = df[message_col].fillna('')
# else:
# print("One or more of conversation_id, speaker_id or message columns are missing! Raising error...")
# print("Columns available: ")
# print(df.columns)
# raise KeyError

def preprocess_conversation_columns(df: pd.DataFrame, column_names: dict, grouping_keys: list,
cumulative_grouping: bool = False, within_task: bool = False) -> pd.DataFrame:
Expand All @@ -66,7 +26,8 @@ def preprocess_conversation_columns(df: pd.DataFrame, column_names: dict, groupi
:return: The preprocessed DataFrame with a conversation number column.
:rtype: pd.DataFrame
"""

# remove all special characters from df
df.columns = df.columns.str.replace('[^A-Za-z0-9_]', '', regex=True)
if not grouping_keys: # case 1: single identifier
return df
if not set(grouping_keys).issubset(df.columns):
Expand All @@ -76,9 +37,42 @@ def preprocess_conversation_columns(df: pd.DataFrame, column_names: dict, groupi
else: # case 2: grouping multiple keys, or case 3 but not 3 layers
df['conversation_num'] = df.groupby(grouping_keys).ngroup()
df = df[df.columns.tolist()[-1:] + df.columns.tolist()[0:-1]] # make the new column first
# assert key columns are present
for role, col in column_names.items():
if role == 'timestamp_col':
continue # skip timestamp column
if col not in df.columns:
raise KeyError(f"Missing required columns in DataFrame: '{col}' (expected for {role})\n Columns available: {df.columns}")
else:
print(f"Confirmed that data has {role} column: {col}!")
df[col] = df[col].fillna('')

return df

def assert_key_columns_present(df: pd.DataFrame, column_names: dict) -> None:
"""Ensure that the DataFrame has essential columns and handle missing values.
This function if the essential columns `conversation_id_col`, `speaker_id_col`, and
`message_col` are present. If any of these columns are missing, a
KeyError is raised.
:param df: The DataFrame to check and process.
:type df: pandas.DataFrame
:param column_names: Columns to preprocess.
:type column_names: dict
:raises KeyError: If one of `conversation_id_col`, `speaker_id_col`, and `message_col` columns is missing.
"""

# Assert that key columns are present
for role, col in column_names.items():
if role == 'timestamp_col':
continue # skip timestamp column
if col not in df.columns:
raise KeyError(f"Missing required columns in DataFrame: '{col}' (expected for {role})\n Columns available: {df.columns}")
else:
print(f"Confirmed that data has {role} column: {col}!")
df[col] = df[col].fillna('')

def remove_unhashable_cols(df: pd.DataFrame, column_names: dict) -> pd.DataFrame:
"""
If a required column contains unhashable types, raise an error.
Expand Down

0 comments on commit 09ccd7d

Please sign in to comment.