-
Notifications
You must be signed in to change notification settings - Fork 901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring client and providers. #27
Merged
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
d9ede71
Refactoring client and providers.
rohitprasad15 fca12d3
Added older provider interface back for tests to pass.
rohit-rptless 7d6d3a4
Minor changes for making tests pass.
rohit-rptless 510b881
Allowing provider name as string in config()
rohit-rptless 5236a96
Addressing review comments.
rohit-rptless File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
.idea/ | ||
.vscode/ | ||
__pycache__/ | ||
env/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .client import Client | ||
from .provider import ProviderNames |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from .provider import ProviderFactory, ProviderNames | ||
|
||
|
||
class Client: | ||
def __init__(self, provider_configs: dict = {}): | ||
""" | ||
Initialize the client with provider configurations. | ||
Use the ProviderFactory to create provider instances. | ||
""" | ||
self.providers = {} | ||
self.provider_configs = provider_configs | ||
self._chat = None | ||
self._initialize_providers() | ||
|
||
def _initialize_providers(self): | ||
"""Helper method to initialize or update providers.""" | ||
for provider_key, config in self.provider_configs.items(): | ||
provider_key = self._validate_provider_key(provider_key) | ||
self.providers[provider_key.value] = ProviderFactory.create_provider( | ||
provider_key, config | ||
) | ||
|
||
def _validate_provider_key(self, provider_key): | ||
""" | ||
Validate if the provider key is part of ProviderNames enum. | ||
Allow strings as well and convert them to ProviderNames. | ||
""" | ||
if isinstance(provider_key, str): | ||
if provider_key not in ProviderNames._value2member_map_: | ||
raise ValueError(f"Provider {provider_key} is not a valid provider") | ||
return ProviderNames(provider_key) | ||
|
||
if isinstance(provider_key, ProviderNames): | ||
return provider_key | ||
|
||
raise ValueError( | ||
f"Provider {provider_key} should either be a string or enum ProviderNames" | ||
) | ||
|
||
def configure(self, provider_configs: dict = None): | ||
""" | ||
Configure the client with provider configurations. | ||
""" | ||
if provider_configs is None: | ||
return | ||
|
||
self.provider_configs.update(provider_configs) | ||
self._initialize_providers() # NOTE: This will override existing provider instances. | ||
|
||
@property | ||
def chat(self): | ||
"""Return the chat API interface.""" | ||
if not self._chat: | ||
self._chat = Chat(self) | ||
return self._chat | ||
|
||
|
||
class Chat: | ||
def __init__(self, client: "Client"): | ||
self.client = client | ||
self._completions = Completions(self.client) | ||
|
||
@property | ||
def completions(self): | ||
"""Return the completions interface.""" | ||
return self._completions | ||
|
||
|
||
class Completions: | ||
def __init__(self, client: "Client"): | ||
self.client = client | ||
|
||
def create(self, model: str, messages: list, **kwargs): | ||
""" | ||
Create chat completion based on the model, messages, and any extra arguments. | ||
""" | ||
# Check that correct format is used | ||
if ":" not in model: | ||
raise ValueError( | ||
f"Invalid model format. Expected 'provider:model', got '{model}'" | ||
) | ||
ksolo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name" | ||
provider_key, model_name = model.split(":", 1) | ||
|
||
if provider_key not in ProviderNames._value2member_map_: | ||
raise ValueError( | ||
f"Provider {provider_key} is not a valid ProviderNames enum" | ||
) | ||
|
||
if provider_key not in self.client.providers: | ||
config = {} | ||
if provider_key in self.client.provider_configs: | ||
config = self.client.provider_configs[provider_key] | ||
self.client.providers[provider_key] = ProviderFactory.create_provider( | ||
ProviderNames(provider_key), config | ||
) | ||
ksolo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
provider = self.client.providers.get(provider_key) | ||
if not provider: | ||
raise ValueError(f"Could not load provider for {provider_key}.") | ||
|
||
# Delegate the chat completion to the correct provider's implementation | ||
# Any additional arguments will be passed to the provider's implementation. | ||
# Eg: max_tokens, temperature, etc. | ||
return provider.chat_completions_create(model_name, messages, **kwargs) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,2 @@ | ||
"""Provides the ProviderInterface for defining the interface that all FM providers must implement.""" | ||
|
||
from .provider_interface import ProviderInterface | ||
from .chat_completion_response import ChatCompletionResponse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
import importlib | ||
|
||
|
||
class LLMError(Exception): | ||
"""Custom exception for LLM errors.""" | ||
|
||
def __init__(self, message): | ||
super().__init__(message) | ||
|
||
|
||
class Provider(ABC): | ||
@abstractmethod | ||
def chat_completions_create(self, model, messages): | ||
"""Abstract method for chat completion calls, to be implemented by each provider.""" | ||
pass | ||
|
||
|
||
class ProviderNames(str, Enum): | ||
OPENAI = "openai" | ||
AWS_BEDROCK = "aws-bedrock" | ||
ANTHROPIC = "anthropic" | ||
AZURE = "azure" | ||
|
||
|
||
class ProviderFactory: | ||
"""Factory to register and create provider instances based on keys.""" | ||
|
||
_provider_info = { | ||
ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"), | ||
ProviderNames.AWS_BEDROCK: ( | ||
"aisuite.providers.aws_bedrock_provider", | ||
"AWSBedrockProvider", | ||
), | ||
ProviderNames.ANTHROPIC: ( | ||
"aisuite.providers.anthropic_provider", | ||
"AnthropicProvider", | ||
), | ||
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"), | ||
} | ||
|
||
@classmethod | ||
def create_provider(cls, provider_key, config): | ||
"""Dynamically import and create an instance of a provider based on the provider key.""" | ||
if not isinstance(provider_key, ProviderNames): | ||
raise ValueError( | ||
f"Provider {provider_key} is not a valid ProviderNames enum" | ||
) | ||
|
||
module_name, class_name = cls._get_provider_info(provider_key) | ||
if not module_name: | ||
raise ValueError(f"Provider {provider_key.value} is not supported") | ||
|
||
# Lazily load the module | ||
try: | ||
module = importlib.import_module(module_name) | ||
except ImportError as e: | ||
raise ImportError(f"Could not import module {module_name}: {str(e)}") | ||
rohitprasad15 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Instantiate the provider class | ||
provider_class = getattr(module, class_name) | ||
return provider_class(**config) | ||
|
||
@classmethod | ||
def _get_provider_info(cls, provider_key): | ||
"""Return the module name and class name for a given provider key.""" | ||
return cls._provider_info.get(provider_key, (None, None)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import anthropic | ||
from aisuite.provider import Provider | ||
from aisuite.framework import ChatCompletionResponse | ||
|
||
# Define a constant for the default max_tokens value | ||
DEFAULT_MAX_TOKENS = 4096 | ||
|
||
|
||
class AnthropicProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the Anthropic provider with the given configuration. | ||
Pass the entire configuration dictionary to the Anthropic client constructor. | ||
""" | ||
|
||
self.client = anthropic.Anthropic(**config) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
# Check if the fist message is a system message | ||
if messages[0]["role"] == "system": | ||
system_message = messages[0]["content"] | ||
messages = messages[1:] | ||
else: | ||
system_message = None | ||
|
||
# kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS) | ||
if "max_tokens" not in kwargs: | ||
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS | ||
|
||
return self.normalize_response( | ||
self.client.messages.create( | ||
model=model, system=system_message, messages=messages, **kwargs | ||
) | ||
) | ||
|
||
def normalize_response(self, response): | ||
"""Normalize the response from the Anthropic API to match OpenAI's response format.""" | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].message.content = response.content[0].text | ||
return normalized_response |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the
provider_configs
is a simple dict, it could be helpful to add the args to doc string to explain what type of data is expected.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!