From 35711c46ced3c0e7e9ece29219f3fd80dffc125d Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Fri, 23 Aug 2024 08:15:44 -0700 Subject: [PATCH] chore: Fix Typing Issues (#27) * typing.Any and friends * message * chore: Import Message model in llm_client * fix: :lipstick: mypy errors * clean up mypy stuff * mypy * format * mypy * mypy * mypy --------- Co-authored-by: paulpaliychuk Co-authored-by: prestonrasmussen --- Makefile | 2 +- core/graphiti.py | 44 ++- core/llm_client/client.py | 8 +- core/llm_client/openai_client.py | 19 +- core/prompts/dedupe_edges.py | 9 +- core/prompts/dedupe_nodes.py | 8 +- core/prompts/extract_edges.py | 6 +- core/prompts/extract_nodes.py | 8 +- core/prompts/invalidate_edges.py | 4 +- core/prompts/lib.py | 9 +- core/prompts/models.py | 6 +- core/search/search.py | 18 +- core/search/search_utils.py | 23 +- core/utils/__init__.py | 8 +- core/utils/bulk_utils.py | 10 +- core/utils/maintenance/__init__.py | 8 +- core/utils/maintenance/edge_operations.py | 138 --------- .../maintenance/graph_data_operations.py | 14 +- core/utils/maintenance/node_operations.py | 63 +--- core/utils/maintenance/temporal_operations.py | 9 +- core/utils/search/search_utils.py | 292 ------------------ core/utils/utils.py | 4 +- pyproject.toml | 7 +- tests/tests_int_graphiti.py | 4 +- 24 files changed, 134 insertions(+), 587 deletions(-) delete mode 100644 core/utils/search/search_utils.py diff --git a/Makefile b/Makefile index 23bb46a0..8ef4ca3d 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ format: # Lint code lint: $(RUFF) check - $(MYPY) . --show-column-numbers --show-error-codes --pretty + $(MYPY) ./core --show-column-numbers --show-error-codes --pretty # Run tests test: diff --git a/core/graphiti.py b/core/graphiti.py index bdfaafbd..cc992d8a 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -56,7 +56,7 @@ def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | N else: self.llm_client = OpenAIClient( LLMConfig( - api_key=os.getenv('OPENAI_API_KEY'), + api_key=os.getenv('OPENAI_API_KEY', default=''), model='gpt-4o-mini', base_url='https://api.openai.com/v1', ) @@ -72,28 +72,16 @@ async def retrieve_episodes( self, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, - sources: list[str] | None = 'messages', ) -> list[EpisodicNode]: """Retrieve the last n episodic nodes from the graph""" - return await retrieve_episodes(self.driver, reference_time, last_n, sources) - - # Invalidate edges that are no longer valid - async def invalidate_edges( - self, - episode: EpisodicNode, - new_nodes: list[EntityNode], - new_edges: list[EntityEdge], - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], - ): ... + return await retrieve_episodes(self.driver, reference_time, last_n) async def add_episode( self, name: str, episode_body: str, source_description: str, - reference_time: datetime | None = None, - episode_type: str | None = 'string', # TODO: this field isn't used yet? + reference_time: datetime, success_callback: Callable | None = None, error_callback: Callable | None = None, ): @@ -104,7 +92,7 @@ async def add_episode( nodes: list[EntityNode] = [] entity_edges: list[EntityEdge] = [] episodic_edges: list[EpisodicEdge] = [] - embedder = self.llm_client.client.embeddings + embedder = self.llm_client.get_embedder() now = datetime.now() previous_episodes = await self.retrieve_episodes(reference_time) @@ -234,7 +222,7 @@ async def add_episode_bulk( ): try: start = time() - embedder = self.llm_client.client.embeddings + embedder = self.llm_client.get_embedder() now = datetime.now() episodes = [ @@ -276,14 +264,22 @@ async def add_episode_bulk( await asyncio.gather(*[node.save(self.driver) for node in nodes]) # re-map edge pointers so that they don't point to discard dupe nodes - extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map) - episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map) + extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers( + extracted_edges, uuid_map + ) + episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers( + episodic_edges, uuid_map + ) # save episodic edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + await asyncio.gather( + *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers] + ) # Dedupe extracted edges - edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges) + edges = await dedupe_edges_bulk( + self.driver, self.llm_client, extracted_edges_with_resolved_pointers + ) logger.info(f'extracted edge length: {len(edges)}') # invalidate edges @@ -302,12 +298,12 @@ async def search(self, query: str, num_results=10): edges = ( await hybrid_search( self.driver, - self.llm_client.client.embeddings, + self.llm_client.get_embedder(), query, datetime.now(), search_config, ) - )['edges'] + ).edges facts = [edge.fact for edge in edges] @@ -315,5 +311,5 @@ async def search(self, query: str, num_results=10): async def _search(self, query: str, timestamp: datetime, config: SearchConfig): return await hybrid_search( - self.driver, self.llm_client.client.embeddings, query, timestamp, config + self.driver, self.llm_client.get_embedder(), query, timestamp, config ) diff --git a/core/llm_client/client.py b/core/llm_client/client.py index 85205d43..911c9fdc 100644 --- a/core/llm_client/client.py +++ b/core/llm_client/client.py @@ -1,5 +1,7 @@ +import typing from abc import ABC, abstractmethod +from ..prompts.models import Message from .config import LLMConfig @@ -9,5 +11,9 @@ def __init__(self, config: LLMConfig): pass @abstractmethod - async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]: + def get_embedder(self) -> typing.Any: + pass + + @abstractmethod + async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: pass diff --git a/core/llm_client/openai_client.py b/core/llm_client/openai_client.py index b0f46a83..feb096fb 100644 --- a/core/llm_client/openai_client.py +++ b/core/llm_client/openai_client.py @@ -1,8 +1,11 @@ import json import logging +import typing from openai import AsyncOpenAI +from openai.types.chat import ChatCompletionMessageParam +from ..prompts.models import Message from .client import LLMClient from .config import LLMConfig @@ -14,16 +17,26 @@ def __init__(self, config: LLMConfig): self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) self.model = config.model - async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]: + def get_embedder(self) -> typing.Any: + return self.client.embeddings + + async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: + openai_messages: list[ChatCompletionMessageParam] = [] + for m in messages: + if m.role == 'user': + openai_messages.append({'role': 'user', 'content': m.content}) + elif m.role == 'system': + openai_messages.append({'role': 'system', 'content': m.content}) try: response = await self.client.chat.completions.create( model=self.model, - messages=messages, + messages=openai_messages, temperature=0.1, max_tokens=3000, response_format={'type': 'json_object'}, ) - return json.loads(response.choices[0].message.content) + result = response.choices[0].message.content or '' + return json.loads(result) except Exception as e: logger.error(f'Error in generating LLM response: {e}') raise diff --git a/core/prompts/dedupe_edges.py b/core/prompts/dedupe_edges.py index b3b1069a..cb41f201 100644 --- a/core/prompts/dedupe_edges.py +++ b/core/prompts/dedupe_edges.py @@ -1,5 +1,5 @@ import json -from typing import Protocol, TypedDict +from typing import Any, Protocol, TypedDict from .models import Message, PromptFunction, PromptVersion @@ -7,6 +7,7 @@ class Prompt(Protocol): v1: PromptVersion v2: PromptVersion + edge_list: PromptVersion class Versions(TypedDict): @@ -15,7 +16,7 @@ class Versions(TypedDict): edge_list: PromptFunction -def v1(context: dict[str, any]) -> list[Message]: +def v1(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', @@ -55,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]: ] -def v2(context: dict[str, any]) -> list[Message]: +def v2(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', @@ -97,7 +98,7 @@ def v2(context: dict[str, any]) -> list[Message]: ] -def edge_list(context: dict[str, any]) -> list[Message]: +def edge_list(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', diff --git a/core/prompts/dedupe_nodes.py b/core/prompts/dedupe_nodes.py index f7665770..6c3f459b 100644 --- a/core/prompts/dedupe_nodes.py +++ b/core/prompts/dedupe_nodes.py @@ -1,5 +1,5 @@ import json -from typing import Protocol, TypedDict +from typing import Any, Protocol, TypedDict from .models import Message, PromptFunction, PromptVersion @@ -16,7 +16,7 @@ class Versions(TypedDict): node_list: PromptVersion -def v1(context: dict[str, any]) -> list[Message]: +def v1(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', @@ -56,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]: ] -def v2(context: dict[str, any]) -> list[Message]: +def v2(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', @@ -96,7 +96,7 @@ def v2(context: dict[str, any]) -> list[Message]: ] -def node_list(context: dict[str, any]) -> list[Message]: +def node_list(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', diff --git a/core/prompts/extract_edges.py b/core/prompts/extract_edges.py index f520cae3..c339f634 100644 --- a/core/prompts/extract_edges.py +++ b/core/prompts/extract_edges.py @@ -1,5 +1,5 @@ import json -from typing import Protocol, TypedDict +from typing import Any, Protocol, TypedDict from .models import Message, PromptFunction, PromptVersion @@ -14,7 +14,7 @@ class Versions(TypedDict): v2: PromptFunction -def v1(context: dict[str, any]) -> list[Message]: +def v1(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', @@ -70,7 +70,7 @@ def v1(context: dict[str, any]) -> list[Message]: ] -def v2(context: dict[str, any]) -> list[Message]: +def v2(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', diff --git a/core/prompts/extract_nodes.py b/core/prompts/extract_nodes.py index 6804290b..7278568e 100644 --- a/core/prompts/extract_nodes.py +++ b/core/prompts/extract_nodes.py @@ -1,5 +1,5 @@ import json -from typing import Protocol, TypedDict +from typing import Any, Protocol, TypedDict from .models import Message, PromptFunction, PromptVersion @@ -16,7 +16,7 @@ class Versions(TypedDict): v3: PromptFunction -def v1(context: dict[str, any]) -> list[Message]: +def v1(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', @@ -64,7 +64,7 @@ def v1(context: dict[str, any]) -> list[Message]: ] -def v2(context: dict[str, any]) -> list[Message]: +def v2(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', @@ -105,7 +105,7 @@ def v2(context: dict[str, any]) -> list[Message]: ] -def v3(context: dict[str, any]) -> list[Message]: +def v3(context: dict[str, Any]) -> list[Message]: sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation.""" user_prompt = f""" diff --git a/core/prompts/invalidate_edges.py b/core/prompts/invalidate_edges.py index fea7046e..6b5667e5 100644 --- a/core/prompts/invalidate_edges.py +++ b/core/prompts/invalidate_edges.py @@ -1,4 +1,4 @@ -from typing import Protocol, TypedDict +from typing import Any, Protocol, TypedDict from .models import Message, PromptFunction, PromptVersion @@ -11,7 +11,7 @@ class Versions(TypedDict): v1: PromptFunction -def v1(context: dict[str, any]) -> list[Message]: +def v1(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', diff --git a/core/prompts/lib.py b/core/prompts/lib.py index 030c4e6d..d42914cc 100644 --- a/core/prompts/lib.py +++ b/core/prompts/lib.py @@ -1,4 +1,4 @@ -from typing import Protocol, TypedDict +from typing import Any, Protocol, TypedDict from .dedupe_edges import ( Prompt as DedupeEdgesPrompt, @@ -68,7 +68,7 @@ class VersionWrapper: def __init__(self, func: PromptFunction): self.func = func - def __call__(self, context: dict[str, any]) -> list[Message]: + def __call__(self, context: dict[str, Any]) -> list[Message]: return self.func(context) @@ -81,7 +81,7 @@ def __init__(self, versions: dict[str, PromptFunction]): class PromptLibraryWrapper: def __init__(self, library: PromptLibraryImpl): for prompt_type, versions in library.items(): - setattr(self, prompt_type, PromptTypeWrapper(versions)) + setattr(self, prompt_type, PromptTypeWrapper(versions)) # type: ignore[arg-type] PROMPT_LIBRARY_IMPL: PromptLibraryImpl = { @@ -91,5 +91,4 @@ def __init__(self, library: PromptLibraryImpl): 'dedupe_edges': dedupe_edges_versions, 'invalidate_edges': invalidate_edges_versions, } - -prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) +prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment] diff --git a/core/prompts/models.py b/core/prompts/models.py index 713497ed..708a3ea5 100644 --- a/core/prompts/models.py +++ b/core/prompts/models.py @@ -1,4 +1,4 @@ -from typing import Callable, Protocol +from typing import Any, Callable, Protocol from pydantic import BaseModel @@ -9,7 +9,7 @@ class Message(BaseModel): class PromptVersion(Protocol): - def __call__(self, context: dict[str, any]) -> list[Message]: ... + def __call__(self, context: dict[str, Any]) -> list[Message]: ... -PromptFunction = Callable[[dict[str, any]], list[Message]] +PromptFunction = Callable[[dict[str, Any]], list[Message]] diff --git a/core/search/search.py b/core/search/search.py index e59ea0fc..a773d8fa 100644 --- a/core/search/search.py +++ b/core/search/search.py @@ -5,9 +5,9 @@ from neo4j import AsyncDriver from pydantic import BaseModel -from core.edges import Edge +from core.edges import EntityEdge from core.llm_client.config import EMBEDDING_DIM -from core.nodes import Node +from core.nodes import EntityNode, EpisodicNode from core.search.search_utils import ( edge_fulltext_search, edge_similarity_search, @@ -28,9 +28,15 @@ class SearchConfig(BaseModel): reranker: str = 'rrf' +class SearchResults(BaseModel): + episodes: list[EpisodicNode] + nodes: list[EntityNode] + edges: list[EntityEdge] + + async def hybrid_search( driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig -) -> dict[str, [Node | Edge]]: +) -> SearchResults: start = time() episodes = [] @@ -86,11 +92,7 @@ async def hybrid_search( reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] edges.extend(reranked_edges) - context = { - 'episodes': episodes, - 'nodes': nodes, - 'edges': edges, - } + context = SearchResults(episodes=episodes, nodes=nodes, edges=edges) end = time() diff --git a/core/search/search_utils.py b/core/search/search_utils.py index c230cae6..6e4b443d 100644 --- a/core/search/search_utils.py +++ b/core/search/search_utils.py @@ -1,5 +1,6 @@ import asyncio import logging +import typing from collections import defaultdict from datetime import datetime from time import time @@ -15,7 +16,7 @@ RELEVANT_SCHEMA_LIMIT = 3 -def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None: +def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: return neo_date.to_native() if neo_date else None @@ -41,7 +42,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) uuid=record['uuid'], name=record['name'], labels=['Entity'], - created_at=datetime.now(), + created_at=record['created_at'].to_native(), summary=record['summary'], ) ) @@ -74,7 +75,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): node_ids=node_ids, ) - context = {} + context: dict[str, typing.Any] = {} for record in records: n_uuid = record['source_node_uuid'] @@ -173,7 +174,7 @@ async def entity_similarity_search( uuid=record['uuid'], name=record['name'], labels=['Entity'], - created_at=datetime.now(), + created_at=record['created_at'].to_native(), summary=record['summary'], ) ) @@ -208,7 +209,7 @@ async def entity_fulltext_search( uuid=record['uuid'], name=record['name'], labels=['Entity'], - created_at=datetime.now(), + created_at=record['created_at'].to_native(), summary=record['summary'], ) ) @@ -277,7 +278,11 @@ async def get_relevant_nodes( results = await asyncio.gather( *[entity_fulltext_search(node.name, driver) for node in nodes], - *[entity_similarity_search(node.name_embedding, driver) for node in nodes], + *[ + entity_similarity_search(node.name_embedding, driver) + for node in nodes + if node.name_embedding is not None + ], ) for result in results: @@ -303,7 +308,11 @@ async def get_relevant_edges( relevant_edge_uuids = set() results = await asyncio.gather( - *[edge_similarity_search(edge.fact_embedding, driver) for edge in edges], + *[ + edge_similarity_search(edge.fact_embedding, driver) + for edge in edges + if edge.fact_embedding is not None + ], *[edge_fulltext_search(edge.fact, driver) for edge in edges], ) diff --git a/core/utils/__init__.py b/core/utils/__init__.py index 3fcfd227..7978529b 100644 --- a/core/utils/__init__.py +++ b/core/utils/__init__.py @@ -1,15 +1,15 @@ from .maintenance import ( build_episodic_edges, clear_data, - extract_new_edges, - extract_new_nodes, + extract_edges, + extract_nodes, retrieve_episodes, ) __all__ = [ - 'extract_new_edges', + 'extract_edges', 'build_episodic_edges', - 'extract_new_nodes', + 'extract_nodes', 'clear_data', 'retrieve_episodes', ] diff --git a/core/utils/bulk_utils.py b/core/utils/bulk_utils.py index 31f20d9f..fbffd3e3 100644 --- a/core/utils/bulk_utils.py +++ b/core/utils/bulk_utils.py @@ -1,4 +1,5 @@ import asyncio +import typing from datetime import datetime from neo4j import AsyncDriver @@ -121,8 +122,8 @@ async def dedupe_edges_bulk( def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]: - uuid_map = {} - name_map = {} + uuid_map: dict[str, str] = {} + name_map: dict[str, EntityNode] = {} for node in nodes: if node.name in name_map: uuid_map[node.uuid] = name_map[node.name].uuid @@ -182,7 +183,10 @@ def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]: return compressed_map -def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]): +E = typing.TypeVar('E', bound=Edge) + + +def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]): for edge in edges: source_uuid = edge.source_node_uuid target_uuid = edge.target_node_uuid diff --git a/core/utils/maintenance/__init__.py b/core/utils/maintenance/__init__.py index 6dafc907..d552dbb1 100644 --- a/core/utils/maintenance/__init__.py +++ b/core/utils/maintenance/__init__.py @@ -1,15 +1,15 @@ -from .edge_operations import build_episodic_edges, extract_new_edges +from .edge_operations import build_episodic_edges, extract_edges from .graph_data_operations import ( clear_data, retrieve_episodes, ) -from .node_operations import extract_new_nodes +from .node_operations import extract_nodes from .temporal_operations import invalidate_edges __all__ = [ - 'extract_new_edges', + 'extract_edges', 'build_episodic_edges', - 'extract_new_nodes', + 'extract_nodes', 'clear_data', 'retrieve_episodes', 'invalidate_edges', diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index 26f183c9..922ba1a6 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -1,4 +1,3 @@ -import json import logging from datetime import datetime from time import time @@ -8,7 +7,6 @@ from core.llm_client import LLMClient from core.nodes import EntityNode, EpisodicNode from core.prompts import prompt_library -from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet logger = logging.getLogger(__name__) @@ -31,103 +29,6 @@ def build_episodic_edges( return edges -async def extract_new_edges( - llm_client: LLMClient, - episode: EpisodicNode, - new_nodes: list[EntityNode], - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], -) -> tuple[list[EntityEdge], list[EntityNode]]: - # Prepare context for LLM - context = { - 'episode_content': episode.content, - 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), - 'relevant_schema': json.dumps(relevant_schema, indent=2), - 'new_nodes': [{'name': node.name, 'summary': node.summary} for node in new_nodes], - 'previous_episodes': [ - { - 'content': ep.content, - 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response(prompt_library.extract_edges.v1(context)) - new_edges_data = llm_response.get('new_edges', []) - logger.info(f'Extracted new edges: {new_edges_data}') - - # Convert the extracted data into EntityEdge objects - new_edges = [] - for edge_data in new_edges_data: - source_node = next( - (node for node in new_nodes if node.name == edge_data['source_node']), - None, - ) - target_node = next( - (node for node in new_nodes if node.name == edge_data['target_node']), - None, - ) - - # If source or target is not in new_nodes, check if it's an existing node - if source_node is None and edge_data['source_node'] in relevant_schema['nodes']: - existing_node_data = relevant_schema['nodes'][edge_data['source_node']] - source_node = EntityNode( - uuid=existing_node_data['uuid'], - name=edge_data['source_node'], - labels=[existing_node_data['label']], - summary='', - created_at=datetime.now(), - ) - if target_node is None and edge_data['target_node'] in relevant_schema['nodes']: - existing_node_data = relevant_schema['nodes'][edge_data['target_node']] - target_node = EntityNode( - uuid=existing_node_data['uuid'], - name=edge_data['target_node'], - labels=[existing_node_data['label']], - summary='', - created_at=datetime.now(), - ) - - if ( - source_node - and target_node - and not ( - source_node.name.startswith('Message') or target_node.name.startswith('Message') - ) - ): - valid_at = ( - datetime.fromisoformat(edge_data['valid_at']) - if edge_data['valid_at'] - else episode.valid_at or datetime.now() - ) - invalid_at = ( - datetime.fromisoformat(edge_data['invalid_at']) if edge_data['invalid_at'] else None - ) - - new_edge = EntityEdge( - source_node=source_node, - target_node=target_node, - name=edge_data['relation_type'], - fact=edge_data['fact'], - episodes=[episode.uuid], - created_at=datetime.now(), - valid_at=valid_at, - invalid_at=invalid_at, - ) - new_edges.append(new_edge) - logger.info( - f'Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})' - ) - - affected_nodes = set() - - for edge in new_edges: - affected_nodes.add(edge.source_node) - affected_nodes.add(edge.target_node) - return new_edges, list(affected_nodes) - - async def extract_edges( llm_client: LLMClient, episode: EpisodicNode, @@ -186,45 +87,6 @@ def create_edge_identifier( return f'{source_node.name}-{edge.name}-{target_node.name}' -async def dedupe_extracted_edges_v2( - llm_client: LLMClient, - extracted_edges: list[NodeEdgeNodeTriplet], - existing_edges: list[NodeEdgeNodeTriplet], -) -> list[NodeEdgeNodeTriplet]: - # Create edge map - edge_map = {} - for n1, edge, n2 in existing_edges: - edge_map[create_edge_identifier(n1, edge, n2)] = edge - for n1, edge, n2 in extracted_edges: - if create_edge_identifier(n1, edge, n2) in edge_map: - continue - edge_map[create_edge_identifier(n1, edge, n2)] = edge - - # Prepare context for LLM - context = { - 'extracted_edges': [ - {'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact} - for n1, edge, n2 in extracted_edges - ], - 'existing_edges': [ - {'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact} - for n1, edge, n2 in extracted_edges - ], - } - logger.info(prompt_library.dedupe_edges.v2(context)) - llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v2(context)) - new_edges_data = llm_response.get('new_edges', []) - logger.info(f'Extracted new edges: {new_edges_data}') - - # Get full edge data - edges = [] - for edge_data in new_edges_data: - edge = edge_map[edge_data['triplet']] - edges.append(edge) - - return edges - - async def dedupe_extracted_edges( llm_client: LLMClient, extracted_edges: list[EntityEdge], diff --git a/core/utils/maintenance/graph_data_operations.py b/core/utils/maintenance/graph_data_operations.py index ca2da0fa..67579ca0 100644 --- a/core/utils/maintenance/graph_data_operations.py +++ b/core/utils/maintenance/graph_data_operations.py @@ -52,9 +52,7 @@ async def build_indices_and_constraints(driver: AsyncDriver): }} """, ] - index_queries: list[LiteralString] = ( - range_indices + fulltext_indices + vector_indices - ) + index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices await asyncio.gather(*[driver.execute_query(query) for query in index_queries]) @@ -72,7 +70,6 @@ async def retrieve_episodes( driver: AsyncDriver, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, - sources: list[str] | None = 'messages', ) -> list[EpisodicNode]: """Retrieve the last n episodic nodes from the graph""" result = await driver.execute_query( @@ -97,14 +94,7 @@ async def retrieve_episodes( created_at=datetime.fromtimestamp( record['created_at'].to_native().timestamp(), timezone.utc ), - valid_at=( - datetime.fromtimestamp( - record['valid_at'].to_native().timestamp(), - timezone.utc, - ) - if record['valid_at'] is not None - else None - ), + valid_at=(record['valid_at'].to_native()), uuid=record['uuid'], source=record['source'], name=record['name'], diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index b9ecab53..e7c3e728 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -9,53 +9,6 @@ logger = logging.getLogger(__name__) -async def extract_new_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], -) -> list[EntityNode]: - # Prepare context for LLM - existing_nodes = [ - {'name': node_name, 'label': node_info['label'], 'uuid': node_info['uuid']} - for node_name, node_info in relevant_schema['nodes'].items() - ] - - context = { - 'episode_content': episode.content, - 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), - 'existing_nodes': existing_nodes, - 'previous_episodes': [ - { - 'content': ep.content, - 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v1(context)) - new_nodes_data = llm_response.get('new_nodes', []) - logger.info(f'Extracted new nodes: {new_nodes_data}') - # Convert the extracted data into EntityNode objects - new_nodes = [] - for node_data in new_nodes_data: - # Check if the node already exists - if not any(existing_node['name'] == node_data['name'] for existing_node in existing_nodes): - new_node = EntityNode( - name=node_data['name'], - labels=node_data['labels'], - summary=node_data['summary'], - created_at=datetime.now(), - ) - new_nodes.append(new_node) - logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - else: - logger.info(f"Node {node_data['name']} already exists, skipping creation.") - - return new_nodes - - async def extract_nodes( llm_client: LLMClient, episode: EpisodicNode, @@ -100,16 +53,16 @@ async def dedupe_extracted_nodes( llm_client: LLMClient, extracted_nodes: list[EntityNode], existing_nodes: list[EntityNode], -) -> tuple[list[EntityNode], dict[str, str]]: +) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]: start = time() # build existing node map - node_map = {} + node_map: dict[str, EntityNode] = {} for node in existing_nodes: node_map[node.name] = node # Temp hack - new_nodes_map = {} + new_nodes_map: dict[str, EntityNode] = {} for node in extracted_nodes: new_nodes_map[node.name] = node @@ -134,14 +87,14 @@ async def dedupe_extracted_nodes( end = time() logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms') - uuid_map = {} + uuid_map: dict[str, str] = {} for duplicate in duplicate_data: uuid = new_nodes_map[duplicate['name']].uuid uuid_value = node_map[duplicate['duplicate_of']].uuid uuid_map[uuid] = uuid_value - nodes = [] - brand_new_nodes = [] + nodes: list[EntityNode] = [] + brand_new_nodes: list[EntityNode] = [] for node in extracted_nodes: if node.uuid in uuid_map: existing_uuid = uuid_map[node.uuid] @@ -149,7 +102,9 @@ async def dedupe_extracted_nodes( # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None) - nodes.append(existing_node) + if existing_node: + nodes.append(existing_node) + continue brand_new_nodes.append(node) nodes.append(node) diff --git a/core/utils/maintenance/temporal_operations.py b/core/utils/maintenance/temporal_operations.py index 37e2d7cb..8634f663 100644 --- a/core/utils/maintenance/temporal_operations.py +++ b/core/utils/maintenance/temporal_operations.py @@ -23,6 +23,8 @@ def extract_node_edge_node_triplet( ) -> NodeEdgeNodeTriplet: source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) + if not source_node or not target_node: + raise ValueError(f'Source or target node not found for edge {edge.uuid}') return (source_node, edge, target_node) @@ -31,11 +33,8 @@ def prepare_edges_for_invalidation( new_edges: list[EntityEdge], nodes: list[EntityNode], ) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]: - existing_edges_pending_invalidation = [] # TODO: this is not yet used? - new_edges_with_nodes = [] # TODO: this is not yet used? - - existing_edges_pending_invalidation = [] - new_edges_with_nodes = [] + existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = [] + new_edges_with_nodes: list[NodeEdgeNodeTriplet] = [] for edge_list, result_list in [ (existing_edges, existing_edges_pending_invalidation), diff --git a/core/utils/search/search_utils.py b/core/utils/search/search_utils.py deleted file mode 100644 index e34a3314..00000000 --- a/core/utils/search/search_utils.py +++ /dev/null @@ -1,292 +0,0 @@ -import asyncio -import logging -from datetime import datetime -from time import time - -from neo4j import AsyncDriver -from neo4j import time as neo4j_time - -from core.edges import EntityEdge -from core.nodes import EntityNode - -logger = logging.getLogger(__name__) - -RELEVANT_SCHEMA_LIMIT = 3 - - -async def bfs(node_ids: list[str], driver: AsyncDriver): - records, _, _ = await driver.execute_query( - """ - MATCH (n WHERE n.uuid in $node_ids)-[r]->(m) - RETURN - n.uuid AS source_node_uuid, - n.name AS source_name, - n.summary AS source_summary, - m.uuid AS target_node_uuid, - m.name AS target_name, - m.summary AS target_summary, - r.uuid AS uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - - """, - node_ids=node_ids, - ) - - context = {} - - for record in records: - n_uuid = record['source_node_uuid'] - if n_uuid in context: - context[n_uuid]['facts'].append(record['fact']) - else: - context[n_uuid] = { - 'name': record['source_name'], - 'summary': record['source_summary'], - 'facts': [record['fact']], - } - - m_uuid = record['target_node_uuid'] - if m_uuid not in context: - context[m_uuid] = { - 'name': record['target_name'], - 'summary': record['target_summary'], - 'facts': [], - } - logger.info(f'bfs search returned context: {context}') - return context - - -async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT -) -> list[EntityEdge]: - # vector similarity search over embedded facts - records, _, _ = await driver.execute_query( - """ - CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector) - YIELD relationship AS r, score - MATCH (n)-[r:RELATES_TO]->(m) - RETURN - r.uuid AS uuid, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT $limit - """, - search_vector=search_vector, - limit=limit, - ) - - edges: list[EntityEdge] = [] - - for record in records: - edge = EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=safely_parse_db_date(record['created_at']), - expired_at=safely_parse_db_date(record['expired_at']), - valid_at=safely_parse_db_date(record['valid_at']), - invalid_At=safely_parse_db_date(record['invalid_at']), - ) - - edges.append(edge) - - return edges - - -async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT -) -> list[EntityNode]: - # vector similarity search over entity names - records, _, _ = await driver.execute_query( - """ - CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) - YIELD node AS n, score - RETURN - n.uuid As uuid, - n.name AS name, - n.created_at AS created_at, - n.summary AS summary - ORDER BY score DESC - """, - search_vector=search_vector, - limit=limit, - ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - labels=[], - created_at=safely_parse_db_date(record['created_at']), - summary=record['summary'], - ) - ) - - return nodes - - -async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT -) -> list[EntityNode]: - # BM25 search to get top nodes - fuzzy_query = query + '~' - records, _, _ = await driver.execute_query( - """ - CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score - RETURN - node.uuid As uuid, - node.name AS name, - node.created_at AS created_at, - node.summary AS summary - ORDER BY score DESC - LIMIT $limit - """, - query=fuzzy_query, - limit=limit, - ) - nodes: list[EntityNode] = [] - - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - labels=[], - created_at=safely_parse_db_date(record['created_at']), - summary=record['summary'], - ) - ) - - return nodes - - -async def edge_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT -) -> list[EntityEdge]: - # fulltext search over facts - fuzzy_query = query + '~' - - records, _, _ = await driver.execute_query( - """ - CALL db.index.fulltext.queryRelationships("name_and_fact", $query) - YIELD relationship AS r, score - MATCH (n:Entity)-[r]->(m:Entity) - RETURN - r.uuid AS uuid, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT $limit - """, - query=fuzzy_query, - limit=limit, - ) - - edges: list[EntityEdge] = [] - - for record in records: - edge = EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=safely_parse_db_date(record['created_at']), - expired_at=safely_parse_db_date(record['expired_at']), - valid_at=safely_parse_db_date(record['valid_at']), - invalid_At=safely_parse_db_date(record['invalid_at']), - ) - - edges.append(edge) - - return edges - - -def safely_parse_db_date(date_str: neo4j_time.Date) -> datetime: - if date_str: - return datetime.fromisoformat(date_str.iso_format()) - return None - - -async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, -) -> list[EntityNode]: - start = time() - relevant_nodes: list[EntityNode] = [] - relevant_node_uuids = set() - - results = await asyncio.gather( - *[entity_fulltext_search(node.name, driver) for node in nodes], - *[entity_similarity_search(node.name_embedding, driver) for node in nodes], - ) - - for result in results: - for node in result: - if node.uuid in relevant_node_uuids: - continue - - relevant_node_uuids.add(node.uuid) - relevant_nodes.append(node) - - end = time() - logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms') - - return relevant_nodes - - -async def get_relevant_edges( - edges: list[EntityEdge], - driver: AsyncDriver, -) -> list[EntityEdge]: - start = time() - relevant_edges: list[EntityEdge] = [] - relevant_edge_uuids = set() - - results = await asyncio.gather( - *[edge_similarity_search(edge.fact_embedding, driver) for edge in edges], - *[edge_fulltext_search(edge.fact, driver) for edge in edges], - ) - - for result in results: - for edge in result: - if edge.uuid in relevant_edge_uuids: - continue - - relevant_edge_uuids.add(edge.uuid) - relevant_edges.append(edge) - - end = time() - logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms') - - return relevant_edges diff --git a/core/utils/utils.py b/core/utils/utils.py index 0999439f..8777fa68 100644 --- a/core/utils/utils.py +++ b/core/utils/utils.py @@ -14,8 +14,8 @@ def build_episodic_edges( for node in entity_nodes: edges.append( EpisodicEdge( - source_node_uuid=episode, - target_node_uuid=node, + source_node_uuid=episode.uuid, + target_node_uuid=node.uuid, created_at=episode.created_at, ) ) diff --git a/pyproject.toml b/pyproject.toml index 26f4b616..17196cf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,10 @@ name = "graphiti" version = "0.0.1" description = "Graph building library" -authors = ["Paul Paliychuk ", "Preston Rasmussen "] +authors = [ + "Paul Paliychuk ", + "Preston Rasmussen ", +] readme = "README.md" [tool.poetry.dependencies] @@ -56,4 +59,4 @@ ignore = ["E501"] [tool.ruff.format] quote-style = "single" indent-style = "tab" -docstring-code-format = true \ No newline at end of file +docstring-code-format = true diff --git a/tests/tests_int_graphiti.py b/tests/tests_int_graphiti.py index 53068e72..65c818bc 100644 --- a/tests/tests_int_graphiti.py +++ b/tests/tests_int_graphiti.py @@ -103,11 +103,11 @@ async def test_graph_integration(): bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary') episodic_edge_1 = EpisodicEdge( - source_node_uuid=episode, target_node_uuid=alice_node, created_at=now + source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now ) episodic_edge_2 = EpisodicEdge( - source_node_uuid=episode, target_node_uuid=bob_node, created_at=now + source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now ) entity_edge = EntityEdge(