Skip to content

Commit

Permalink
updates to fix message_original error
Browse files Browse the repository at this point in the history
  • Loading branch information
amytangzheng committed Dec 19, 2024
1 parent d8291db commit d8377dc
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 244 deletions.
2 changes: 1 addition & 1 deletion src/team_comm_tools/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def __init__(
need_sentiment = True

# check_embeddings(self.chat_data, self.vect_path, self.bert_path, need_sentence, need_sentiment, self.regenerate_vectors, message_col = self.vector_colname)
check_embeddings(self.chat_data, self.vect_path, self.bert_path, self.original_vect_path, need_sentence, need_sentiment, self.regenerate_vectors, message_col = self.vector_colname)
check_embeddings(self.chat_data, self.vect_path, self.bert_path, self.original_vect_path, need_sentence, need_sentiment, self.regenerate_vectors, message_col = self.vector_colname, custom_vect = custom_vect_path is not None)

if(need_sentence):
self.vect_data = pd.read_csv(self.vect_path, encoding='mac_roman')
Expand Down
108 changes: 55 additions & 53 deletions src/team_comm_tools/utils/check_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# Check if embeddings exist
def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, original_vect_path: str, need_sentence: bool,
need_sentiment: bool, regenerate_vectors: bool, message_col: str = "message"):
need_sentiment: bool, regenerate_vectors: bool, message_col: str = "message", custom_vect: bool = False):
"""
Check if embeddings and required lexicons exist, and generate them if they don't.
Expand All @@ -51,6 +51,8 @@ def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, or
:type regenerate_vectors: bool, optional
:param message_col: A string representing the column name that should be selected as the message. Defaults to "message".
:type message_col: str, optional
:param custom_vect: Whether the user has passed in custom vectors
:type custom_vect: bool, optional
:return: None
:rtype: None
Expand All @@ -61,64 +63,64 @@ def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, or
generate_bert(chat_data, bert_path, message_col)

try:
vector_df = pd.read_csv(vect_path)

# check whether the given vector and bert data matches length of chat data
if len(vector_df) != len(chat_data):
print("ERROR: The length of the vector data does not match the length of the chat data. Regenerating...")
# reset vector path to default/original
generate_vect(chat_data, original_vect_path, message_col)
else:
# check that message in vector data matches chat data
preprocessed_chat = chat_data[message_col].astype(str).apply(preprocess_text)
if custom_vect == True:
vector_df = pd.read_csv(vect_path)

# preprocess vector data, remove _original if message_col contains to preprocess the text
if '_original' in message_col:
message_col = message_col.replace('_original', '')

print(message_col, message_col[:-9])
preprocessed_vector = vector_df[message_col].astype(str).apply(preprocess_text)


mismatches = chat_data[preprocessed_chat != preprocessed_vector]
if len(mismatches) != 0:
print("Messages in the vector data do not match the chat data. Regenerating...")
# check whether the given vector and bert data matches length of chat data
if len(vector_df) != len(chat_data):
print("ERROR: The length of the vector data does not match the length of the chat data. Regenerating...")
# reset vector path to default/original
generate_vect(chat_data, original_vect_path, message_col)

if "message_embedding" in vector_df.columns:
# check that message_embedding is numeric list
if not vector_df["message_embedding"].apply(is_numeric_list).all():
print("message_embedding is not a numeric list. Regenerating ...")
else:
# check that message in vector data matches chat data
preprocessed_chat = chat_data[message_col].astype(str).apply(preprocess_text).fillna("")

# preprocess vector data, remove _original if message_col contains to preprocess the text
while '_original' in message_col:
message_col = message_col.replace('_original', '')

# print(message_col)
preprocessed_vector = vector_df[message_col].astype(str).apply(preprocess_text).fillna("")

mismatches = chat_data[preprocessed_chat != preprocessed_vector]
if len(mismatches) != 0:
print("Messages in the vector data do not match the chat data. Regenerating...")
generate_vect(chat_data, original_vect_path, message_col)
else:
# check if length of all vectors is the same
vect_lengths = vector_df["message_embedding"].apply(lambda x: ast.literal_eval(x)).apply(lambda x : len(x))

if (vect_lengths == 0).any():
print("One or more value in message_embedding are null. Regenerating ...")
generate_vect(chat_data, original_vect_path, message_col)

if len(vect_lengths.unique()) > 1:
print("Not all vectors have the same length. Regenerating ...")
generate_vect(chat_data, original_vect_path, message_col)

# check if vectors have a 1-1 mapping with the text
embedding_message_map = {}
for _, row in vector_df.iterrows():
embedding = row['message_embedding']
message = row['message']

if embedding in embedding_message_map:
if message != embedding_message_map[embedding]:
print("Same embedding maps to multiple unique messages. Regenerating ...")
if "message_embedding" in vector_df.columns:
# check that message_embedding is numeric list
if not vector_df["message_embedding"].apply(is_numeric_list).all():
print("message_embedding is not a numeric list. Regenerating ...")
generate_vect(chat_data, original_vect_path, message_col)
else:
# check if length of all vectors is the same
vect_lengths = vector_df["message_embedding"].apply(lambda x: ast.literal_eval(x)).apply(lambda x : len(x))

if (vect_lengths == 0).any():
print("One or more value in message_embedding are null. Regenerating ...")
generate_vect(chat_data, original_vect_path, message_col)

if len(vect_lengths.unique()) > 1:
print("Not all vectors have the same length. Regenerating ...")
generate_vect(chat_data, original_vect_path, message_col)
break
else:
embedding_message_map[embedding] = message

else:
print("no message_embedding column. Regenerating ...")
generate_vect(chat_data, original_vect_path, message_col)
# check if vectors have a 1-1 mapping with the text
embedding_message_map = {}
for _, row in vector_df.iterrows():
embedding = row['message_embedding']
message = row['message']

if embedding in embedding_message_map:
if message != embedding_message_map[embedding]:
print("Same embedding maps to multiple unique messages. Regenerating ...")
generate_vect(chat_data, original_vect_path, message_col)
break
else:
embedding_message_map[embedding] = message

else:
print("no message_embedding column. Regenerating ...")
generate_vect(chat_data, original_vect_path, message_col)

except FileNotFoundError: # It's OK if we don't have the path, if the sentence vectors are not necessary
if need_sentence:
Expand Down
Loading

0 comments on commit d8377dc

Please sign in to comment.