Skip to content

Commit

Permalink
vector summarization updates
Browse files Browse the repository at this point in the history
  • Loading branch information
amytangzheng committed Dec 18, 2024
1 parent 0c89fa9 commit 57bfe88
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 125 deletions.
129 changes: 78 additions & 51 deletions examples/featurize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,59 +40,86 @@
"""

# Tiny Juries
print("Tiny Juries Example...")
tiny_juries_feature_builder = FeatureBuilder(
input_df = tiny_juries_df,
grouping_keys = ["batch_num", "round_num"],
output_file_base = "jury_TINY_output", # Naming output files using the output_file_base parameter (recommended)
# turns = False, -- ERROR: with turns being true
turns = True,
custom_features = [
"(BERT) Mimicry",
"Moving Mimicry",
"Forward Flow",
"Discursive Diversity"]
)
tiny_juries_feature_builder.featurize()

# Tiny Juries with custom aggregations
print("Tiny Juries with Custom Aggregation...")
tiny_juries_feature_builder_custom_agg = FeatureBuilder(
input_df = tiny_juries_df,
grouping_keys = ["batch_num", "round_num"],
output_file_base = "jury_TINY_output_custom_agg", # Naming output files using the output_file_base parameter (recommended)
turns = False,
custom_features = [
"(BERT) Mimicry",
"Moving Mimicry",
"Forward Flow",
"Discursive Diversity"],
convo_methods = ['max', 'median'], # This will aggregate ONLY the "positive_bert" at the conversation level, using max and median.
convo_columns = ['positive_bert'],
user_methods = ['mean'], # This will aggregate ONLY "negative_bert" at the speaker/user level, using mean.
user_columns = ['negative_bert'],
)
tiny_juries_feature_builder.featurize(col="message")

# Tiny multi-task
# tiny_multi_task_feature_builder = FeatureBuilder(
# input_df = tiny_multi_task_df,
# conversation_id_col = "stageId",
# # alternatively, you can name each output file separately. NOTE, however, that we don't directly use this path;
# # we modify the path to place outputs within the `output/chat`, `output/conv`, and `output/user` folders.
# output_file_path_chat_level = "./multi_task_TINY_output_chat_level_stageId_cumulative.csv",
# output_file_path_user_level = "./multi_task_TINY_output_user_level_stageId_cumulative.csv",
# output_file_path_conv_level = "./multi_task_TINY_output_conversation_level_stageId_cumulative.csv",
# turns = False
# Test Vectors
# print("Testing vectors valid ...")
# test_vector_feature_builder = FeatureBuilder(
# input_df = test_vector_df,
# output_file_base = "test_vector",
# custom_vect_path = "../tests/vector_data/sentence/chats/test_vector_valid.csv",
# turns = False,
# )
# tiny_multi_task_feature_builder.featurize()

# FULL DATASETS BELOW ------------------------------------- #
# test_vector_feature_builder.featurize()

# Juries
# jury_feature_builder = FeatureBuilder(
# input_df = juries_df,
valid_df = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
vector_row_mismatch_df = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
vector_data_mismatch_df = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
no_message_embedding_df = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
no_turn_level_data_df = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
vect_diff_length_df = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
vect_null = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
vect_nan = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
vect_no_one_to_one = pd.read_csv("../tests/vector_data/sentence/chats/test_vector_valid.csv", encoding='utf-8')
test_convo_num_issue = pd.read_csv("../tests/vector_data/sentence/chats/test_turns_convo_num_issue.csv", encoding='utf-8')

# test number of rows mismatch
vector_row_mismatch_df = vector_row_mismatch_df.iloc[:-1]

# test chat data not equal to vector data (message)
vector_data_mismatch_df.loc[0, 'message'] = 'goodbye'

# test no message_embedding column
no_message_embedding_df.rename(columns={'message_embedding': 'temp'}, inplace=True)

# test vectors not same length
vector_data_mismatch_df.loc[0, 'message_embedding'] = '[0.9]'

# test null vectors
vect_null.loc[0, 'message_embedding'] = '[]'

# test nan vectors
vect_nan.loc[0, 'message_embedding'] = '[np.nan, np.nan]'

# test no 1-1 mapping
vect_no_one_to_one.loc[0, 'message_embedding'] = '[0.1, 0.2]'

test_cases = {
"Valid DataFrame": valid_df,
"Vector Row Mismatch": vector_row_mismatch_df,
"Vector Data Mistmatch": vector_data_mismatch_df,
"No Message Embedding Column": no_message_embedding_df,
"No Turn-Level Data (turns=True)": no_turn_level_data_df,
"Vectors Not of Same Length": vector_data_mismatch_df,
"Vectors Null": vect_null,
"Vectors Nan": vect_nan,
"Custom File Equals Default Dir": valid_df,
"No 1-1 Mapping": vect_no_one_to_one,
}

for name, df in test_cases.items():
custom_vect_path = "../tests/vector_data/sentence/chats/test_vector.csv"
print(name)
df.to_csv(custom_vect_path, index=False, encoding='utf-8')

if name == "No Turn-Level Data (turns=True)":
turns = True
else:
turns = False

if name == "Custom File Equals Default Dir":
custom_vect_path = "./vector_data/sentence/chats/test_vector_chat_level.csv"

test_vector_feature_builder = FeatureBuilder(
input_df=test_vector_df,
output_file_base="test_vector",
custom_vect_path=custom_vect_path,
# Simulate turns=True for "No Turn-Level Data" case
turns=turns,
)
test_vector_feature_builder.featurize()

# Tiny Juries
# tiny_juries_feature_builder = FeatureBuilder(
# input_df = tiny_juries_df,
# grouping_keys = ["batch_num", "round_num"],
# vector_directory = "./vector_data/",
# # custom_vect_path = "C:/Users/amyta/Documents/GitHub/team_comm_tools/examples/vector_data/sentence/turns/jury_TINY_output_chat_level.csv", # testing turns = False but data mismatch
Expand Down
Loading

0 comments on commit 57bfe88

Please sign in to comment.