Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Fix Typing Issues #27

Merged
merged 10 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ format:
# Lint code
lint:
$(RUFF) check
$(MYPY) . --show-column-numbers --show-error-codes --pretty
$(MYPY) ./core --show-column-numbers --show-error-codes --pretty

# Run tests
test:
Expand Down
44 changes: 20 additions & 24 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | N
else:
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv('OPENAI_API_KEY'),
api_key=os.getenv('OPENAI_API_KEY', default=''),
model='gpt-4o-mini',
base_url='https://api.openai.com/v1',
)
Expand All @@ -72,28 +72,16 @@ async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
sources: list[str] | None = 'messages',
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, reference_time, last_n, sources)

# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
return await retrieve_episodes(self.driver, reference_time, last_n)

async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime | None = None,
episode_type: str | None = 'string', # TODO: this field isn't used yet?
reference_time: datetime,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
Expand All @@ -104,7 +92,7 @@ async def add_episode(
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings
embedder = self.llm_client.get_embedder()
now = datetime.now()

previous_episodes = await self.retrieve_episodes(reference_time)
Expand Down Expand Up @@ -234,7 +222,7 @@ async def add_episode_bulk(
):
try:
start = time()
embedder = self.llm_client.client.embeddings
embedder = self.llm_client.get_embedder()
now = datetime.now()

episodes = [
Expand Down Expand Up @@ -276,14 +264,22 @@ async def add_episode_bulk(
await asyncio.gather(*[node.save(self.driver) for node in nodes])

# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map)
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map)
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
extracted_edges, uuid_map
)
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
)

# save episodic edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
)

# Dedupe extracted edges
edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges)
edges = await dedupe_edges_bulk(
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
)
logger.info(f'extracted edge length: {len(edges)}')

# invalidate edges
Expand All @@ -302,18 +298,18 @@ async def search(self, query: str, num_results=10):
edges = (
await hybrid_search(
self.driver,
self.llm_client.client.embeddings,
self.llm_client.get_embedder(),
query,
datetime.now(),
search_config,
)
)['edges']
).edges

facts = [edge.fact for edge in edges]

return facts

async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
return await hybrid_search(
self.driver, self.llm_client.client.embeddings, query, timestamp, config
self.driver, self.llm_client.get_embedder(), query, timestamp, config
)
8 changes: 7 additions & 1 deletion core/llm_client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing
from abc import ABC, abstractmethod

from ..prompts.models import Message
from .config import LLMConfig


Expand All @@ -9,5 +11,9 @@ def __init__(self, config: LLMConfig):
pass

@abstractmethod
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
def get_embedder(self) -> typing.Any:
pass

@abstractmethod
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
pass
19 changes: 16 additions & 3 deletions core/llm_client/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import logging
import typing

from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig

Expand All @@ -14,16 +17,26 @@ def __init__(self, config: LLMConfig):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
self.model = config.model

async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
def get_embedder(self) -> typing.Any:
return self.client.embeddings

async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
messages=openai_messages,
temperature=0.1,
max_tokens=3000,
response_format={'type': 'json_object'},
)
return json.loads(response.choices[0].message.content)
result = response.choices[0].message.content or ''
return json.loads(result)
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
9 changes: 5 additions & 4 deletions core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion


class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
edge_list: PromptVersion


class Versions(TypedDict):
Expand All @@ -15,7 +16,7 @@ class Versions(TypedDict):
edge_list: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -55,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -97,7 +98,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]


def edge_list(context: dict[str, any]) -> list[Message]:
def edge_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down
8 changes: 4 additions & 4 deletions core/prompts/dedupe_nodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion

Expand All @@ -16,7 +16,7 @@ class Versions(TypedDict):
node_list: PromptVersion


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -56,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -96,7 +96,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]


def node_list(context: dict[str, any]) -> list[Message]:
def node_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down
6 changes: 3 additions & 3 deletions core/prompts/extract_edges.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion

Expand All @@ -14,7 +14,7 @@ class Versions(TypedDict):
v2: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -70,7 +70,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down
8 changes: 4 additions & 4 deletions core/prompts/extract_nodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion

Expand All @@ -16,7 +16,7 @@ class Versions(TypedDict):
v3: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -64,7 +64,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down Expand Up @@ -105,7 +105,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]


def v3(context: dict[str, any]) -> list[Message]:
def v3(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""

user_prompt = f"""
Expand Down
4 changes: 2 additions & 2 deletions core/prompts/invalidate_edges.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion

Expand All @@ -11,7 +11,7 @@ class Versions(TypedDict):
v1: PromptFunction


def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
Expand Down
9 changes: 4 additions & 5 deletions core/prompts/lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict

from .dedupe_edges import (
Prompt as DedupeEdgesPrompt,
Expand Down Expand Up @@ -68,7 +68,7 @@ class VersionWrapper:
def __init__(self, func: PromptFunction):
self.func = func

def __call__(self, context: dict[str, any]) -> list[Message]:
def __call__(self, context: dict[str, Any]) -> list[Message]:
return self.func(context)


Expand All @@ -81,7 +81,7 @@ def __init__(self, versions: dict[str, PromptFunction]):
class PromptLibraryWrapper:
def __init__(self, library: PromptLibraryImpl):
for prompt_type, versions in library.items():
setattr(self, prompt_type, PromptTypeWrapper(versions))
setattr(self, prompt_type, PromptTypeWrapper(versions)) # type: ignore[arg-type]


PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
Expand All @@ -91,5 +91,4 @@ def __init__(self, library: PromptLibraryImpl):
'dedupe_edges': dedupe_edges_versions,
'invalidate_edges': invalidate_edges_versions,
}

prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
6 changes: 3 additions & 3 deletions core/prompts/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Protocol
from typing import Any, Callable, Protocol

from pydantic import BaseModel

Expand All @@ -9,7 +9,7 @@ class Message(BaseModel):


class PromptVersion(Protocol):
def __call__(self, context: dict[str, any]) -> list[Message]: ...
def __call__(self, context: dict[str, Any]) -> list[Message]: ...


PromptFunction = Callable[[dict[str, any]], list[Message]]
PromptFunction = Callable[[dict[str, Any]], list[Message]]
Loading
Loading