Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom ontology #262

Merged
merged 15 commits into from
Feb 13, 2025
142 changes: 0 additions & 142 deletions examples/multi_session_conversation_memory/msc_eval.py

This file was deleted.

89 changes: 0 additions & 89 deletions examples/multi_session_conversation_memory/msc_runner.py

This file was deleted.

85 changes: 0 additions & 85 deletions examples/multi_session_conversation_memory/parse_msc_messages.py

This file was deleted.

8 changes: 8 additions & 0 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys

from dotenv import load_dotenv
from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages

from graphiti_core import Graphiti
Expand Down Expand Up @@ -53,6 +54,12 @@ def setup_logging():
return logger


class Person(BaseModel):
first_name: str | None = Field(..., description='First name')
last_name: str | None = Field(..., description='Last name')
occupation: str | None = Field(..., description="The person's work occupation")


async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
Expand All @@ -67,6 +74,7 @@ async def main():
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='podcast',
entity_types={'Person': Person},
)


Expand Down
6 changes: 5 additions & 1 deletion graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ async def add_episode(
group_id: str = '',
uuid: str | None = None,
update_communities: bool = False,
entity_types: dict[str, BaseModel] | None = None,
) -> AddEpisodeResults:
"""
Process an episode and update the graph.
Expand Down Expand Up @@ -336,7 +337,9 @@ async def add_episode_endpoint(episode_data: EpisodeData):

# Extract entities as nodes

extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes, entity_types
)
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

# Calculate Embeddings
Expand All @@ -362,6 +365,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):
existing_nodes_lists,
episode,
previous_episodes,
entity_types,
),
extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
Expand Down
6 changes: 4 additions & 2 deletions graphiti_core/models/nodes/node_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@

ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n:$($labels)
SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""

ENTITY_NODE_SAVE_BULK = """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
SET n:$(node.labels)
SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""
Expand Down
Loading