Skip to content

Commit

Permalink
Refactoring client and providers.
Browse files Browse the repository at this point in the history
Added Azure support & refactored code.
Refactoring includes -
- ProviderFactory provides a specific provider impl.
- Lazily import the provider based on config passed to Client.
- Reduced the number of files.
  • Loading branch information
rohitprasad15 authored and rohit-rptless committed Sep 11, 2024
1 parent 00105ae commit 0a041e4
Show file tree
Hide file tree
Showing 30 changed files with 440 additions and 189 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
.vscode/
__pycache__/
env/
1 change: 1 addition & 0 deletions aisuite/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .client import Client
from .provider import ProviderNames
84 changes: 84 additions & 0 deletions aisuite/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from .provider import ProviderFactory, ProviderNames


class Client:
def __init__(self, provider_configs: dict = {}):
"""
Initialize the client with provider configurations.
Use the ProviderFactory to create provider instances.
"""
self.providers = {}
self.provider_configs = provider_configs
for provider_key, config in provider_configs.items():
# Check if the provider key is a valid ProviderNames enum
if not isinstance(provider_key, ProviderNames):
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
# Store the value of the enum in the providers dictionary
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config)

self._chat = None

def configure(self, provider_configs: dict = None):
"""
Configure the client with provider configurations.
"""
if provider_configs is None:
return

self.provider_configs.update(provider_configs)

for provider_key, config in self.provider_configs.items():
if not isinstance(provider_key, ProviderNames):
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config)

@property
def chat(self):
"""Return the chat API interface."""
if not self._chat:
self._chat = Chat(self)
return self._chat


class Chat:
def __init__(self, client: 'Client'):
self.client = client

@property
def completions(self):
"""Return the completions interface."""
return Completions(self.client)


class Completions:
def __init__(self, client: 'Client'):
self.client = client

def create(self, model: str, messages: list, **kwargs):
"""
Create chat completion based on the model, messages, and any extra arguments.
"""
# Check that correct format is used
if ':' not in model:
raise ValueError(f"Invalid model format. Expected 'provider:model', got '{model}'")

# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
provider_key, model_name = model.split(":", 1)

if provider_key not in ProviderNames._value2member_map_:
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")

if provider_key not in self.client.providers:
config = {}
if provider_key in self.client.provider_configs:
config = self.client.provider_configs[provider_key]
self.client.providers[provider_key] = ProviderFactory.create_provider(ProviderNames(provider_key), config)

provider = self.client.providers.get(provider_key)
if not provider:
raise ValueError(f"Could not load provider for {provider_key}.")

# Delegate the chat completion to the correct provider's implementation
# Any additional arguments will be passed to the provider's implementation.
# Eg: max_tokens, temperature, etc.
return provider.chat_completions_create(model_name, messages, **kwargs)
3 changes: 0 additions & 3 deletions aisuite/client/__init__.py

This file was deleted.

18 changes: 0 additions & 18 deletions aisuite/client/chat.py

This file was deleted.

90 changes: 0 additions & 90 deletions aisuite/client/client.py

This file was deleted.

37 changes: 0 additions & 37 deletions aisuite/client/completions.py

This file was deleted.

3 changes: 0 additions & 3 deletions aisuite/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
"""Provides the ProviderInterface for defining the interface that all FM providers must implement."""

from .provider_interface import ProviderInterface
from .chat_completion_response import ChatCompletionResponse
25 changes: 0 additions & 25 deletions aisuite/framework/provider_interface.py

This file was deleted.

13 changes: 13 additions & 0 deletions aisuite/old_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Provides the individual provider interfaces for each FM provider."""

from .anthropic_interface import AnthropicInterface
from .aws_bedrock_interface import AWSBedrockInterface
from .fireworks_interface import FireworksInterface
from .groq_interface import GroqInterface
from .mistral_interface import MistralInterface
from .octo_interface import OctoInterface
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .google_interface import GoogleInterface
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
59 changes: 59 additions & 0 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABC, abstractmethod
from enum import Enum
import importlib

class LLMError(Exception):
"""Custom exception for LLM errors."""
def __init__(self, message):
super().__init__(message)


class Provider(ABC):
@abstractmethod
def chat_completions_create(self, model, messages):
"""Abstract method for chat completion calls, to be implemented by each provider."""
pass


class ProviderNames(str, Enum):
OPENAI = 'openai'
AWS_BEDROCK = 'aws-bedrock'
ANTHROPIC = 'anthropic'
AZURE = 'azure'


class ProviderFactory:
"""Factory to register and create provider instances based on keys."""

_provider_info = {
ProviderNames.OPENAI: ('aisuite.providers.openai_provider', 'OpenAIProvider'),
ProviderNames.AWS_BEDROCK: ('aisuite.providers.aws_bedrock_provider', 'AWSBedrockProvider'),
ProviderNames.ANTHROPIC: ('aisuite.providers.anthropic_provider', 'AnthropicProvider'),
ProviderNames.AZURE: ('aisuite.providers.azure_provider', 'AzureProvider'),
}


@classmethod
def create_provider(cls, provider_key, config):
"""Dynamically import and create an instance of a provider based on the provider key."""
if not isinstance(provider_key, ProviderNames):
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")

module_name, class_name = cls._get_provider_info(provider_key)
if not module_name:
raise ValueError(f"Provider {provider_key.value} is not supported")

# Lazily load the module
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise ImportError(f"Could not import module {module_name}: {str(e)}")

# Instantiate the provider class
provider_class = getattr(module, class_name)
return provider_class(**config)

@classmethod
def _get_provider_info(cls, provider_key):
"""Return the module name and class name for a given provider key."""
return cls._provider_info.get(provider_key, (None, None))
14 changes: 1 addition & 13 deletions aisuite/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1 @@
"""Provides the individual provider interfaces for each FM provider."""

from .anthropic_interface import AnthropicInterface
from .aws_bedrock_interface import AWSBedrockInterface
from .fireworks_interface import FireworksInterface
from .groq_interface import GroqInterface
from .mistral_interface import MistralInterface
from .octo_interface import OctoInterface
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .google_interface import GoogleInterface
# Empty __init__.py
Loading

0 comments on commit 0a041e4

Please sign in to comment.