-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add partial completion of full ingestion and other things
- Loading branch information
1 parent
34b4b9c
commit 6f3def2
Showing
7 changed files
with
3,362 additions
and
1,257 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
2,429 changes: 2,429 additions & 0 deletions
2,429
tests/evals/data/LongMemEval_mini_dataset_loading.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import pandas as pd | ||
from datetime import datetime, timedelta | ||
import json | ||
from graphiti_core.nodes import EpisodicNode, EpisodeType, EntityNode | ||
|
||
def create_episodes_from_messages(input_message, input_previous_messages): | ||
""" | ||
Create an episode and a list of previous episodes from input messages. | ||
""" | ||
# Current time for the episode | ||
current_time = datetime.now() | ||
|
||
# Create the current episode | ||
role = input_message["role"] | ||
content = input_message["content"] | ||
message_content = f"{role}: {content}" | ||
episode = EpisodicNode( | ||
name="", | ||
group_id="", | ||
source=EpisodeType.message, | ||
type=EpisodeType.message, | ||
source_description="", | ||
content=message_content, | ||
valid_at=current_time, | ||
) | ||
|
||
# Create previous episodes | ||
num_previous_messages = len(input_previous_messages) | ||
previous_times = [current_time - timedelta(minutes=num_previous_messages - i) for i in range(num_previous_messages)] | ||
previous_episodes = [ | ||
EpisodicNode( | ||
name="", | ||
group_id="", | ||
source=EpisodeType.message, | ||
source_description="", | ||
content=f"{message['role']}: {message['content']}", | ||
valid_at=previous_time, | ||
) | ||
for message, previous_time in zip(input_previous_messages, previous_times) | ||
] | ||
|
||
return episode, previous_episodes | ||
|
||
async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name): | ||
# Import necessary functions | ||
from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes | ||
from graphiti_core.utils.maintenance.edge_operations import extract_edges | ||
|
||
# Loop through each unique message_index_within_snippet in sorted order | ||
for message_index in sorted(snippet_df['message_index_within_snippet'].unique()): | ||
message_df = snippet_df[snippet_df['message_index_within_snippet'] == message_index] | ||
|
||
#### Process 'extract_nodes' task | ||
extract_nodes_row = message_df[message_df['task_name'] == 'extract_nodes'] | ||
assert len(extract_nodes_row) == 1, f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}" | ||
input_message = json.loads(extract_nodes_row.iloc[0]['input_message']) | ||
input_previous_messages = json.loads(extract_nodes_row.iloc[0]['input_previous_messages']) | ||
episode, previous_episodes = create_episodes_from_messages(input_message, input_previous_messages) | ||
extracted_nodes = await extract_nodes(llm_client, episode, previous_episodes) | ||
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in extracted_nodes]) | ||
|
||
#### Process 'dedupe_nodes' task | ||
dedupe_nodes_row = message_df[message_df['task_name'] == 'dedupe_nodes'] | ||
assert len(dedupe_nodes_row) == 1, "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}" | ||
|
||
# Calculate existing nodes list | ||
existing_nodes = [] | ||
for prev_message_index in sorted(snippet_df['message_index_within_snippet'].unique()): | ||
if prev_message_index >= message_index: | ||
break | ||
|
||
# Filter for previous messages with 'extract_nodes' task | ||
prev_message_df = snippet_df[ | ||
(snippet_df['message_index_within_snippet'] == prev_message_index) & | ||
(snippet_df['task_name'] == 'extract_nodes') | ||
] | ||
|
||
# Retrieve and deserialize the nodes | ||
serialized_nodes = prev_message_df.iloc[0][output_column_name] | ||
node_dicts = json.loads(serialized_nodes) | ||
nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in node_dicts] | ||
existing_nodes.extend(nodes) | ||
|
||
existing_nodes_lists = [existing_nodes for _ in range(len(extracted_nodes))] | ||
resolved_nodes, uuid_map = await resolve_extracted_nodes(llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes) | ||
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in resolved_nodes]) | ||
|
||
#### Process 'extract_edges' task | ||
extract_edges_row = message_df[message_df['task_name'] == 'extract_edges'] | ||
assert len(extract_edges_row) == 1, f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}" | ||
extracted_edges = await extract_edges( | ||
llm_client, | ||
episode, | ||
extracted_nodes, | ||
previous_episodes, | ||
group_id='', | ||
) | ||
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps([entity_to_dict(edge) for edge in extracted_edges]) | ||
|
||
#### Process 'dedupe_edges' task | ||
# dedupe_edges_row = message_df[message_df['task_name'] == 'dedupe_edges'] | ||
# assert len(dedupe_edges_row) == 1, "There should be exactly one row for 'dedupe_edges'" | ||
# output = dedupe_extracted_edge( | ||
# llm_client, | ||
# extracted_edge, | ||
# related_edges, | ||
# ) | ||
# snippet_df.at[dedupe_edges_row.index[0], output_column_name] = output | ||
|
||
#### Process 'extract_edge_dates' task | ||
# extract_edge_dates_row = message_df[message_df['task_name'] == 'extract_edge_dates'] | ||
# assert len(extract_edge_dates_row) == 1, "There should be exactly one row for 'extract_edge_dates'" | ||
# output = extract_edge_dates(extract_edge_dates_row.iloc[0]['input_extracted_edge_dates']) | ||
# snippet_df.at[extract_edge_dates_row.index[0], output_column_name] = output | ||
|
||
#### Process 'edge_invalidation' task | ||
# edge_invalidation_row = message_df[message_df['task_name'] == 'edge_invalidation'] | ||
# assert len(edge_invalidation_row) == 1, "There should be exactly one row for 'edge_invalidation'" | ||
# output = edge_invalidation(edge_invalidation_row.iloc[0]['input_edge_invalidation']) | ||
# snippet_df.at[edge_invalidation_row.index[0], output_column_name] = output | ||
|
||
return snippet_df | ||
|
||
|
||
async def ingest_and_label_minidataset(llm_client, minidataset_df, output_column_name): | ||
# Add a new column with the specified name, initialized with empty values | ||
minidataset_df[output_column_name] = None | ||
|
||
minidataset_labelled_df = None | ||
for snippet_index in sorted(minidataset_df['snippet_index'].unique()): | ||
snippet_df = minidataset_df[minidataset_df['snippet_index'] == snippet_index] | ||
|
||
# Pass the output column name to the ingest_and_label_snippet function | ||
snippet_df_labelled = await ingest_and_label_snippet(llm_client, snippet_df, output_column_name) | ||
|
||
if minidataset_labelled_df is None: | ||
minidataset_labelled_df = snippet_df_labelled | ||
else: | ||
minidataset_labelled_df = pd.concat([minidataset_labelled_df, snippet_df_labelled]) | ||
|
||
return minidataset_labelled_df | ||
|
||
def entity_to_dict(entity): | ||
""" | ||
Convert an entity object to a dictionary, handling datetime serialization. | ||
""" | ||
entity_dict = vars(entity) | ||
for key, value in entity_dict.items(): | ||
if isinstance(value, datetime): | ||
entity_dict[key] = value.isoformat() # Convert datetime to ISO 8601 string | ||
return entity_dict | ||
|
||
def dict_to_entity(entity_dict, entity_class): | ||
""" | ||
Convert a dictionary back to an entity object, handling datetime deserialization. | ||
""" | ||
for key, value in entity_dict.items(): | ||
try: | ||
# Attempt to parse strings back to datetime objects | ||
entity_dict[key] = datetime.fromisoformat(value) | ||
except (ValueError, TypeError): | ||
# If parsing fails, keep the original value | ||
pass | ||
return entity_class(**entity_dict) |
Oops, something went wrong.