diff --git a/src/team_comm_tools/utils/check_embeddings.py b/src/team_comm_tools/utils/check_embeddings.py index 632826d5..866a512e 100644 --- a/src/team_comm_tools/utils/check_embeddings.py +++ b/src/team_comm_tools/utils/check_embeddings.py @@ -72,9 +72,14 @@ def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, or # check that message in vector data matches chat data preprocessed_chat = chat_data[message_col].astype(str).apply(preprocess_text) - # preprocess vector data + # preprocess vector data, remove _original if message_col contains to preprocess the text + if '_original' in message_col: + message_col = message_col.replace('_original', '') + + print(message_col, message_col[:-9]) preprocessed_vector = vector_df[message_col].astype(str).apply(preprocess_text) + mismatches = chat_data[preprocessed_chat != preprocessed_vector] if len(mismatches) != 0: print("Messages in the vector data do not match the chat data. Regenerating...")