-
Notifications
You must be signed in to change notification settings - Fork 901
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added Azure support & refactored code. Refactoring includes - - ProviderFactory provides a specific provider impl. - Lazily import the provider based on config passed to Client. - Reduced the number of files.
- Loading branch information
1 parent
00105ae
commit 0a041e4
Showing
30 changed files
with
440 additions
and
189 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
.idea/ | ||
.vscode/ | ||
__pycache__/ | ||
env/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .client import Client | ||
from .provider import ProviderNames |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from .provider import ProviderFactory, ProviderNames | ||
|
||
|
||
class Client: | ||
def __init__(self, provider_configs: dict = {}): | ||
""" | ||
Initialize the client with provider configurations. | ||
Use the ProviderFactory to create provider instances. | ||
""" | ||
self.providers = {} | ||
self.provider_configs = provider_configs | ||
for provider_key, config in provider_configs.items(): | ||
# Check if the provider key is a valid ProviderNames enum | ||
if not isinstance(provider_key, ProviderNames): | ||
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") | ||
# Store the value of the enum in the providers dictionary | ||
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config) | ||
|
||
self._chat = None | ||
|
||
def configure(self, provider_configs: dict = None): | ||
""" | ||
Configure the client with provider configurations. | ||
""" | ||
if provider_configs is None: | ||
return | ||
|
||
self.provider_configs.update(provider_configs) | ||
|
||
for provider_key, config in self.provider_configs.items(): | ||
if not isinstance(provider_key, ProviderNames): | ||
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") | ||
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config) | ||
|
||
@property | ||
def chat(self): | ||
"""Return the chat API interface.""" | ||
if not self._chat: | ||
self._chat = Chat(self) | ||
return self._chat | ||
|
||
|
||
class Chat: | ||
def __init__(self, client: 'Client'): | ||
self.client = client | ||
|
||
@property | ||
def completions(self): | ||
"""Return the completions interface.""" | ||
return Completions(self.client) | ||
|
||
|
||
class Completions: | ||
def __init__(self, client: 'Client'): | ||
self.client = client | ||
|
||
def create(self, model: str, messages: list, **kwargs): | ||
""" | ||
Create chat completion based on the model, messages, and any extra arguments. | ||
""" | ||
# Check that correct format is used | ||
if ':' not in model: | ||
raise ValueError(f"Invalid model format. Expected 'provider:model', got '{model}'") | ||
|
||
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name" | ||
provider_key, model_name = model.split(":", 1) | ||
|
||
if provider_key not in ProviderNames._value2member_map_: | ||
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") | ||
|
||
if provider_key not in self.client.providers: | ||
config = {} | ||
if provider_key in self.client.provider_configs: | ||
config = self.client.provider_configs[provider_key] | ||
self.client.providers[provider_key] = ProviderFactory.create_provider(ProviderNames(provider_key), config) | ||
|
||
provider = self.client.providers.get(provider_key) | ||
if not provider: | ||
raise ValueError(f"Could not load provider for {provider_key}.") | ||
|
||
# Delegate the chat completion to the correct provider's implementation | ||
# Any additional arguments will be passed to the provider's implementation. | ||
# Eg: max_tokens, temperature, etc. | ||
return provider.chat_completions_create(model_name, messages, **kwargs) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1 @@ | ||
"""Provides the ProviderInterface for defining the interface that all FM providers must implement.""" | ||
|
||
from .provider_interface import ProviderInterface | ||
from .chat_completion_response import ChatCompletionResponse |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Provides the individual provider interfaces for each FM provider.""" | ||
|
||
from .anthropic_interface import AnthropicInterface | ||
from .aws_bedrock_interface import AWSBedrockInterface | ||
from .fireworks_interface import FireworksInterface | ||
from .groq_interface import GroqInterface | ||
from .mistral_interface import MistralInterface | ||
from .octo_interface import OctoInterface | ||
from .ollama_interface import OllamaInterface | ||
from .openai_interface import OpenAIInterface | ||
from .replicate_interface import ReplicateInterface | ||
from .together_interface import TogetherInterface | ||
from .google_interface import GoogleInterface |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from abc import ABC, abstractmethod | ||
from enum import Enum | ||
import importlib | ||
|
||
class LLMError(Exception): | ||
"""Custom exception for LLM errors.""" | ||
def __init__(self, message): | ||
super().__init__(message) | ||
|
||
|
||
class Provider(ABC): | ||
@abstractmethod | ||
def chat_completions_create(self, model, messages): | ||
"""Abstract method for chat completion calls, to be implemented by each provider.""" | ||
pass | ||
|
||
|
||
class ProviderNames(str, Enum): | ||
OPENAI = 'openai' | ||
AWS_BEDROCK = 'aws-bedrock' | ||
ANTHROPIC = 'anthropic' | ||
AZURE = 'azure' | ||
|
||
|
||
class ProviderFactory: | ||
"""Factory to register and create provider instances based on keys.""" | ||
|
||
_provider_info = { | ||
ProviderNames.OPENAI: ('aisuite.providers.openai_provider', 'OpenAIProvider'), | ||
ProviderNames.AWS_BEDROCK: ('aisuite.providers.aws_bedrock_provider', 'AWSBedrockProvider'), | ||
ProviderNames.ANTHROPIC: ('aisuite.providers.anthropic_provider', 'AnthropicProvider'), | ||
ProviderNames.AZURE: ('aisuite.providers.azure_provider', 'AzureProvider'), | ||
} | ||
|
||
|
||
@classmethod | ||
def create_provider(cls, provider_key, config): | ||
"""Dynamically import and create an instance of a provider based on the provider key.""" | ||
if not isinstance(provider_key, ProviderNames): | ||
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") | ||
|
||
module_name, class_name = cls._get_provider_info(provider_key) | ||
if not module_name: | ||
raise ValueError(f"Provider {provider_key.value} is not supported") | ||
|
||
# Lazily load the module | ||
try: | ||
module = importlib.import_module(module_name) | ||
except ImportError as e: | ||
raise ImportError(f"Could not import module {module_name}: {str(e)}") | ||
|
||
# Instantiate the provider class | ||
provider_class = getattr(module, class_name) | ||
return provider_class(**config) | ||
|
||
@classmethod | ||
def _get_provider_info(cls, provider_key): | ||
"""Return the module name and class name for a given provider key.""" | ||
return cls._provider_info.get(provider_key, (None, None)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1 @@ | ||
"""Provides the individual provider interfaces for each FM provider.""" | ||
|
||
from .anthropic_interface import AnthropicInterface | ||
from .aws_bedrock_interface import AWSBedrockInterface | ||
from .fireworks_interface import FireworksInterface | ||
from .groq_interface import GroqInterface | ||
from .mistral_interface import MistralInterface | ||
from .octo_interface import OctoInterface | ||
from .ollama_interface import OllamaInterface | ||
from .openai_interface import OpenAIInterface | ||
from .replicate_interface import ReplicateInterface | ||
from .together_interface import TogetherInterface | ||
from .google_interface import GoogleInterface | ||
# Empty __init__.py |
Oops, something went wrong.