Skip to content

Commit

Permalink
allow usage of different openai compatible clients in embedder and en…
Browse files Browse the repository at this point in the history
…coder
  • Loading branch information
Hedrekao committed Feb 26, 2025
1 parent 29a071b commit 0374108
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
5 changes: 5 additions & 0 deletions graphiti_core/cross_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .bge_reranker_client import BGERerankerClient
from .client import CrossEncoderClient
from .openai_reranker_client import OpenAIRerankerClient

__all__ = ['CrossEncoderClient', 'BGERerankerClient', 'OpenAIRerankerClient']
14 changes: 10 additions & 4 deletions graphiti_core/cross_encoder/openai_reranker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,27 @@ class BooleanClassifier(BaseModel):


class OpenAIRerankerClient(CrossEncoderClient):
def __init__(self, config: LLMConfig | None = None):
def __init__(
self,
config: LLMConfig | None = None,
client: Any = None,
):
"""
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
if config is None:
config = LLMConfig()

self.config = config
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else:
self.client = client

async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
openai_messages_list: Any = [
Expand All @@ -62,7 +68,7 @@ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]
Message(
role='user',
content=f"""
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
Respond with "True" if PASSAGE is relevant to QUERY and "False" otherwise.
<PASSAGE>
{passage}
</PASSAGE>
Expand Down
13 changes: 11 additions & 2 deletions graphiti_core/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

from collections.abc import Iterable
from typing import Any

from openai import AsyncOpenAI
from openai.types import EmbeddingModel
Expand All @@ -35,11 +36,19 @@ class OpenAIEmbedder(EmbedderClient):
OpenAI Embedder Client
"""

def __init__(self, config: OpenAIEmbedderConfig | None = None):
def __init__(
self,
config: OpenAIEmbedderConfig | None = None,
client: Any = None,
):
if config is None:
config = OpenAIEmbedderConfig()
self.config = config
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)

if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else:
self.client = client

async def create(
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
Expand Down

0 comments on commit 0374108

Please sign in to comment.