Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jackaldenryan committed Jan 10, 2025
1 parent 6f70ad6 commit f0c1bd9
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 530 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ cython_debug/
# Cache files
cache.db*


# LongMemEval data
longmemeval_data/

# All DS_Store files
.DS_Store
287 changes: 0 additions & 287 deletions tests/evals/data/LongMemEval_mini_dataset_loading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2113,293 +2113,6 @@
"# ######## Save to csv\n",
"# lme_dataset_df_filtered_human_labelling.to_csv(\"lme_dataset_df_filtered_human_labelling.csv\", index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Archive"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Deserialize the JSON strings back into lists of dictionaries\n",
"list1_dicts = json.loads(eval_minidataset_labelled.iloc[0][output_column_name])\n",
"list2_dicts = json.loads(eval_minidataset_labelled.iloc[1][output_column_name])\n",
"\n",
"list3_dicts = json.loads(eval_minidataset_labelled.iloc[6][output_column_name])\n",
"list4_dicts = json.loads(eval_minidataset_labelled.iloc[7][output_column_name])\n",
"list5_dicts = json.loads(eval_minidataset_labelled.iloc[12][output_column_name])\n",
"list6_dicts = json.loads(eval_minidataset_labelled.iloc[13][output_column_name])\n",
"\n",
"# Convert the dictionaries back into EntityNode objects\n",
"list1_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list1_dicts]\n",
"list2_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list2_dicts]\n",
"list3_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list3_dicts]\n",
"list4_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list4_dicts]\n",
"list5_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list5_dicts]\n",
"list6_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list6_dicts]\n",
"\n",
"# Now list1_nodes, list2_nodes, list3_nodes, and list4_nodes contain the deserialized EntityNode objects\n",
"\n",
"#Convert Each List Into A List Of Only The Node Names\n",
"list1_node_names = [node.name for node in list1_nodes]\n",
"list2_node_names = [node.name for node in list2_nodes]\n",
"list3_node_names = [node.name for node in list3_nodes]\n",
"list4_node_names = [node.name for node in list4_nodes]\n",
"list5_node_names = [node.name for node in list5_nodes]\n",
"list6_node_names = [node.name for node in list6_nodes]\n",
"\n",
"\n",
"#Print Out The Lists\n",
"print(list1_node_names)\n",
"print(list2_node_names)\n",
"print(list3_node_names)\n",
"print(list4_node_names)\n",
"print(list5_node_names)\n",
"print(list6_node_names)\n",
"\n",
"#NowCollectAndPrintTheNoteSummaries For Each List\n",
"\n",
"# Function to print node summaries\n",
"def print_node_summaries(node_list, list_name):\n",
" print(f\"Node summaries for {list_name}:\")\n",
" for node in node_list:\n",
" # Assuming each node has a 'name' and 'attributes' attribute\n",
" print(f\"Name: {node.name}, Summary: {node.summary}\")\n",
" print(\"\\n\")\n",
"\n",
"# Print node summaries for each list\n",
"print_node_summaries(list1_nodes, \"list1\")\n",
"print_node_summaries(list2_nodes, \"list2\")\n",
"print_node_summaries(list3_nodes, \"list3\")\n",
"print_node_summaries(list4_nodes, \"list4\")\n",
"print_node_summaries(list5_nodes, \"list5\")\n",
"print_node_summaries(list6_nodes, \"list6\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# def get_output(row, model):\n",
"# \"\"\"\n",
"# Gets the output for a given model\n",
"# \"\"\"\n",
"# if row['task_name'] == \"extract_nodes\":\n",
"# return row['input_extracted_nodes']\n",
"# elif row['task_name'] == \"dedupe_nodes\":\n",
"# return row['input_existing_relevant_nodes']\n",
"# elif row['task_name'] == \"extract_edges\":\n",
"# return row['input_extracted_edges']\n",
"# elif row['task_name'] == \"dedupe_edges\":\n",
"# return row['input_existing_relevant_edges']\n",
"# elif row['task_name'] == \"extract_edge_dates\":\n",
"# return row['input_extracted_edge_dates']\n",
"# elif row['task_name'] == \"edge_invalidation\":\n",
"# return row['input_edge_invalidation']\n",
"\n",
"# def insert_gpt4o_answers(df):\n",
"# \"\"\"\n",
"# Inserts gpt4o answers by doing ingestion in the right order and filling extra input columns as needed\n",
"# \"\"\"\n",
"# for _, row in df.iterrows():\n",
"# # Get the output\n",
"# output_gpt4o = get_output(row, \"gpt4o\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# ######## Write a multisession to a txt file\n",
"# question_id = \"gpt4_93159ced_abs\"\n",
"# multi_session = lme_dataset_df[lme_dataset_df['question_id'] == question_id][\"haystack_sessions\"].iloc[0]\n",
"\n",
"\n",
"# with open(f'{question_id}.txt', 'w') as f:\n",
"\n",
"\n",
"# for session in multi_session:\n",
"# f.write(\"New session \" + \"*\"*200 + \"\\n\")\n",
"# for message in session:\n",
" \n",
"# f.write(f\"{message}\\n\")\n",
"# f.write(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# ######## Method to save all of the snippets (or only firsts/lasts) of the specified multi-sessions to a CSV file\n",
"\n",
"\n",
"# def make_messages_readable(messages):\n",
"# if len(messages) == 0:\n",
"# return []\n",
"# if messages == [None]:\n",
"# return None\n",
"# result_string = \"\"\n",
"# for message in messages:\n",
"# if message is None:\n",
"# continue\n",
"# result_string += \"|\"*80 + f\" {message['role']} \" + \"|\"*80 + \"\\n\\n\" + f\"{message['content']}\\n\\n\"\n",
"# return result_string\n",
"\n",
"\n",
"\n",
"# def handle_message_readability(snippet, readableMessages, spreadOutPreviousMessages, max_num_previous_messages):\n",
"# if readableMessages:\n",
"# snippet['message'] = make_messages_readable([snippet['message']])\n",
"# if spreadOutPreviousMessages:\n",
"# for i in range(max_num_previous_messages):\n",
"# snippet[f\"previous_message_{i+1}\"] = make_messages_readable([snippet[f\"previous_message_{i+1}\"]])\n",
"# return snippet\n",
"# snippet['previous_messages'] = make_messages_readable(snippet['previous_messages'])\n",
"# return snippet\n",
" \n",
"# snippet['message'] = json.dumps(snippet['message'])\n",
"# if spreadOutPreviousMessages:\n",
"# for i in range(max_num_previous_messages):\n",
"# snippet[f\"previous_message_{i+1}\"] = json.dumps(snippet[f\"previous_message_{i+1}\"])\n",
"# return snippet\n",
"# snippet['previous_messages'] = json.dumps(snippet['previous_messages'])\n",
"# return snippet\n",
"\n",
"# def save_multi_session_snippets_to_csv(question_ids, max_num_previous_messages=5, mode=\"all\", readableMessages=False, spreadOutPreviousMessages=False, snippet_index=None):\n",
"# \"\"\"\n",
"# Creates a csv where each row is a \"snippet\" from longmemeval. A snippet is a message and set of previous messages.\n",
"\n",
"# mode:\n",
"# mode=\"all\" --> saves all possible snippets for the specified question_ids. \n",
"# mode=\"firsts_only\" --> saves only the first snippet for each question_id.\n",
"# mode=\"lasts_only\" --> saves only the last snippet for each question_id.\n",
"# mode=\"index_only\" --> saves only the snippet at snippet_index for each question_id.\n",
"\n",
"# readableMessages:\n",
"# readableMessages=True --> saves the messages in a readable format, intended for reading in Google sheets.\n",
"# readableMessages=False --> saves the messages in a json format.\n",
"\n",
"# spreadOutPreviousMessages:\n",
"# spreadOutPreviousMessages=True --> spreads out the previous messages across multiple columns.\n",
"# spreadOutPreviousMessages=False --> saves all previous messages in a single column.\n",
" \n",
"# \"\"\"\n",
"\n",
"# all_snippets = []\n",
"# indices = []\n",
"# for question_id in question_ids:\n",
" \n",
"# ######## First, lets combine the sessions and dates into a list of dicts\n",
"# row = lme_dataset_df[lme_dataset_df['question_id'] == question_id]\n",
"\n",
"# if row.empty:\n",
"# raise ValueError(f\"No question found with ID: {question_id}\")\n",
"\n",
"\n",
"# # Extract the haystack_sessions column value\n",
"# sessions = row['haystack_sessions'].iloc[0]\n",
"\n",
"# # Get the haystack_dates column value\n",
"# session_dates = row['haystack_dates'].iloc[0]\n",
"\n",
"# # Combine into list of dictionaries\n",
"# session_and_dates = [\n",
"# {\n",
"# \"session\": session,\n",
"# \"date\": datetime.strptime(date, \"%Y/%m/%d (%a) %H:%M\")\n",
"# } \n",
"# for session, date in zip(sessions, session_dates)\n",
"# ]\n",
"\n",
"# # Sort by date from earliest to latest\n",
"# session_and_dates.sort(key=lambda x: x[\"date\"])\n",
"\n",
"\n",
"# all_snippets_this_session = []\n",
"\n",
"# total_num_messages = sum([len(session_and_date[\"session\"]) for session_and_date in session_and_dates])\n",
"\n",
"# message_index_across_sessions = 0\n",
"# for session_index, session_and_date in enumerate(session_and_dates):\n",
"# for message_index_within_session, message in enumerate(session_and_date[\"session\"]):\n",
" \n",
"# num_previous_messages = min(max_num_previous_messages, message_index_across_sessions)\n",
"# previous_snippets = all_snippets_this_session[message_index_across_sessions-num_previous_messages:]\n",
"# previous_messages_only = [previous_snippet[\"message\"] for previous_snippet in previous_snippets]\n",
"\n",
"# snippet = {\n",
"# \"question_id\": question_id,\n",
"# \"multisession_index\": int(row.index[0]),\n",
"# \"session_index\": session_index,\n",
"# \"message_index_within_session\": message_index_within_session,\n",
"# \"session_date\": session_and_date[\"date\"],\n",
"# }\n",
"\n",
"# if spreadOutPreviousMessages:\n",
"# previous_messages_only_padded = previous_messages_only + [None] * (max_num_previous_messages - len(previous_messages_only))\n",
"# for i, prev_msg in enumerate(previous_messages_only_padded):\n",
"# snippet[f\"previous_message_{i+1}\"] = prev_msg\n",
"# snippet[\"message\"] = message\n",
"# else:\n",
"# snippet[\"message\"] = message\n",
"# snippet[\"previous_messages\"] = previous_messages_only\n",
"\n",
"# all_snippets_this_session.append(snippet)\n",
"# message_index_across_sessions += 1\n",
"\n",
"# if mode == \"firsts_only\":\n",
"# all_snippets_this_session = [all_snippets_this_session[0]]\n",
"# elif mode == \"lasts_only\":\n",
"# all_snippets_this_session = [all_snippets_this_session[-1]]\n",
"# elif mode == \"index_only\":\n",
"# all_snippets_this_session = [all_snippets_this_session[snippet_index]]\n",
"\n",
"# all_snippets.extend(all_snippets_this_session)\n",
"# indices.append(int(row.index[0]))\n",
"\n",
"\n",
"# filename = \"lme_samples_indices=\"\n",
"# indices.sort() \n",
"# num_indices = len(indices)\n",
"# if num_indices < 4:\n",
"# for index in indices:\n",
"# filename += f\"_{index}\"\n",
"# else:\n",
"# filename += f\"_{indices[0]}_etc_{indices[-1]}\"\n",
"\n",
"# if mode == \"firsts_only\":\n",
"# filename += \"_firsts_only\"\n",
"# elif mode == \"lasts_only\":\n",
"# filename += \"_lasts_only\"\n",
"# elif mode == \"index_only\":\n",
"# filename += f\"_index_only={snippet_index}\"\n",
"# if spreadOutPreviousMessages:\n",
"# filename += \"_spreadOutPreviousMessages\"\n",
"# if readableMessages:\n",
"# filename += \"_readable\"\n",
"# filename += \".csv\"\n",
" \n",
" \n",
"# with open(filename, \"w\", newline=\"\") as csvfile:\n",
"# writer = csv.DictWriter(csvfile, fieldnames=all_snippets[0].keys())\n",
"# writer.writeheader()\n",
"# for snippet in all_snippets:\n",
"# processed_snippet = handle_message_readability(snippet, readableMessages, spreadOutPreviousMessages, max_num_previous_messages)\n",
"# writer.writerow(processed_snippet)"
]
}
],
"metadata": {
Expand Down
2 changes: 2 additions & 0 deletions tests/evals/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
)
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps([entity_to_dict(edge) for edge in extracted_edges])

####### TODO: finish implementing below

#### Process 'dedupe_edges' task
# dedupe_edges_row = message_df[message_df['task_name'] == 'dedupe_edges']
# assert len(dedupe_edges_row) == 1, "There should be exactly one row for 'dedupe_edges'"
Expand Down
Loading

0 comments on commit f0c1bd9

Please sign in to comment.