From bc1dc0c4586e472bbae90287a89da73d5f7b127d Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 30 Jan 2025 12:46:17 -0500 Subject: [PATCH 01/15] ontology --- graphiti_core/graphiti.py | 69 ++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 81cf90fc..1fb04ee6 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -91,14 +91,14 @@ class AddEpisodeResults(BaseModel): class Graphiti: def __init__( - self, - uri: str, - user: str, - password: str, - llm_client: LLMClient | None = None, - embedder: EmbedderClient | None = None, - cross_encoder: CrossEncoderClient | None = None, - store_raw_episode_content: bool = True, + self, + uri: str, + user: str, + password: str, + llm_client: LLMClient | None = None, + embedder: EmbedderClient | None = None, + cross_encoder: CrossEncoderClient | None = None, + store_raw_episode_content: bool = True, ): """ Initialize a Graphiti instance. @@ -220,10 +220,10 @@ async def build_indices_and_constraints(self, delete_existing: bool = False): await build_indices_and_constraints(self.driver, delete_existing) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str] | None = None, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -253,15 +253,16 @@ async def retrieve_episodes( return await retrieve_episodes(self.driver, reference_time, last_n, group_ids) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - group_id: str = '', - uuid: str | None = None, - update_communities: bool = False, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + group_id: str = '', + uuid: str | None = None, + update_communities: bool = False, + ontology: list[BaseModel] | None = None, ) -> AddEpisodeResults: """ Process an episode and update the graph. @@ -622,12 +623,12 @@ async def build_communities(self, group_ids: list[str] | None = None) -> list[Co return community_nodes async def search( - self, - query: str, - center_node_uuid: str | None = None, - group_ids: list[str] | None = None, - num_results=DEFAULT_SEARCH_LIMIT, - search_filter: SearchFilters | None = None, + self, + query: str, + center_node_uuid: str | None = None, + group_ids: list[str] | None = None, + num_results=DEFAULT_SEARCH_LIMIT, + search_filter: SearchFilters | None = None, ) -> list[EntityEdge]: """ Perform a hybrid search on the knowledge graph. @@ -681,13 +682,13 @@ async def search( return edges async def _search( - self, - query: str, - config: SearchConfig, - group_ids: list[str] | None = None, - center_node_uuid: str | None = None, - bfs_origin_node_uuids: list[str] | None = None, - search_filter: SearchFilters | None = None, + self, + query: str, + config: SearchConfig, + group_ids: list[str] | None = None, + center_node_uuid: str | None = None, + bfs_origin_node_uuids: list[str] | None = None, + search_filter: SearchFilters | None = None, ) -> SearchResults: return await search( self.driver, From 8abae598c8ec782c5721156ac3b3bbb361084d1e Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Sun, 2 Feb 2025 19:22:59 -0500 Subject: [PATCH 02/15] extract and save node labels --- graphiti_core/graphiti.py | 74 ++++++++++--------- graphiti_core/models/nodes/node_db_queries.py | 2 + graphiti_core/nodes.py | 16 +++- graphiti_core/prompts/extract_nodes.py | 43 ++++++++++- graphiti_core/search/search_utils.py | 16 ++-- graphiti_core/utils/bulk_utils.py | 5 +- .../utils/maintenance/node_operations.py | 26 ++++++- 7 files changed, 132 insertions(+), 50 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 1fb04ee6..39e7a22b 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -91,14 +91,14 @@ class AddEpisodeResults(BaseModel): class Graphiti: def __init__( - self, - uri: str, - user: str, - password: str, - llm_client: LLMClient | None = None, - embedder: EmbedderClient | None = None, - cross_encoder: CrossEncoderClient | None = None, - store_raw_episode_content: bool = True, + self, + uri: str, + user: str, + password: str, + llm_client: LLMClient | None = None, + embedder: EmbedderClient | None = None, + cross_encoder: CrossEncoderClient | None = None, + store_raw_episode_content: bool = True, ): """ Initialize a Graphiti instance. @@ -220,10 +220,10 @@ async def build_indices_and_constraints(self, delete_existing: bool = False): await build_indices_and_constraints(self.driver, delete_existing) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str] | None = None, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -253,16 +253,16 @@ async def retrieve_episodes( return await retrieve_episodes(self.driver, reference_time, last_n, group_ids) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - group_id: str = '', - uuid: str | None = None, - update_communities: bool = False, - ontology: list[BaseModel] | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + group_id: str = '', + uuid: str | None = None, + update_communities: bool = False, + entity_types: dict[str, BaseModel] | None = None, ) -> AddEpisodeResults: """ Process an episode and update the graph. @@ -337,7 +337,9 @@ async def add_episode_endpoint(episode_data: EpisodeData): # Extract entities as nodes - extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) + extracted_nodes = await extract_nodes( + self.llm_client, episode, previous_episodes, entity_types + ) logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') # Calculate Embeddings @@ -623,12 +625,12 @@ async def build_communities(self, group_ids: list[str] | None = None) -> list[Co return community_nodes async def search( - self, - query: str, - center_node_uuid: str | None = None, - group_ids: list[str] | None = None, - num_results=DEFAULT_SEARCH_LIMIT, - search_filter: SearchFilters | None = None, + self, + query: str, + center_node_uuid: str | None = None, + group_ids: list[str] | None = None, + num_results=DEFAULT_SEARCH_LIMIT, + search_filter: SearchFilters | None = None, ) -> list[EntityEdge]: """ Perform a hybrid search on the knowledge graph. @@ -682,13 +684,13 @@ async def search( return edges async def _search( - self, - query: str, - config: SearchConfig, - group_ids: list[str] | None = None, - center_node_uuid: str | None = None, - bfs_origin_node_uuids: list[str] | None = None, - search_filter: SearchFilters | None = None, + self, + query: str, + config: SearchConfig, + group_ids: list[str] | None = None, + center_node_uuid: str | None = None, + bfs_origin_node_uuids: list[str] | None = None, + search_filter: SearchFilters | None = None, ) -> SearchResults: return await search( self.driver, diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 9010532b..cdd020e0 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -31,6 +31,7 @@ ENTITY_NODE_SAVE = """ MERGE (n:Entity {uuid: $uuid}) + SET n:$($labels) SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at} WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding) RETURN n.uuid AS uuid""" @@ -38,6 +39,7 @@ ENTITY_NODE_SAVE_BULK = """ UNWIND $nodes AS node MERGE (n:Entity {uuid: node.uuid}) + SET n:$(n.labels) SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at} WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) RETURN n.uuid AS uuid diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 6a490c8c..b7c2ed6a 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -255,6 +255,9 @@ async def get_by_group_ids( class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) + properties: dict[str, Any] | None = Field( + default=None, description='Additional properties of the node. Dependent on node labels' + ) async def generate_name_embedding(self, embedder: EmbedderClient): start = time() @@ -269,10 +272,12 @@ async def save(self, driver: AsyncDriver): result = await driver.execute_query( ENTITY_NODE_SAVE, uuid=self.uuid, + labels=self.labels + ['Entity'], name=self.name, group_id=self.group_id, summary=self.summary, name_embedding=self.name_embedding, + properties=self.properties, created_at=self.created_at, database_=DEFAULT_DATABASE, ) @@ -292,7 +297,8 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): n.name_embedding AS name_embedding, n.group_id AS group_id, n.created_at AS created_at, - n.summary AS summary + n.summary AS summary, + labels(n) AS labels, """, uuid=uuid, database_=DEFAULT_DATABASE, @@ -317,7 +323,8 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): n.name_embedding AS name_embedding, n.group_id AS group_id, n.created_at AS created_at, - n.summary AS summary + n.summary AS summary, + labels(n) AS labels, """, uuids=uuids, database_=DEFAULT_DATABASE, @@ -351,7 +358,8 @@ async def get_by_group_ids( n.name_embedding AS name_embedding, n.group_id AS group_id, n.created_at AS created_at, - n.summary AS summary + n.summary AS summary, + labels(n) AS labels, ORDER BY n.uuid DESC """ + limit_query, @@ -503,7 +511,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode: name=record['name'], group_id=record['group_id'], name_embedding=record['name_embedding'], - labels=['Entity'], + labels=record['labels'], created_at=record['created_at'].to_native(), summary=record['summary'], ) diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index 49e2036b..a966d52e 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -30,11 +30,19 @@ class MissedEntities(BaseModel): missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted") +class EntityClassification(BaseModel): + entity_classification: dict[str, str | None] = Field( + ..., + description='Dictionary of entity classifications. Key is the entity name and value is the entity type', + ) + + class Prompt(Protocol): extract_message: PromptVersion extract_json: PromptVersion extract_text: PromptVersion reflexion: PromptVersion + classify_nodes: PromptVersion class Versions(TypedDict): @@ -42,6 +50,7 @@ class Versions(TypedDict): extract_json: PromptFunction extract_text: PromptFunction reflexion: PromptFunction + classify_nodes: PromptFunction def extract_message(context: dict[str, Any]) -> list[Message]: @@ -109,7 +118,7 @@ def extract_text(context: dict[str, Any]) -> list[Message]: {context['custom_prompt']} -Given the following text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned: +Given the above text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned: Guidelines: 1. Extract significant entities, concepts, or actors mentioned in the conversation. @@ -147,9 +156,41 @@ def reflexion(context: dict[str, Any]) -> list[Message]: ] +def classify_nodes(context: dict[str, Any]) -> list[Message]: + sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted""" + + user_prompt = f""" + + {json.dumps([ep for ep in context['previous_episodes']], indent=2)} + + + {context["episode_content"]} + + + + {context['extracted_entities']} + + + + {context['entity_types']} + + + Given the above conversation, extracted entities, and provided entity types, classify the extracted entities. + + Guidelines: + 1. Each entity must have exactly one type + 2. If none of the provided entity types accurately classify an extracted node, the type should be set to None +""" + return [ + Message(role='system', content=sys_prompt), + Message(role='user', content=user_prompt), + ] + + versions: Versions = { 'extract_message': extract_message, 'extract_json': extract_json, 'extract_text': extract_text, 'reflexion': reflexion, + 'classify_nodes': classify_nodes, } diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index c4d44fd1..854bd1ae 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -97,7 +97,8 @@ async def get_mentioned_nodes( n.name AS name, n.name_embedding AS name_embedding, n.created_at AS created_at, - n.summary AS summary + n.summary AS summary, + labels(n) AS labels """, uuids=episode_uuids, database_=DEFAULT_DATABASE, @@ -223,8 +224,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -341,7 +342,8 @@ async def node_fulltext_search( n.name AS name, n.name_embedding AS name_embedding, n.created_at AS created_at, - n.summary AS summary + n.summary AS summary, + labels(n) AS labels, ORDER BY score DESC LIMIT $limit """, @@ -390,7 +392,8 @@ async def node_similarity_search( n.name AS name, n.name_embedding AS name_embedding, n.created_at AS created_at, - n.summary AS summary + n.summary AS summary, + labels(n) AS labels, ORDER BY score DESC LIMIT $limit """, @@ -427,7 +430,8 @@ async def node_bfs_search( n.name AS name, n.name_embedding AS name_embedding, n.created_at AS created_at, - n.summary AS summary + n.summary AS summary, + labels(n) AS labels, LIMIT $limit """, bfs_origin_node_uuids=bfs_origin_node_uuids, diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 80f66029..60ef43f0 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -109,8 +109,11 @@ async def add_nodes_and_edges_bulk_tx( episodes = [dict(episode) for episode in episodic_nodes] for episode in episodes: episode['source'] = str(episode['source'].value) + nodes = [dict(entity) for entity in entity_nodes] + for node in nodes: + node['labels'] = list(set(node['labels'] + ['Entity'])) await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) - await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes]) + await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes) await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges]) await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges]) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 31e916b4..83741454 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -17,12 +17,14 @@ import logging from time import time +from pydantic import BaseModel + from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate -from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities +from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities from graphiti_core.prompts.summarize_nodes import Summary from graphiti_core.utils.datetime_utils import utc_now @@ -114,6 +116,7 @@ async def extract_nodes( llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode], + entity_types: dict[str, BaseModel] | None = None, ) -> list[EntityNode]: start = time() extracted_node_names: list[str] = [] @@ -144,15 +147,34 @@ async def extract_nodes( for entity in missing_entities: custom_prompt += f'\n{entity},' + node_classification_context = { + 'episode_content': episode.content, + 'previous_episodes': [ep.content for ep in previous_episodes], + 'extracted_entities': extracted_node_names, + 'entity_types': entity_types.keys(), + } + + node_classifications: dict[str, str | None] = {} + + if entity_types is not None: + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.classify_nodes(node_classification_context), + response_model=EntityClassification, + ) + node_classifications.update(llm_response.get('entity_classification', {})) + end = time() logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms') # Convert the extracted data into EntityNode objects new_nodes = [] for name in extracted_node_names: + entity_type = node_classifications.get(name, None) + labels = ['Entity'] if entity_type is None else ['Entity', entity_type] + new_node = EntityNode( name=name, group_id=episode.group_id, - labels=['Entity'], + labels=labels, summary='', created_at=utc_now(), ) From 6eb199d9917f02b4de62b17556cec6cd4502c316 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 3 Feb 2025 11:45:21 -0500 Subject: [PATCH 03/15] extract entity type properties --- graphiti_core/models/nodes/node_db_queries.py | 4 ++-- graphiti_core/nodes.py | 18 +++++++++++------- graphiti_core/prompts/summarize_nodes.py | 14 ++++++++++---- .../utils/maintenance/node_operations.py | 13 ++++++++++++- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index cdd020e0..38d60305 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -32,7 +32,7 @@ ENTITY_NODE_SAVE = """ MERGE (n:Entity {uuid: $uuid}) SET n:$($labels) - SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at} + SET n = $entity_data WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding) RETURN n.uuid AS uuid""" @@ -40,7 +40,7 @@ UNWIND $nodes AS node MERGE (n:Entity {uuid: node.uuid}) SET n:$(n.labels) - SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at} + SET n = node WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) RETURN n.uuid AS uuid """ diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index b7c2ed6a..f213b5bc 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -269,16 +269,20 @@ async def generate_name_embedding(self, embedder: EmbedderClient): return self.name_embedding async def save(self, driver: AsyncDriver): + entity_data: dict[str, Any] = { + 'uuid': self.uuid, + 'name': self.name, + 'group_id': self.group_id, + 'summary': self.summary, + 'created_at': self.created_at, + } + + entity_data.update(self.properties) + result = await driver.execute_query( ENTITY_NODE_SAVE, - uuid=self.uuid, labels=self.labels + ['Entity'], - name=self.name, - group_id=self.group_id, - summary=self.summary, - name_embedding=self.name_embedding, - properties=self.properties, - created_at=self.created_at, + entity_data=entity_data, database_=DEFAULT_DATABASE, ) diff --git a/graphiti_core/prompts/summarize_nodes.py b/graphiti_core/prompts/summarize_nodes.py index e00e1bab..d1c77e51 100644 --- a/graphiti_core/prompts/summarize_nodes.py +++ b/graphiti_core/prompts/summarize_nodes.py @@ -24,7 +24,8 @@ class Summary(BaseModel): summary: str = Field( - ..., description='Summary containing the important information from both summaries' + ..., + description='Summary containing the important information about the entity. Under 500 words', ) @@ -68,7 +69,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', - content='You are a helpful assistant that combines summaries with new conversation context.', + content='You are a helpful assistant that extracts entity properties from the provided text.', ), Message( role='user', @@ -81,16 +82,21 @@ def summarize_context(context: dict[str, Any]) -> list[Message]: Given the above MESSAGES and the following ENTITY name and ENTITY CONTEXT, create a summary for the ENTITY. Your summary must only use information from the provided MESSAGES and from the ENTITY CONTEXT. Your summary should also only contain information relevant to the - provided ENTITY. + provided ENTITY. Summaries must be under 500 words. - Summaries must be under 500 words. + In addition, extract any values for the provided entity properties based on their descriptions. {context['node_name']} + {context['node_summary']} + + + {json.dumps(context['properties'], indent=2)} + """, ), ] diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 83741454..a1c503f6 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -17,6 +17,7 @@ import logging from time import time +import pydantic from pydantic import BaseModel from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather @@ -267,6 +268,7 @@ async def resolve_extracted_node( existing_nodes: list[EntityNode], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, BaseModel] | None = None, ) -> tuple[EntityNode, dict[str, str]]: start = time() @@ -297,13 +299,22 @@ async def resolve_extracted_node( else [], } + entity_type_classes = tuple( + filter( + lambda x: x is not None, + [entity_types.get(entity_type) for entity_type in extracted_node.labels], + ) + ) + llm_response, node_summary_response = await semaphore_gather( llm_client.generate_response( prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate ), llm_client.generate_response( prompt_library.summarize_nodes.summarize_context(summary_context), - response_model=Summary, + response_model=pydantic.create_model( + 'EntityProperties', __base__=entity_type_classes + (Summary,) + ), ), ) From 1dd91d9c23b8094707f1b962d3ee411c753b4268 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Tue, 4 Feb 2025 11:45:23 -0500 Subject: [PATCH 04/15] neo4j upgrade needed --- examples/podcast/podcast_runner.py | 10 ++- graphiti_core/graphiti.py | 1 + graphiti_core/prompts/extract_nodes.py | 2 +- graphiti_core/search/search_utils.py | 8 +- graphiti_core/utils/bulk_utils.py | 45 +++++----- .../utils/maintenance/node_operations.py | 82 +++++++++++-------- 6 files changed, 85 insertions(+), 63 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 0ee01eb2..b46672d6 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -20,6 +20,7 @@ import sys from dotenv import load_dotenv +from pydantic import BaseModel, Field from transcript_parser import parse_podcast_messages from graphiti_core import Graphiti @@ -53,10 +54,16 @@ def setup_logging(): return logger +class Person(BaseModel): + first_name: str | None = Field(..., description='First name') + last_name: str | None = Field(..., description='Last name') + occupation: str | None = Field(..., description="The person's work occupation") + + async def main(): setup_logging() client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - await clear_data(client.driver) + # await clear_data(client.driver) await client.build_indices_and_constraints() messages = parse_podcast_messages() @@ -67,6 +74,7 @@ async def main(): reference_time=message.actual_timestamp, source_description='Podcast Transcript', group_id='podcast', + entity_types={'Person': Person}, ) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 39e7a22b..df12d98d 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -365,6 +365,7 @@ async def add_episode_endpoint(episode_data: EpisodeData): existing_nodes_lists, episode, previous_episodes, + entity_types, ), extract_edges( self.llm_client, episode, extracted_nodes, previous_episodes, group_id diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index a966d52e..3c1535c6 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -31,7 +31,7 @@ class MissedEntities(BaseModel): class EntityClassification(BaseModel): - entity_classification: dict[str, str | None] = Field( + entity_classification: str = Field( ..., description='Dictionary of entity classifications. Key is the entity name and value is the entity type', ) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 854bd1ae..8d58c7a7 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -224,8 +224,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -343,7 +343,7 @@ async def node_fulltext_search( n.name_embedding AS name_embedding, n.created_at AS created_at, n.summary AS summary, - labels(n) AS labels, + labels(n) AS labels ORDER BY score DESC LIMIT $limit """, @@ -393,7 +393,7 @@ async def node_similarity_search( n.name_embedding AS name_embedding, n.created_at AS created_at, n.summary AS summary, - labels(n) AS labels, + labels(n) AS labels ORDER BY score DESC LIMIT $limit """, diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 60ef43f0..7f3fc8b8 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -69,7 +69,7 @@ class RawEpisode(BaseModel): async def retrieve_previous_episodes_bulk( - driver: AsyncDriver, episodes: list[EpisodicNode] + driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: previous_episodes_list = await semaphore_gather( *[ @@ -87,11 +87,11 @@ async def retrieve_previous_episodes_bulk( async def add_nodes_and_edges_bulk( - driver: AsyncDriver, - episodic_nodes: list[EpisodicNode], - episodic_edges: list[EpisodicEdge], - entity_nodes: list[EntityNode], - entity_edges: list[EntityEdge], + driver: AsyncDriver, + episodic_nodes: list[EpisodicNode], + episodic_edges: list[EpisodicEdge], + entity_nodes: list[EntityNode], + entity_edges: list[EntityEdge], ): async with driver.session() as session: await session.execute_write( @@ -100,11 +100,11 @@ async def add_nodes_and_edges_bulk( async def add_nodes_and_edges_bulk_tx( - tx: AsyncManagedTransaction, - episodic_nodes: list[EpisodicNode], - episodic_edges: list[EpisodicEdge], - entity_nodes: list[EntityNode], - entity_edges: list[EntityEdge], + tx: AsyncManagedTransaction, + episodic_nodes: list[EpisodicNode], + episodic_edges: list[EpisodicEdge], + entity_nodes: list[EntityNode], + entity_edges: list[EntityEdge], ): episodes = [dict(episode) for episode in episodic_nodes] for episode in episodes: @@ -112,6 +112,7 @@ async def add_nodes_and_edges_bulk_tx( nodes = [dict(entity) for entity in entity_nodes] for node in nodes: node['labels'] = list(set(node['labels'] + ['Entity'])) + await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes) await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges]) @@ -119,7 +120,7 @@ async def add_nodes_and_edges_bulk_tx( async def extract_nodes_and_edges_bulk( - llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] + llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: extracted_nodes_bulk = await semaphore_gather( *[ @@ -162,16 +163,16 @@ async def extract_nodes_and_edges_bulk( async def dedupe_nodes_bulk( - driver: AsyncDriver, - llm_client: LLMClient, - extracted_nodes: list[EntityNode], + driver: AsyncDriver, + llm_client: LLMClient, + extracted_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: # Compress nodes nodes, uuid_map = node_name_match(extracted_nodes) compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map) - node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] + node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] existing_nodes_chunks: list[list[EntityNode]] = list( await semaphore_gather( @@ -198,13 +199,13 @@ async def dedupe_nodes_bulk( async def dedupe_edges_bulk( - driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] + driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] ) -> list[EntityEdge]: # First compress edges compressed_edges = await compress_edges(llm_client, extracted_edges) edge_chunks = [ - compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE) + compressed_edges[i: i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE) ] relevant_edges_chunks: list[list[EntityEdge]] = list( @@ -240,7 +241,7 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str async def compress_nodes( - llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] + llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] ) -> tuple[list[EntityNode], dict[str, str]]: # We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk if len(nodes) == 0: @@ -361,9 +362,9 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]): async def extract_edge_dates_bulk( - llm_client: LLMClient, - extracted_edges: list[EntityEdge], - episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]], + llm_client: LLMClient, + extracted_edges: list[EntityEdge], + episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]], ) -> list[EntityEdge]: edges: list[EntityEdge] = [] # confirm that all of our edges have at least one episode diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index a1c503f6..1f4daecd 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -14,6 +14,7 @@ limitations under the License. """ +import ast import logging from time import time @@ -33,10 +34,10 @@ async def extract_message_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - custom_prompt='', + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], + custom_prompt='', ) -> list[str]: # Prepare context for LLM context = { @@ -54,10 +55,10 @@ async def extract_message_nodes( async def extract_text_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - custom_prompt='', + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], + custom_prompt='', ) -> list[str]: # Prepare context for LLM context = { @@ -75,7 +76,7 @@ async def extract_text_nodes( async def extract_json_nodes( - llm_client: LLMClient, episode: EpisodicNode, custom_prompt='' + llm_client: LLMClient, episode: EpisodicNode, custom_prompt='' ) -> list[str]: # Prepare context for LLM context = { @@ -93,10 +94,10 @@ async def extract_json_nodes( async def extract_nodes_reflexion( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - node_names: list[str], + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], + node_names: list[str], ) -> list[str]: # Prepare context for LLM context = { @@ -114,10 +115,10 @@ async def extract_nodes_reflexion( async def extract_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - entity_types: dict[str, BaseModel] | None = None, + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], + entity_types: dict[str, BaseModel] | None = None, ) -> list[EntityNode]: start = time() extracted_node_names: list[str] = [] @@ -162,7 +163,8 @@ async def extract_nodes( prompt_library.extract_nodes.classify_nodes(node_classification_context), response_model=EntityClassification, ) - node_classifications.update(llm_response.get('entity_classification', {})) + response_string = llm_response.get('entity_classification', '{}') + node_classifications.update(ast.literal_eval(response_string)) end = time() logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms') @@ -186,9 +188,9 @@ async def extract_nodes( async def dedupe_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes: list[EntityNode], + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: start = time() @@ -236,11 +238,12 @@ async def dedupe_extracted_nodes( async def resolve_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes_lists: list[list[EntityNode]], - episode: EpisodicNode | None = None, - previous_episodes: list[EpisodicNode] | None = None, + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes_lists: list[list[EntityNode]], + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, str] | None = None, ) -> tuple[list[EntityNode], dict[str, str]]: uuid_map: dict[str, str] = {} resolved_nodes: list[EntityNode] = [] @@ -248,7 +251,12 @@ async def resolve_extracted_nodes( await semaphore_gather( *[ resolve_extracted_node( - llm_client, extracted_node, existing_nodes, episode, previous_episodes + llm_client, + extracted_node, + existing_nodes, + episode, + previous_episodes, + entity_types, ) for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists) ] @@ -263,12 +271,12 @@ async def resolve_extracted_nodes( async def resolve_extracted_node( - llm_client: LLMClient, - extracted_node: EntityNode, - existing_nodes: list[EntityNode], - episode: EpisodicNode | None = None, - previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, + llm_client: LLMClient, + extracted_node: EntityNode, + existing_nodes: list[EntityNode], + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, BaseModel] | None = None, ) -> tuple[EntityNode, dict[str, str]]: start = time() @@ -297,6 +305,7 @@ async def resolve_extracted_node( 'previous_episodes': [ep.content for ep in previous_episodes] if previous_episodes is not None else [], + 'properties': [] } entity_type_classes = tuple( @@ -305,6 +314,9 @@ async def resolve_extracted_node( [entity_types.get(entity_type) for entity_type in extracted_node.labels], ) ) + for entity_type in entity_type_classes: + for field_name in entity_type.__fields__.keys(): + summary_context['properties'].append(field_name) llm_response, node_summary_response = await semaphore_gather( llm_client.generate_response( @@ -350,8 +362,8 @@ async def resolve_extracted_node( async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], + llm_client: LLMClient, + nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: start = time() From 668d317ceb1c9a2f8642e6a464d0019158e7c8fa Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 7 Feb 2025 13:27:19 -0500 Subject: [PATCH 05/15] add entity types --- .../msc_eval.py | 142 ------------------ .../msc_runner.py | 89 ----------- .../parse_msc_messages.py | 85 ----------- graphiti_core/models/nodes/node_db_queries.py | 2 +- graphiti_core/nodes.py | 10 +- graphiti_core/prompts/extract_nodes.py | 1 + graphiti_core/prompts/summarize_nodes.py | 10 +- graphiti_core/search/search_utils.py | 14 +- graphiti_core/utils/bulk_utils.py | 46 +++--- .../utils/maintenance/node_operations.py | 85 ++++++----- 10 files changed, 90 insertions(+), 394 deletions(-) delete mode 100644 examples/multi_session_conversation_memory/msc_eval.py delete mode 100644 examples/multi_session_conversation_memory/msc_runner.py delete mode 100644 examples/multi_session_conversation_memory/parse_msc_messages.py diff --git a/examples/multi_session_conversation_memory/msc_eval.py b/examples/multi_session_conversation_memory/msc_eval.py deleted file mode 100644 index db61482b..00000000 --- a/examples/multi_session_conversation_memory/msc_eval.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import asyncio -import csv -import logging -import os -import sys -from time import time - -from dotenv import load_dotenv - -from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a -from graphiti_core import Graphiti -from graphiti_core.helpers import semaphore_gather -from graphiti_core.prompts import prompt_library -from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF - -load_dotenv() - -neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' -neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' -neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' - - -def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO - - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - # Add formatter to console handler - console_handler.setFormatter(formatter) - - # Add console handler to logger - logger.addHandler(console_handler) - - return logger - - -async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str): - search_start = time() - results = await graphiti._search( - query, - COMBINED_HYBRID_SEARCH_RRF, - group_ids=[str(group_id)], - ) - search_end = time() - search_duration = search_end - search_start - - facts = [edge.fact for edge in results.edges] - entity_summaries = [node.name + ': ' + node.summary for node in results.nodes] - context = { - 'facts': facts, - 'entity_summaries': entity_summaries, - 'query': 'Bob: ' + query, - } - - llm_response = await graphiti.llm_client.generate_response( - prompt_library.eval.qa_prompt(context) - ) - response = llm_response.get('ANSWER', '') - - eval_context = { - 'query': 'Bob: ' + query, - 'answer': 'Alice: ' + answer, - 'response': 'Alice: ' + response, - } - - eval_llm_response = await graphiti.llm_client.generate_response( - prompt_library.eval.eval_prompt(eval_context) - ) - eval_response = 1 if eval_llm_response.get('is_correct', False) else 0 - - return { - 'Group id': group_id, - 'Question': query, - 'Answer': answer, - 'Response': response, - 'Score': eval_response, - 'Search Duration (ms)': search_duration * 1000, - } - - -async def main(): - setup_logging() - graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - - fields = [ - 'Group id', - 'Question', - 'Answer', - 'Response', - 'Score', - 'Search Duration (ms)', - ] - with open('../data/msc_eval.csv', 'w', newline='') as file: - writer = csv.DictWriter(file, fieldnames=fields) - writer.writeheader() - - qa = conversation_q_and_a()[0:500] - i = 0 - while i < 500: - qa_chunk = qa[i : i + 20] - group_ids = range(len(qa))[i : i + 20] - results = list( - await semaphore_gather( - *[ - evaluate_qa(graphiti, str(group_id), query, answer) - for group_id, (query, answer) in zip(group_ids, qa_chunk) - ] - ) - ) - - with open('../data/msc_eval.csv', 'a', newline='') as file: - writer = csv.DictWriter(file, fieldnames=fields) - writer.writerows(results) - i += 20 - - await graphiti.close() - - -asyncio.run(main()) diff --git a/examples/multi_session_conversation_memory/msc_runner.py b/examples/multi_session_conversation_memory/msc_runner.py deleted file mode 100644 index 2cef9c58..00000000 --- a/examples/multi_session_conversation_memory/msc_runner.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import asyncio -import logging -import os -import sys - -from dotenv import load_dotenv - -from examples.multi_session_conversation_memory.parse_msc_messages import ( - ParsedMscMessage, - parse_msc_messages, -) -from graphiti_core import Graphiti -from graphiti_core.helpers import semaphore_gather - -load_dotenv() - -neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' -neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' -neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' - - -def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO - - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - # Add formatter to console handler - console_handler.setFormatter(formatter) - - # Add console handler to logger - logger.addHandler(console_handler) - - return logger - - -async def add_conversation(graphiti: Graphiti, group_id: str, messages: list[ParsedMscMessage]): - for i, message in enumerate(messages): - await graphiti.add_episode( - name=f'Message {group_id + "-" + str(i)}', - episode_body=f'{message.speaker_name}: {message.content}', - reference_time=message.actual_timestamp, - source_description='Multi-Session Conversation', - group_id=group_id, - ) - - -async def main(): - setup_logging() - graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - msc_messages = parse_msc_messages() - i = 0 - while i < len(msc_messages): - msc_message_slice = msc_messages[i : i + 10] - group_ids = range(len(msc_messages))[i : i + 10] - - await semaphore_gather( - *[ - add_conversation(graphiti, str(group_id), messages) - for group_id, messages in zip(group_ids, msc_message_slice) - ] - ) - - i += 10 - - -asyncio.run(main()) diff --git a/examples/multi_session_conversation_memory/parse_msc_messages.py b/examples/multi_session_conversation_memory/parse_msc_messages.py deleted file mode 100644 index 4c5bd219..00000000 --- a/examples/multi_session_conversation_memory/parse_msc_messages.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -Copyright 2024, Zep Software, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import json -from datetime import datetime, timezone - -from pydantic import BaseModel - - -class ParsedMscMessage(BaseModel): - speaker_name: str - actual_timestamp: datetime - content: str - group_id: str - - -def parse_msc_messages() -> list[list[ParsedMscMessage]]: - msc_messages: list[list[ParsedMscMessage]] = [] - speakers = ['Alice', 'Bob'] - - with open('../data/msc.jsonl') as file: - data = [json.loads(line) for line in file] - for i, conversation in enumerate(data): - messages: list[ParsedMscMessage] = [] - for previous_dialog in conversation['previous_dialogs']: - dialog = previous_dialog['dialog'] - speaker_idx = 0 - - for utterance in dialog: - content = utterance['text'] - messages.append( - ParsedMscMessage( - speaker_name=speakers[speaker_idx], - content=content, - actual_timestamp=datetime.now(timezone.utc), - group_id=str(i), - ) - ) - speaker_idx += 1 - speaker_idx %= 2 - - dialog = conversation['dialog'] - speaker_idx = 0 - for utterance in dialog: - content = utterance['text'] - messages.append( - ParsedMscMessage( - speaker_name=speakers[speaker_idx], - content=content, - actual_timestamp=datetime.now(timezone.utc), - group_id=str(i), - ) - ) - speaker_idx += 1 - speaker_idx %= 2 - - msc_messages.append(messages) - - return msc_messages - - -def conversation_q_and_a() -> list[tuple[str, str]]: - with open('../data/msc.jsonl') as file: - data = [json.loads(line) for line in file] - - qa: list[tuple[str, str]] = [] - for conversation in data: - query = conversation['self_instruct']['B'] - answer = conversation['self_instruct']['A'] - - qa.append((query, answer)) - return qa diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py index 38d60305..a34c5cb5 100644 --- a/graphiti_core/models/nodes/node_db_queries.py +++ b/graphiti_core/models/nodes/node_db_queries.py @@ -39,7 +39,7 @@ ENTITY_NODE_SAVE_BULK = """ UNWIND $nodes AS node MERGE (n:Entity {uuid: node.uuid}) - SET n:$(n.labels) + SET n:$(node.labels) SET n = node WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding) RETURN n.uuid AS uuid diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index f213b5bc..a898f2f5 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -255,8 +255,8 @@ async def get_by_group_ids( class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) - properties: dict[str, Any] | None = Field( - default=None, description='Additional properties of the node. Dependent on node labels' + attributes: dict[str, Any] | None = Field( + default=None, description='Additional attributes of the node. Dependent on node labels' ) async def generate_name_embedding(self, embedder: EmbedderClient): @@ -277,7 +277,7 @@ async def save(self, driver: AsyncDriver): 'created_at': self.created_at, } - entity_data.update(self.properties) + entity_data.update(self.attributes) result = await driver.execute_query( ENTITY_NODE_SAVE, @@ -303,6 +303,7 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, + properties(n) AS attributes """, uuid=uuid, database_=DEFAULT_DATABASE, @@ -329,6 +330,7 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, + properties(n) AS attributes """, uuids=uuids, database_=DEFAULT_DATABASE, @@ -364,6 +366,7 @@ async def get_by_group_ids( n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, + properties(n) AS attributes ORDER BY n.uuid DESC """ + limit_query, @@ -518,6 +521,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode: labels=record['labels'], created_at=record['created_at'].to_native(), summary=record['summary'], + attributes=record['attributes'], ) diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index 3c1535c6..ebbaec87 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -75,6 +75,7 @@ def extract_message(context: dict[str, Any]) -> list[Message]: 4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later). 5. Be as explicit as possible in your node names, using full names. 6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context. +7. Extract preferences as their own nodes """ return [ Message(role='system', content=sys_prompt), diff --git a/graphiti_core/prompts/summarize_nodes.py b/graphiti_core/prompts/summarize_nodes.py index d1c77e51..0a880a82 100644 --- a/graphiti_core/prompts/summarize_nodes.py +++ b/graphiti_core/prompts/summarize_nodes.py @@ -80,8 +80,8 @@ def summarize_context(context: dict[str, Any]) -> list[Message]: {json.dumps(context['episode_content'], indent=2)} - Given the above MESSAGES and the following ENTITY name and ENTITY CONTEXT, create a summary for the ENTITY. Your summary must only use - information from the provided MESSAGES and from the ENTITY CONTEXT. Your summary should also only contain information relevant to the + Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use + information from the provided MESSAGES. Your summary should also only contain information relevant to the provided ENTITY. Summaries must be under 500 words. In addition, extract any values for the provided entity properties based on their descriptions. @@ -94,9 +94,9 @@ def summarize_context(context: dict[str, Any]) -> list[Message]: {context['node_summary']} - - {json.dumps(context['properties'], indent=2)} - + + {json.dumps(context['attributes'], indent=2)} + """, ), ] diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 8d58c7a7..2710bf07 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -98,7 +98,8 @@ async def get_mentioned_nodes( n.name_embedding AS name_embedding, n.created_at AS created_at, n.summary AS summary, - labels(n) AS labels + labels(n) AS labels, + properties(n) AS properties """, uuids=episode_uuids, database_=DEFAULT_DATABASE, @@ -224,8 +225,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -343,7 +344,8 @@ async def node_fulltext_search( n.name_embedding AS name_embedding, n.created_at AS created_at, n.summary AS summary, - labels(n) AS labels + labels(n) AS labels, + properties(n) AS properties ORDER BY score DESC LIMIT $limit """, @@ -393,7 +395,8 @@ async def node_similarity_search( n.name_embedding AS name_embedding, n.created_at AS created_at, n.summary AS summary, - labels(n) AS labels + labels(n) AS labels, + properties(n) AS properties ORDER BY score DESC LIMIT $limit """, @@ -432,6 +435,7 @@ async def node_bfs_search( n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, + properties(n) AS properties LIMIT $limit """, bfs_origin_node_uuids=bfs_origin_node_uuids, diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 7f3fc8b8..3877b0e1 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -69,7 +69,7 @@ class RawEpisode(BaseModel): async def retrieve_previous_episodes_bulk( - driver: AsyncDriver, episodes: list[EpisodicNode] + driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: previous_episodes_list = await semaphore_gather( *[ @@ -87,11 +87,11 @@ async def retrieve_previous_episodes_bulk( async def add_nodes_and_edges_bulk( - driver: AsyncDriver, - episodic_nodes: list[EpisodicNode], - episodic_edges: list[EpisodicEdge], - entity_nodes: list[EntityNode], - entity_edges: list[EntityEdge], + driver: AsyncDriver, + episodic_nodes: list[EpisodicNode], + episodic_edges: list[EpisodicEdge], + entity_nodes: list[EntityNode], + entity_edges: list[EntityEdge], ): async with driver.session() as session: await session.execute_write( @@ -100,11 +100,11 @@ async def add_nodes_and_edges_bulk( async def add_nodes_and_edges_bulk_tx( - tx: AsyncManagedTransaction, - episodic_nodes: list[EpisodicNode], - episodic_edges: list[EpisodicEdge], - entity_nodes: list[EntityNode], - entity_edges: list[EntityEdge], + tx: AsyncManagedTransaction, + episodic_nodes: list[EpisodicNode], + episodic_edges: list[EpisodicEdge], + entity_nodes: list[EntityNode], + entity_edges: list[EntityEdge], ): episodes = [dict(episode) for episode in episodic_nodes] for episode in episodes: @@ -112,7 +112,7 @@ async def add_nodes_and_edges_bulk_tx( nodes = [dict(entity) for entity in entity_nodes] for node in nodes: node['labels'] = list(set(node['labels'] + ['Entity'])) - + await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes) await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges]) @@ -120,7 +120,7 @@ async def add_nodes_and_edges_bulk_tx( async def extract_nodes_and_edges_bulk( - llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] + llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: extracted_nodes_bulk = await semaphore_gather( *[ @@ -163,16 +163,16 @@ async def extract_nodes_and_edges_bulk( async def dedupe_nodes_bulk( - driver: AsyncDriver, - llm_client: LLMClient, - extracted_nodes: list[EntityNode], + driver: AsyncDriver, + llm_client: LLMClient, + extracted_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: # Compress nodes nodes, uuid_map = node_name_match(extracted_nodes) compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map) - node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] + node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] existing_nodes_chunks: list[list[EntityNode]] = list( await semaphore_gather( @@ -199,13 +199,13 @@ async def dedupe_nodes_bulk( async def dedupe_edges_bulk( - driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] + driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] ) -> list[EntityEdge]: # First compress edges compressed_edges = await compress_edges(llm_client, extracted_edges) edge_chunks = [ - compressed_edges[i: i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE) + compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE) ] relevant_edges_chunks: list[list[EntityEdge]] = list( @@ -241,7 +241,7 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str async def compress_nodes( - llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] + llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] ) -> tuple[list[EntityNode], dict[str, str]]: # We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk if len(nodes) == 0: @@ -362,9 +362,9 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]): async def extract_edge_dates_bulk( - llm_client: LLMClient, - extracted_edges: list[EntityEdge], - episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]], + llm_client: LLMClient, + extracted_edges: list[EntityEdge], + episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]], ) -> list[EntityEdge]: edges: list[EntityEdge] = [] # confirm that all of our edges have at least one episode diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 1f4daecd..95ee72ec 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -34,10 +34,10 @@ async def extract_message_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - custom_prompt='', + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], + custom_prompt='', ) -> list[str]: # Prepare context for LLM context = { @@ -55,10 +55,10 @@ async def extract_message_nodes( async def extract_text_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - custom_prompt='', + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], + custom_prompt='', ) -> list[str]: # Prepare context for LLM context = { @@ -76,7 +76,7 @@ async def extract_text_nodes( async def extract_json_nodes( - llm_client: LLMClient, episode: EpisodicNode, custom_prompt='' + llm_client: LLMClient, episode: EpisodicNode, custom_prompt='' ) -> list[str]: # Prepare context for LLM context = { @@ -94,10 +94,10 @@ async def extract_json_nodes( async def extract_nodes_reflexion( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - node_names: list[str], + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], + node_names: list[str], ) -> list[str]: # Prepare context for LLM context = { @@ -115,10 +115,10 @@ async def extract_nodes_reflexion( async def extract_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - entity_types: dict[str, BaseModel] | None = None, + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], + entity_types: dict[str, BaseModel] | None = None, ) -> list[EntityNode]: start = time() extracted_node_names: list[str] = [] @@ -188,9 +188,9 @@ async def extract_nodes( async def dedupe_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes: list[EntityNode], + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: start = time() @@ -238,12 +238,12 @@ async def dedupe_extracted_nodes( async def resolve_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes_lists: list[list[EntityNode]], - episode: EpisodicNode | None = None, - previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, str] | None = None, + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes_lists: list[list[EntityNode]], + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, str] | None = None, ) -> tuple[list[EntityNode], dict[str, str]]: uuid_map: dict[str, str] = {} resolved_nodes: list[EntityNode] = [] @@ -271,12 +271,12 @@ async def resolve_extracted_nodes( async def resolve_extracted_node( - llm_client: LLMClient, - extracted_node: EntityNode, - existing_nodes: list[EntityNode], - episode: EpisodicNode | None = None, - previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, + llm_client: LLMClient, + extracted_node: EntityNode, + existing_nodes: list[EntityNode], + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, BaseModel] | None = None, ) -> tuple[EntityNode, dict[str, str]]: start = time() @@ -305,7 +305,7 @@ async def resolve_extracted_node( 'previous_episodes': [ep.content for ep in previous_episodes] if previous_episodes is not None else [], - 'properties': [] + 'attributes': [], } entity_type_classes = tuple( @@ -316,21 +316,24 @@ async def resolve_extracted_node( ) for entity_type in entity_type_classes: for field_name in entity_type.__fields__.keys(): - summary_context['properties'].append(field_name) + summary_context['attributes'].append(field_name) - llm_response, node_summary_response = await semaphore_gather( + entity_attributes_model = pydantic.create_model( + 'EntityAttributes', __base__=entity_type_classes + (Summary,) + ) + + llm_response, node_attributes_response = await semaphore_gather( llm_client.generate_response( prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate ), llm_client.generate_response( prompt_library.summarize_nodes.summarize_context(summary_context), - response_model=pydantic.create_model( - 'EntityProperties', __base__=entity_type_classes + (Summary,) - ), + response_model=entity_attributes_model, ), ) - extracted_node.summary = node_summary_response.get('summary', '') + extracted_node.summary = node_attributes_response.get('summary', '') + extracted_node.attributes.update(node_attributes_response) is_duplicate: bool = llm_response.get('is_duplicate', False) uuid: str | None = llm_response.get('uuid', None) @@ -362,8 +365,8 @@ async def resolve_extracted_node( async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], + llm_client: LLMClient, + nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: start = time() From 7f7867849d32e02ca464d91c2a53469b112af96e Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Tue, 11 Feb 2025 11:41:41 -0500 Subject: [PATCH 06/15] update typing --- examples/podcast/podcast_runner.py | 2 +- graphiti_core/nodes.py | 6 +++--- graphiti_core/utils/maintenance/node_operations.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index b46672d6..d5511375 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,7 +63,7 @@ class Person(BaseModel): async def main(): setup_logging() client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - # await clear_data(client.driver) + await clear_data(client.driver) await client.build_indices_and_constraints() messages = parse_podcast_messages() diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index a898f2f5..bd1da026 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -255,8 +255,8 @@ async def get_by_group_ids( class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) - attributes: dict[str, Any] | None = Field( - default=None, description='Additional attributes of the node. Dependent on node labels' + attributes: dict[str, Any] = Field( + default={}, description='Additional attributes of the node. Dependent on node labels' ) async def generate_name_embedding(self, embedder: EmbedderClient): @@ -277,7 +277,7 @@ async def save(self, driver: AsyncDriver): 'created_at': self.created_at, } - entity_data.update(self.attributes) + entity_data.update(self.attributes or {}) result = await driver.execute_query( ENTITY_NODE_SAVE, diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 95ee72ec..77531a3f 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -153,7 +153,7 @@ async def extract_nodes( 'episode_content': episode.content, 'previous_episodes': [ep.content for ep in previous_episodes], 'extracted_entities': extracted_node_names, - 'entity_types': entity_types.keys(), + 'entity_types': entity_types.keys() if entity_types is not None else [], } node_classifications: dict[str, str | None] = {} @@ -171,7 +171,7 @@ async def extract_nodes( # Convert the extracted data into EntityNode objects new_nodes = [] for name in extracted_node_names: - entity_type = node_classifications.get(name, None) + entity_type = node_classifications.get(name) labels = ['Entity'] if entity_type is None else ['Entity', entity_type] new_node = EntityNode( @@ -243,7 +243,7 @@ async def resolve_extracted_nodes( existing_nodes_lists: list[list[EntityNode]], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, str] | None = None, + entity_types: dict[str, BaseModel] | None = None, ) -> tuple[list[EntityNode], dict[str, str]]: uuid_map: dict[str, str] = {} resolved_nodes: list[EntityNode] = [] From 9550e00fdf680cbc5ebe62ec588d2569a2edc8c3 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Tue, 11 Feb 2025 11:59:32 -0500 Subject: [PATCH 07/15] update types --- .../utils/maintenance/node_operations.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 77531a3f..34f092ea 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -308,12 +308,18 @@ async def resolve_extracted_node( 'attributes': [], } - entity_type_classes = tuple( - filter( - lambda x: x is not None, - [entity_types.get(entity_type) for entity_type in extracted_node.labels], + entity_type_classes: tuple[BaseModel] = tuple() + if entity_types is not None: + entity_type_classes: tuple[BaseModel] = tuple( + filter( + lambda x: x is not None, + [ + entity_types.get(entity_type, BaseModel()) + for entity_type in extracted_node.labels + ], + ) ) - ) + for entity_type in entity_type_classes: for field_name in entity_type.__fields__.keys(): summary_context['attributes'].append(field_name) From 8d2e7c604bf31c8cc625b3794a6ac871db92e1ef Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Wed, 12 Feb 2025 14:22:27 -0500 Subject: [PATCH 08/15] updates --- graphiti_core/utils/maintenance/node_operations.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 34f092ea..3c9261d2 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -310,19 +310,16 @@ async def resolve_extracted_node( entity_type_classes: tuple[BaseModel] = tuple() if entity_types is not None: - entity_type_classes: tuple[BaseModel] = tuple( + entity_type_classes + tuple( filter( lambda x: x is not None, - [ - entity_types.get(entity_type, BaseModel()) - for entity_type in extracted_node.labels - ], + [entity_types.get(entity_type, BaseModel) for entity_type in extracted_node.labels], ) ) for entity_type in entity_type_classes: - for field_name in entity_type.__fields__.keys(): - summary_context['attributes'].append(field_name) + for field_name in entity_type.model_fields.keys(): + summary_context.get('attributes', []).append(field_name) entity_attributes_model = pydantic.create_model( 'EntityAttributes', __base__=entity_type_classes + (Summary,) From 89965a224830ba3c00e5432c218850480ef80e84 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Wed, 12 Feb 2025 14:28:00 -0500 Subject: [PATCH 09/15] Update graphiti_core/utils/maintenance/node_operations.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- graphiti_core/utils/maintenance/node_operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 3c9261d2..4c997614 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -310,7 +310,7 @@ async def resolve_extracted_node( entity_type_classes: tuple[BaseModel] = tuple() if entity_types is not None: - entity_type_classes + tuple( + entity_type_classes = entity_type_classes + tuple( filter( lambda x: x is not None, [entity_types.get(entity_type, BaseModel) for entity_type in extracted_node.labels], From 46b279bb9b1dab604b9c6361e49bf5b2a3118de7 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Wed, 12 Feb 2025 14:28:51 -0500 Subject: [PATCH 10/15] fix warning --- graphiti_core/utils/maintenance/node_operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 4c997614..d277fe0b 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -308,7 +308,7 @@ async def resolve_extracted_node( 'attributes': [], } - entity_type_classes: tuple[BaseModel] = tuple() + entity_type_classes: tuple[BaseModel, ...] = tuple() if entity_types is not None: entity_type_classes = entity_type_classes + tuple( filter( From e52b3c4cde7cd71acbe6926d7514b54ee06c49a7 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 13 Feb 2025 11:37:26 -0500 Subject: [PATCH 11/15] mypy updates --- graphiti_core/utils/maintenance/node_operations.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index d277fe0b..e5432fa0 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -318,10 +318,11 @@ async def resolve_extracted_node( ) for entity_type in entity_type_classes: - for field_name in entity_type.model_fields.keys(): + for field_name in entity_type.model_fields: summary_context.get('attributes', []).append(field_name) - entity_attributes_model = pydantic.create_model( + # type: ignore[arg-type] + entity_attributes_model: BaseModel = pydantic.create_model( 'EntityAttributes', __base__=entity_type_classes + (Summary,) ) From 802ef255d6e1a01f477f1404045d9c9de7f537f8 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 13 Feb 2025 12:02:48 -0500 Subject: [PATCH 12/15] update properties --- examples/podcast/podcast_runner.py | 2 +- graphiti_core/nodes.py | 1 + graphiti_core/search/search_utils.py | 12 ++++++------ graphiti_core/utils/bulk_utils.py | 18 +++++++++++++++--- .../utils/maintenance/node_operations.py | 7 ++++--- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index d5511375..b46672d6 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,7 +63,7 @@ class Person(BaseModel): async def main(): setup_logging() client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - await clear_data(client.driver) + # await clear_data(client.driver) await client.build_indices_and_constraints() messages = parse_podcast_messages() diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index bd1da026..508a0d4a 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -272,6 +272,7 @@ async def save(self, driver: AsyncDriver): entity_data: dict[str, Any] = { 'uuid': self.uuid, 'name': self.name, + 'name_embedding': self.name_embedding, 'group_id': self.group_id, 'summary': self.summary, 'created_at': self.created_at, diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 2710bf07..ef9b6eb2 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -99,7 +99,7 @@ async def get_mentioned_nodes( n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, - properties(n) AS properties + properties(n) AS attributes """, uuids=episode_uuids, database_=DEFAULT_DATABASE, @@ -225,8 +225,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -345,7 +345,7 @@ async def node_fulltext_search( n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, - properties(n) AS properties + properties(n) AS attributes ORDER BY score DESC LIMIT $limit """, @@ -396,7 +396,7 @@ async def node_similarity_search( n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, - properties(n) AS properties + properties(n) AS attributes ORDER BY score DESC LIMIT $limit """, @@ -435,7 +435,7 @@ async def node_bfs_search( n.created_at AS created_at, n.summary AS summary, labels(n) AS labels, - properties(n) AS properties + properties(n) AS attributes LIMIT $limit """, bfs_origin_node_uuids=bfs_origin_node_uuids, diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 3877b0e1..b2340d63 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -23,6 +23,7 @@ from neo4j import AsyncDriver, AsyncManagedTransaction from numpy import dot, sqrt from pydantic import BaseModel +from typing_extensions import Any from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge from graphiti_core.helpers import semaphore_gather @@ -109,9 +110,20 @@ async def add_nodes_and_edges_bulk_tx( episodes = [dict(episode) for episode in episodic_nodes] for episode in episodes: episode['source'] = str(episode['source'].value) - nodes = [dict(entity) for entity in entity_nodes] - for node in nodes: - node['labels'] = list(set(node['labels'] + ['Entity'])) + nodes: list[dict[str, Any]] = [] + for node in entity_nodes: + entity_data: dict[str, Any] = { + 'uuid': node.uuid, + 'name': node.name, + 'name_embedding': node.name_embedding, + 'group_id': node.group_id, + 'summary': node.summary, + 'created_at': node.created_at, + } + + entity_data.update(node.attributes or {}) + entity_data['labels'] = list(set(node.labels + ['Entity'])) + nodes.append(entity_data) await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes) await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index e5432fa0..8f71a781 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -309,11 +309,12 @@ async def resolve_extracted_node( } entity_type_classes: tuple[BaseModel, ...] = tuple() + # type: ignore if entity_types is not None: entity_type_classes = entity_type_classes + tuple( filter( lambda x: x is not None, - [entity_types.get(entity_type, BaseModel) for entity_type in extracted_node.labels], + [entity_types.get(entity_type) for entity_type in extracted_node.labels], ) ) @@ -321,8 +322,8 @@ async def resolve_extracted_node( for field_name in entity_type.model_fields: summary_context.get('attributes', []).append(field_name) - # type: ignore[arg-type] - entity_attributes_model: BaseModel = pydantic.create_model( + # type: ignore + entity_attributes_model = pydantic.create_model( 'EntityAttributes', __base__=entity_type_classes + (Summary,) ) From 37178e67f16d957f205c067443b60e33fc3c6d53 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 13 Feb 2025 12:09:28 -0500 Subject: [PATCH 13/15] mypy ignore --- examples/podcast/podcast_runner.py | 2 +- graphiti_core/utils/maintenance/node_operations.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index b46672d6..d5511375 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,7 +63,7 @@ class Person(BaseModel): async def main(): setup_logging() client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - # await clear_data(client.driver) + await clear_data(client.driver) await client.build_indices_and_constraints() messages = parse_podcast_messages() diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 8f71a781..4853cd84 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -309,22 +309,21 @@ async def resolve_extracted_node( } entity_type_classes: tuple[BaseModel, ...] = tuple() - # type: ignore - if entity_types is not None: + if entity_types is not None: # type: ignore entity_type_classes = entity_type_classes + tuple( filter( lambda x: x is not None, - [entity_types.get(entity_type) for entity_type in extracted_node.labels], + [entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore ) ) for entity_type in entity_type_classes: for field_name in entity_type.model_fields: - summary_context.get('attributes', []).append(field_name) + summary_context.get('attributes', []).append(field_name) # type: ignore - # type: ignore - entity_attributes_model = pydantic.create_model( - 'EntityAttributes', __base__=entity_type_classes + (Summary,) + entity_attributes_model = pydantic.create_model( # type: ignore + 'EntityAttributes', + __base__=entity_type_classes + (Summary,), # type: ignore ) llm_response, node_attributes_response = await semaphore_gather( From ae0e2488f20f0f0636086d14a36ebd2a04902fe7 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 13 Feb 2025 12:13:32 -0500 Subject: [PATCH 14/15] mypy types --- graphiti_core/utils/maintenance/node_operations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 4853cd84..3d4f6bf3 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -312,7 +312,7 @@ async def resolve_extracted_node( if entity_types is not None: # type: ignore entity_type_classes = entity_type_classes + tuple( filter( - lambda x: x is not None, + lambda x: x is not None, # type: ignore [entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore ) ) From 58168d4908a4a636bb80368a571dff61c7175915 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 13 Feb 2025 12:14:08 -0500 Subject: [PATCH 15/15] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index aa0a52f1..790a3a72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.6.1" +version = "0.7.0" description = "A temporal graph building library" authors = [ "Paul Paliychuk ",