Skip to content

Commit

Permalink
fix: Address graph disconnect (#7)
Browse files Browse the repository at this point in the history
* fix: Address graph disconnect

* chore: Remove valid_to and valid_from setting in extract edges step (will be handled during invalidation step)
  • Loading branch information
paul-paliychuk authored Aug 19, 2024
1 parent 4db3906 commit 40e74a2
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 25 deletions.
14 changes: 12 additions & 2 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,16 @@ async def add_episode(
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)

existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)

logger.info(
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
)
new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
logger.info(
f"Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}"
)
nodes.extend(new_nodes)

extracted_edges = await extract_edges(
Expand All @@ -130,11 +134,17 @@ async def add_episode(
)

existing_edges = await get_relevant_edges(extracted_edges, self.driver)
logger.info(f"Existing edges: {[(e.name, e.uuid) for e in existing_edges]}")
logger.info(
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
)

new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)

logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")

entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
Expand Down
13 changes: 7 additions & 6 deletions core/prompts/extract_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,18 @@ def v2(context: dict[str, any]) -> list[Message]:
Message(
role="user",
content=f"""
Given the following context, extract new edges (relationships) that need to be added to the knowledge graph:
Given the following context, extract edges (relationships) that need to be added to the knowledge graph:
Nodes:
{json.dumps(context['nodes'], indent=2)}
New Episode:
Content: {context['episode_content']}
Previous Episodes:
Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
{context['episode_content']} <-- New Episode
Extract new entity edges based on the content of the current episode, the given nodes, and context from previous episodes.
Extract entity edges based on the content of the current episode, the given nodes, and context from previous episodes.
Guidelines:
1. Create edges only between the provided nodes.
Expand All @@ -168,7 +169,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]
}}
If no new edges need to be added, return an empty list for "new_edges".
If no edges need to be added, return an empty list for "edges".
""",
),
]
Expand Down
37 changes: 36 additions & 1 deletion core/prompts/extract_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
v3: PromptVersion


class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
v3: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
Expand Down Expand Up @@ -103,4 +105,37 @@ def v2(context: dict[str, any]) -> list[Message]:
]


versions: Versions = {"v1": v1, "v2": v2}
def v3(context: dict[str, any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""

user_prompt = f"""
Given the following conversation, extract entity nodes that are explicitly or implicitly mentioned:
Conversation:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
{context["episode_content"]}
Guidelines:
1. ALWAYS extract the speaker/actor as the first node. The speaker is the part before the colon in each line of dialogue.
2. Extract other significant entities, concepts, or actors mentioned in the conversation.
3. Provide concise but informative summaries for each extracted node.
4. Avoid creating nodes for relationships or actions.
Respond with a JSON object in the following format:
{{
"new_nodes": [
{{
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
"labels": ["Entity", "Speaker" for speaker nodes, "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
"""
return [
Message(role="system", content=sys_prompt),
Message(role="user", content=user_prompt),
]


versions: Versions = {"v1": v1, "v2": v2, "v3": v3}
29 changes: 15 additions & 14 deletions core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,21 @@ async def extract_edges(
# Convert the extracted data into EntityEdge objects
edges = []
for edge_data in edges_data:
edge = EntityEdge(
source_node_uuid=edge_data["source_node_uuid"],
target_node_uuid=edge_data["target_node_uuid"],
name=edge_data["relation_type"],
fact=edge_data["fact"],
episodes=[episode.uuid],
created_at=datetime.now(),
valid_at=edge_data["valid_at"],
invalid_at=edge_data["invalid_at"],
)
edges.append(edge)
logger.info(
f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})"
)
if edge_data["target_node_uuid"] and edge_data["source_node_uuid"]:
edge = EntityEdge(
source_node_uuid=edge_data["source_node_uuid"],
target_node_uuid=edge_data["target_node_uuid"],
name=edge_data["relation_type"],
fact=edge_data["fact"],
episodes=[episode.uuid],
created_at=datetime.now(),
valid_at=None,
invalid_at=None,
)
edges.append(edge)
logger.info(
f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})"
)

return edges

Expand Down
2 changes: 1 addition & 1 deletion core/utils/maintenance/node_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def extract_nodes(
}

llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.v2(context)
prompt_library.extract_nodes.v3(context)
)
new_nodes_data = llm_response.get("new_nodes", [])
logger.info(f"Extracted new nodes: {new_nodes_data}")
Expand Down
2 changes: 1 addition & 1 deletion runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def main():
)
await client.add_episode(
name="Message 2",
episode_body="Paul: I love bananas",
episode_body="Paul: I own many bananas",
source_description="WhatsApp Message",
)
await client.add_episode(
Expand Down

0 comments on commit 40e74a2

Please sign in to comment.