Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Fix Typing Issues #27

Merged
merged 10 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ format:
# Lint code
lint:
$(RUFF) check
$(MYPY) . --show-column-numbers --show-error-codes --pretty
$(MYPY) ./core --show-column-numbers --show-error-codes --pretty

# Run tests
test:
Expand Down
44 changes: 20 additions & 24 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | N
else:
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv('OPENAI_API_KEY'),
api_key=os.getenv('OPENAI_API_KEY') or '',
model='gpt-4o-mini',
base_url='https://api.openai.com/v1',
)
Expand All @@ -72,28 +72,16 @@ async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
sources: list[str] | None = 'messages',
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, reference_time, last_n, sources)

# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
return await retrieve_episodes(self.driver, reference_time, last_n)

async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime | None = None,
episode_type: str | None = 'string', # TODO: this field isn't used yet?
reference_time: datetime,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
Expand All @@ -104,7 +92,7 @@ async def add_episode(
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings
embedder = self.llm_client.get_embedder()
now = datetime.now()

previous_episodes = await self.retrieve_episodes(reference_time)
Expand Down Expand Up @@ -234,7 +222,7 @@ async def add_episode_bulk(
):
try:
start = time()
embedder = self.llm_client.client.embeddings
embedder = self.llm_client.get_embedder()
now = datetime.now()

episodes = [
Expand Down Expand Up @@ -276,14 +264,22 @@ async def add_episode_bulk(
await asyncio.gather(*[node.save(self.driver) for node in nodes])

# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map)
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map)
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
extracted_edges, uuid_map
)
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
)

# save episodic edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
)

# Dedupe extracted edges
edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges)
edges = await dedupe_edges_bulk(
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
)
logger.info(f'extracted edge length: {len(edges)}')

# invalidate edges
Expand All @@ -302,18 +298,18 @@ async def search(self, query: str, num_results=10):
edges = (
await hybrid_search(
self.driver,
self.llm_client.client.embeddings,
self.llm_client.get_embedder(),
query,
datetime.now(),
search_config,
)
)['edges']

facts = [edge.fact for edge in edges]
facts = [edge.fact for edge in edges if isinstance(edge, EntityEdge)]

return facts

async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
return await hybrid_search(
self.driver, self.llm_client.client.embeddings, query, timestamp, config
self.driver, self.llm_client.get_embedder(), query, timestamp, config
)
8 changes: 7 additions & 1 deletion core/llm_client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing
from abc import ABC, abstractmethod

from ..prompts.models import Message
from .config import LLMConfig


Expand All @@ -9,5 +11,9 @@ def __init__(self, config: LLMConfig):
pass

@abstractmethod
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
def get_embedder(self) -> typing.Any:
pass

@abstractmethod
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
pass
19 changes: 16 additions & 3 deletions core/llm_client/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import logging
import typing

from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig

Expand All @@ -14,16 +17,26 @@ def __init__(self, config: LLMConfig):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
self.model = config.model

async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
def get_embedder(self) -> typing.Any:
return self.client.embeddings

async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
messages=openai_messages,
temperature=0.1,
max_tokens=3000,
response_format={'type': 'json_object'},
)
return json.loads(response.choices[0].message.content)
result = response.choices[0].message.content or ''
return json.loads(result)
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
3 changes: 2 additions & 1 deletion core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class EpisodicNode(Node):
source: str = Field(description='source type')
source_description: str = Field(description='description of the data source')
content: str = Field(description='raw episode data')
valid_at: datetime = Field(
valid_at: datetime | None = Field(
description='datetime of when the original document was created',
default=None,
)
entity_edges: list[str] = Field(
description='list of entity edges referenced in this episode',
Expand Down
15 changes: 8 additions & 7 deletions core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion


class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
edge_list: PromptVersion


class Versions(TypedDict):
Expand All @@ -15,7 +16,7 @@ class Versions(TypedDict):
edge_list: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand All @@ -34,7 +35,7 @@ def v1(context: dict[str, any]) -> list[Message]:

Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
2. If Any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges

Expand All @@ -55,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand All @@ -74,7 +75,7 @@ def v2(context: dict[str, any]) -> list[Message]:

Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
2. If Any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges

Expand All @@ -97,7 +98,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]


def edge_list(context: dict[str, any]) -> list[Message]:
def edge_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand All @@ -112,7 +113,7 @@ def edge_list(context: dict[str, any]) -> list[Message]:
{json.dumps(context['edges'], indent=2)}

Task:
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
If Any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges

Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
Expand Down
12 changes: 6 additions & 6 deletions core/prompts/dedupe_nodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion

Expand All @@ -16,7 +16,7 @@ class Versions(TypedDict):
node_list: PromptVersion


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand All @@ -35,7 +35,7 @@ def v1(context: dict[str, any]) -> list[Message]:

Task:
1. start with the list of nodes from New Nodes
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
2. If Any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
node in the list
3. Respond with the resulting list of nodes

Expand All @@ -56,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand All @@ -74,7 +74,7 @@ def v2(context: dict[str, any]) -> list[Message]:
{json.dumps(context['extracted_nodes'], indent=2)}

Task:
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
If Any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list

Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
Expand All @@ -96,7 +96,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]


def node_list(context: dict[str, any]) -> list[Message]:
def node_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down
6 changes: 3 additions & 3 deletions core/prompts/extract_edges.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion

Expand All @@ -14,7 +14,7 @@ class Versions(TypedDict):
v2: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -70,7 +70,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down
8 changes: 4 additions & 4 deletions core/prompts/extract_nodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion

Expand All @@ -16,7 +16,7 @@ class Versions(TypedDict):
v3: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -64,7 +64,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -105,7 +105,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]


def v3(context: dict[str, any]) -> list[Message]:
def v3(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""

user_prompt = f"""
Expand Down
6 changes: 3 additions & 3 deletions core/prompts/invalidate_edges.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion

Expand All @@ -11,7 +11,7 @@ class Versions(TypedDict):
v1: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand All @@ -20,7 +20,7 @@ def v1(context: dict[str, any]) -> list[Message]:
Message(
role='user',
content=f"""
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if Any, should be invalidated due to contradictions or updates in the new edges.
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.

Expand Down
Loading