Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add partial eval platform #252

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,9 @@ cython_debug/
## Other
# Cache files
cache.db*

# LongMemEval data
longmemeval_data/

# All DS_Store files
.DS_Store
1,174 changes: 1,174 additions & 0 deletions tests/evals/data/LongMemEval_Snippetization.ipynb

Large diffs are not rendered by default.

2,124 changes: 2,124 additions & 0 deletions tests/evals/data/LongMemEval_mini_dataset_loading.ipynb

Large diffs are not rendered by default.

495 changes: 495 additions & 0 deletions tests/evals/data/lme_dataset_filtered.csv

Large diffs are not rendered by default.

166 changes: 166 additions & 0 deletions tests/evals/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import pandas as pd
from datetime import datetime, timedelta
import json
from graphiti_core.nodes import EpisodicNode, EpisodeType, EntityNode

def create_episodes_from_messages(input_message, input_previous_messages):
"""
Create an episode and a list of previous episodes from input messages.
"""
# Current time for the episode
current_time = datetime.now()

# Create the current episode
role = input_message["role"]
content = input_message["content"]
message_content = f"{role}: {content}"
episode = EpisodicNode(
name="",
group_id="",
source=EpisodeType.message,
type=EpisodeType.message,
source_description="",
content=message_content,
valid_at=current_time,
)

# Create previous episodes
num_previous_messages = len(input_previous_messages)
previous_times = [current_time - timedelta(minutes=num_previous_messages - i) for i in range(num_previous_messages)]
previous_episodes = [
EpisodicNode(
name="",
group_id="",
source=EpisodeType.message,
source_description="",
content=f"{message['role']}: {message['content']}",
valid_at=previous_time,
)
for message, previous_time in zip(input_previous_messages, previous_times)
]

return episode, previous_episodes

async def ingest_and_label_snippet(llm_client, snippet_df, output_column_name):
# Import necessary functions
from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes
from graphiti_core.utils.maintenance.edge_operations import extract_edges

# Loop through each unique message_index_within_snippet in sorted order
for message_index in sorted(snippet_df['message_index_within_snippet'].unique()):
message_df = snippet_df[snippet_df['message_index_within_snippet'] == message_index]

#### Process 'extract_nodes' task
extract_nodes_row = message_df[message_df['task_name'] == 'extract_nodes']
assert len(extract_nodes_row) == 1, f"There should be exactly one row for 'extract_nodes' but there are {len(extract_nodes_row)}"
input_message = json.loads(extract_nodes_row.iloc[0]['input_message'])
input_previous_messages = json.loads(extract_nodes_row.iloc[0]['input_previous_messages'])
episode, previous_episodes = create_episodes_from_messages(input_message, input_previous_messages)
extracted_nodes = await extract_nodes(llm_client, episode, previous_episodes)
snippet_df.at[extract_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in extracted_nodes])

#### Process 'dedupe_nodes' task
dedupe_nodes_row = message_df[message_df['task_name'] == 'dedupe_nodes']
assert len(dedupe_nodes_row) == 1, "There should be exactly one row for 'dedupe_nodes' but there are {len(dedupe_nodes_row)}"

# Calculate existing nodes list
existing_nodes = []
for prev_message_index in sorted(snippet_df['message_index_within_snippet'].unique()):
if prev_message_index >= message_index:
break

# Filter for previous messages with 'extract_nodes' task
prev_message_df = snippet_df[
(snippet_df['message_index_within_snippet'] == prev_message_index) &
(snippet_df['task_name'] == 'extract_nodes')
]

# Retrieve and deserialize the nodes
serialized_nodes = prev_message_df.iloc[0][output_column_name]
node_dicts = json.loads(serialized_nodes)
nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in node_dicts]
existing_nodes.extend(nodes)

existing_nodes_lists = [existing_nodes for _ in range(len(extracted_nodes))]
resolved_nodes, uuid_map = await resolve_extracted_nodes(llm_client, extracted_nodes, existing_nodes_lists, episode, previous_episodes)
snippet_df.at[dedupe_nodes_row.index[0], output_column_name] = json.dumps([entity_to_dict(node) for node in resolved_nodes])

#### Process 'extract_edges' task
extract_edges_row = message_df[message_df['task_name'] == 'extract_edges']
assert len(extract_edges_row) == 1, f"There should be exactly one row for 'extract_edges' but there are {len(extract_edges_row)}"
extracted_edges = await extract_edges(
llm_client,
episode,
extracted_nodes,
previous_episodes,
group_id='',
)
snippet_df.at[extract_edges_row.index[0], output_column_name] = json.dumps([entity_to_dict(edge) for edge in extracted_edges])

########## TODO: Complete the implementation of the below

#### Process 'dedupe_edges' task
# dedupe_edges_row = message_df[message_df['task_name'] == 'dedupe_edges']
# assert len(dedupe_edges_row) == 1, "There should be exactly one row for 'dedupe_edges'"
# output = dedupe_extracted_edge(
# llm_client,
# extracted_edge,
# related_edges,
# )
# snippet_df.at[dedupe_edges_row.index[0], output_column_name] = output

#### Process 'extract_edge_dates' task
# extract_edge_dates_row = message_df[message_df['task_name'] == 'extract_edge_dates']
# assert len(extract_edge_dates_row) == 1, "There should be exactly one row for 'extract_edge_dates'"
# output = extract_edge_dates(extract_edge_dates_row.iloc[0]['input_extracted_edge_dates'])
# snippet_df.at[extract_edge_dates_row.index[0], output_column_name] = output

