Skip to content

Commit

Permalink
feat: Initial version of temporal invalidation + tests (#8)
Browse files Browse the repository at this point in the history
* feat: Initial version of temporal invalidation + tests

* fix: dont run int tests on CI

* fix: dont run int tests on CI

* fix: dont run int tests on CI

* fix: time of day issue

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* revert: Tests structural changes

* chore: Remove idea file

* chore: Get rid of NodesWithEdges class and define a triplet type instead
  • Loading branch information
paul-paliychuk authored Aug 20, 2024
1 parent 40e74a2 commit a6fd0dd
Show file tree
Hide file tree
Showing 21 changed files with 895 additions and 3,091 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Unit Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Load cached Poetry installation
uses: actions/cache@v3
with:
path: ~/.local
key: poetry-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
- name: Load cached dependencies
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
run: poetry install --no-interaction --no-root
- name: Run non-integration tests
env:
PYTHONPATH: ${{ github.workspace }}
run: |
poetry run pytest -m "not integration"
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,5 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.idea/
.vscode/
6 changes: 6 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import sys
import os

# This code adds the project root directory to the Python path, allowing imports to work correctly when running tests.
# Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))
57 changes: 32 additions & 25 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@

from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.edges import EntityEdge, EpisodicEdge
from core.utils import (
build_episodic_edges,
retrieve_relevant_schema,
extract_new_edges,
extract_new_nodes,
clear_data,
retrieve_episodes,
)
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.edge_operations import (
extract_edges,
dedupe_extracted_edges,
)

from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
prepare_edges_for_invalidation,
invalidate_edges,
)
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
Expand Down Expand Up @@ -59,21 +63,6 @@ async def retrieve_episodes(
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, last_n, sources)

async def retrieve_relevant_schema(self, query: str = None) -> dict[str, any]:
"""Retrieve relevant nodes and edges to a specific query"""
return await retrieve_relevant_schema(self.driver, query)
...

# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...

async def add_episode(
self,
name: str,
Expand Down Expand Up @@ -102,7 +91,6 @@ async def add_episode(
created_at=now,
valid_at=reference_time,
)
# relevant_schema = await self.retrieve_relevant_schema(episode.content)

extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes
Expand Down Expand Up @@ -139,13 +127,32 @@ async def add_episode(
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
)

new_edges = await dedupe_extracted_edges(
deduped_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)

logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
) = prepare_edges_for_invalidation(
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
)

invalidated_edges = await invalidate_edges(
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
)

entity_edges.extend(invalidated_edges)

logger.info(
f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}"
)

logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")

entity_edges.extend(new_edges)
entity_edges.extend(deduped_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
Expand Down
4 changes: 3 additions & 1 deletion core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ class EntityNode(Node):
name_embedding: list[float] | None = Field(
default=None, description="embedding of the name"
)
summary: str = Field(description="regional summary of surrounding edges")
summary: str = Field(
description="regional summary of surrounding edges", default_factory=str
)

async def update_summary(self, driver: AsyncDriver): ...

Expand Down
50 changes: 50 additions & 0 deletions core/prompts/invalidate_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Protocol, TypedDict
from .models import Message, PromptVersion, PromptFunction


class Prompt(Protocol):
v1: PromptVersion


class Versions(TypedDict):
v1: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.",
),
Message(
role="user",
content=f"""
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
Do not invalidate relationships merely because they weren't mentioned in new edges.
Existing Edges (sorted by timestamp, newest first):
{context['existing_edges']}
New Edges:
{context['new_edges']}
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)"
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
"reason": "Brief explanation of why this edge is being invalidated"
}}
]
}}
If no relationships need to be invalidated, return an empty list for "invalidated_edges".
""",
),
]


versions: Versions = {"v1": v1}
9 changes: 9 additions & 0 deletions core/prompts/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,27 @@
versions as dedupe_edges_versions,
)

from .invalidate_edges import (
Prompt as InvalidateEdgesPrompt,
Versions as InvalidateEdgesVersions,
versions as invalidate_edges_versions,
)


class PromptLibrary(Protocol):
extract_nodes: ExtractNodesPrompt
dedupe_nodes: DedupeNodesPrompt
extract_edges: ExtractEdgesPrompt
dedupe_edges: DedupeEdgesPrompt
invalidate_edges: InvalidateEdgesPrompt


class PromptLibraryImpl(TypedDict):
extract_nodes: ExtractNodesVersions
dedupe_nodes: DedupeNodesVersions
extract_edges: ExtractEdgesVersions
dedupe_edges: DedupeEdgesVersions
invalidate_edges: InvalidateEdgesVersions


class VersionWrapper:
Expand Down Expand Up @@ -66,6 +74,7 @@ def __init__(self, library: PromptLibraryImpl):
"dedupe_nodes": dedupe_nodes_versions,
"extract_edges": extract_edges_versions,
"dedupe_edges": dedupe_edges_versions,
"invalidate_edges": invalidate_edges_versions,
}

prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
2 changes: 0 additions & 2 deletions core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
build_episodic_edges,
extract_new_nodes,
clear_data,
retrieve_relevant_schema,
retrieve_episodes,
)

Expand All @@ -12,6 +11,5 @@
"build_episodic_edges",
"extract_new_nodes",
"clear_data",
"retrieve_relevant_schema",
"retrieve_episodes",
]
3 changes: 1 addition & 2 deletions core/utils/maintenance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .node_operations import extract_new_nodes
from .graph_data_operations import (
clear_data,
retrieve_relevant_schema,
retrieve_episodes,
)

Expand All @@ -11,6 +10,6 @@
"build_episodic_edges",
"extract_new_nodes",
"clear_data",
"retrieve_relevant_schema",
"retrieve_episodes",
"invalidate_edges",
]
2 changes: 2 additions & 0 deletions core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import List
from datetime import datetime

from pydantic import BaseModel

from core.nodes import EntityNode, EpisodicNode
from core.edges import EpisodicEdge, EntityEdge
import logging
Expand Down
46 changes: 0 additions & 46 deletions core/utils/maintenance/graph_data_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,6 @@ async def delete_all(tx):
await session.execute_write(delete_all)


async def retrieve_relevant_schema(
driver: AsyncDriver, query: str = None
) -> dict[str, any]:
async with driver.session() as session:
summary_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]->(m)
RETURN DISTINCT labels(n) AS node_labels, n.uuid AS node_uuid, n.name AS node_name,
type(r) AS relationship_type, r.name AS relationship_name, m.name AS related_node_name
"""
result = await session.run(summary_query)
records = [record async for record in result]

schema = {"nodes": {}, "relationships": []}

for record in records:
node_label = record["node_labels"][0] # Assuming one label per node
node_uuid = record["node_uuid"]
node_name = record["node_name"]
rel_type = record["relationship_type"]
rel_name = record["relationship_name"]
related_node = record["related_node_name"]

if node_name not in schema["nodes"]:
schema["nodes"][node_name] = {
"uuid": node_uuid,
"label": node_label,
"relationships": [],
}

if rel_type and related_node:
schema["nodes"][node_name]["relationships"].append(
{"type": rel_type, "name": rel_name, "target": related_node}
)
schema["relationships"].append(
{
"source": node_name,
"type": rel_type,
"name": rel_name,
"target": related_node,
}
)

return schema


async def retrieve_episodes(
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
Expand Down
Loading

0 comments on commit a6fd0dd

Please sign in to comment.