#### Process 'edge_invalidation' task
# edge_invalidation_row = message_df[message_df['task_name'] == 'edge_invalidation']
# assert len(edge_invalidation_row) == 1, "There should be exactly one row for 'edge_invalidation'"
# output = edge_invalidation(edge_invalidation_row.iloc[0]['input_edge_invalidation'])
# snippet_df.at[edge_invalidation_row.index[0], output_column_name] = output

return snippet_df


async def ingest_and_label_minidataset(llm_client, minidataset_df, output_column_name):
# Add a new column with the specified name, initialized with empty values
minidataset_df[output_column_name] = None

minidataset_labelled_df = None
for snippet_index in sorted(minidataset_df['snippet_index'].unique()):
snippet_df = minidataset_df[minidataset_df['snippet_index'] == snippet_index]

# Pass the output column name to the ingest_and_label_snippet function
snippet_df_labelled = await ingest_and_label_snippet(llm_client, snippet_df, output_column_name)

if minidataset_labelled_df is None:
minidataset_labelled_df = snippet_df_labelled
else:
minidataset_labelled_df = pd.concat([minidataset_labelled_df, snippet_df_labelled])

return minidataset_labelled_df

def entity_to_dict(entity):
"""
Convert an entity object to a dictionary, handling datetime serialization.
"""
entity_dict = vars(entity)
for key, value in entity_dict.items():
if isinstance(value, datetime):
entity_dict[key] = value.isoformat() # Convert datetime to ISO 8601 string
return entity_dict

def dict_to_entity(entity_dict, entity_class):
"""
Convert a dictionary back to an entity object, handling datetime deserialization.
"""
for key, value in entity_dict.items():
try:
# Attempt to parse strings back to datetime objects
entity_dict[key] = datetime.fromisoformat(value)
except (ValueError, TypeError):
# If parsing fails, keep the original value
pass
return entity_class(**entity_dict)
151 changes: 151 additions & 0 deletions tests/evals/eval_extract_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import json
from tests.evals.utils import setup_logging, ingest_snippet
from datetime import datetime, timedelta

import pytest
from dotenv import load_dotenv

from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.nodes import EntityNode, EpisodicNode

from graphiti_core.utils.maintenance.node_operations import extract_nodes
from graphiti_core.llm_client import OpenAIClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.nodes import EpisodeType

import csv # Add this import at the top of the file




############# EVERYTHING BELOW IS OUTDATED

# Setup
load_dotenv()
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
logger = setup_logging()


async def general_extract_nodes_test(llm_client, data_sample):
episode = data_sample['episode']
previous_episodes = data_sample['previous_episodes']
gold_answer_names = data_sample['gold_answer_names']

hypothesis_nodes = await extract_nodes(llm_client, episode, previous_episodes)
hypothesis_node_names = [node.name for node in hypothesis_nodes]

# Sort both lists by node name
hypothesis_node_names.sort()
gold_answer_names.sort()

# assert hypothesis_node_names == gold_answer_names, \
# f"""Test Failed. Expected nodes: {gold_answer_names}. Got: {hypothesis_node_names}"""

return hypothesis_node_names





def prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx):

samples_csv_path = "tests/evals/data/" + data_file_name + ".csv"

# From CSV path, load everything
with open(samples_csv_path, 'r') as file:
csv_reader = csv.DictReader(file)
lme_samples = list(csv_reader)


data_samples = []

# Loop through each row
for row in lme_samples:

### Prepare episode
current_time = datetime.now()
message = json.loads(row["message"])
role = message["role"]
content = message["content"]
message_content = role + ": " + content
episode = EpisodicNode(
name="",
group_id="",
source=EpisodeType.message,
type=EpisodeType.message,
source_description="",
content=message_content,
valid_at=current_time,
)

### Prepare previous episodes
previous_messages = json.loads(row["previous_messages"])
num_previous_messages = len(previous_messages)
previous_times = [current_time - timedelta(minutes=num_previous_messages-i) for i in range(num_previous_messages)]
previous_episodes = [EpisodicNode(
name="",
group_id="",
source=EpisodeType.message,
source_description="",
content=message["role"] + ": " + message["content"],
valid_at=previous_time,
) for message, previous_time in zip(previous_messages, previous_times)]

### TODO: Prepare gold answer names

### Add to data samples list
data_samples.append({
"episode": episode,
"previous_episodes": previous_episodes,
"gold_answer_names": [],
})

return data_samples





@pytest.mark.asyncio
async def test_extract_nodes():
model_name = 'gpt-4o-mini'
llm_config = LLMConfig(
api_key=os.getenv('OPENAI_API_KEY'),
model=model_name,
)
llm_client = OpenAIClient(config=llm_config)

data_file_name = 'output_short'
question_id = "gpt4_2655b836"
session_idx = 0
message_idx = 0
data_samples = prepare_data_from_csv(data_file_name, question_id, session_idx, message_idx)

for data_sample in data_samples:
print(f"\n\nEpisode: {data_sample['episode']}")
print("*"*50)
print(f"Previous Episodes: {data_sample['previous_episodes']}")
print("*"*50)
# print(f"Gold Answer Names: {gold_answer_names}")

await general_extract_nodes_test(llm_client, data_sample)

4 changes: 4 additions & 0 deletions tests/evals/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
asyncio_default_fixture_loop_scope = function
markers =
integration: marks tests as integration tests
23 changes: 23 additions & 0 deletions tests/evals/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import logging
import sys


def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO

# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)

# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Add formatter to console handler
console_handler.setFormatter(formatter)

# Add console handler to logger
logger.addHandler(console_handler)

return logger
Loading