Skip to content

Commit c44c25d

Browse files
rohitprasad15rohit-rptless
authored andcommitted
Refactoring client and providers.
Added Azure support & refactored code. Refactoring includes - - ProviderFactory provides a specific provider impl. - Lazily import the provider based on config passed to Client. - Reduced the number of files.
1 parent 00105ae commit c44c25d

30 files changed

+441
-189
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
.vscode/
33
__pycache__/
4+
env/

Diff for: aisuite/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .client import Client
2+
from .provider import ProviderNames

Diff for: aisuite/client.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from .provider import ProviderFactory, ProviderNames
2+
3+
4+
class Client:
5+
def __init__(self, provider_configs: dict = {}):
6+
"""
7+
Initialize the client with provider configurations.
8+
Use the ProviderFactory to create provider instances.
9+
"""
10+
self.providers = {}
11+
self.provider_configs = provider_configs
12+
for provider_key, config in provider_configs.items():
13+
# Check if the provider key is a valid ProviderNames enum
14+
if not isinstance(provider_key, ProviderNames):
15+
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
16+
# Store the value of the enum in the providers dictionary
17+
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config)
18+
19+
self._chat = None
20+
21+
def configure(self, provider_configs: dict = None):
22+
"""
23+
Configure the client with provider configurations.
24+
"""
25+
if provider_configs is None:
26+
return
27+
28+
self.provider_configs.update(provider_configs)
29+
30+
for provider_key, config in self.provider_configs.items():
31+
if not isinstance(provider_key, ProviderNames):
32+
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
33+
self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config)
34+
35+
@property
36+
def chat(self):
37+
"""Return the chat API interface."""
38+
if not self._chat:
39+
self._chat = Chat(self)
40+
return self._chat
41+
42+
43+
class Chat:
44+
def __init__(self, client: 'Client'):
45+
self.client = client
46+
47+
@property
48+
def completions(self):
49+
"""Return the completions interface."""
50+
return Completions(self.client)
51+
52+
53+
class Completions:
54+
def __init__(self, client: 'Client'):
55+
self.client = client
56+
57+
def create(self, model: str, messages: list, **kwargs):
58+
"""
59+
Create chat completion based on the model, messages, and any extra arguments.
60+
"""
61+
# Check that correct format is used
62+
if ':' not in model:
63+
raise ValueError(f"Invalid model format. Expected 'provider:model', got '{model}'")
64+
65+
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
66+
provider_key, model_name = model.split(":", 1)
67+
68+
if provider_key not in ProviderNames._value2member_map_:
69+
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
70+
71+
if provider_key not in self.client.providers:
72+
config = {}
73+
if provider_key in self.client.provider_configs:
74+
config = self.client.provider_configs[provider_key]
75+
self.client.providers[provider_key] = ProviderFactory.create_provider(ProviderNames(provider_key), config)
76+
77+
provider = self.client.providers.get(provider_key)
78+
if not provider:
79+
raise ValueError(f"Could not load provider for {provider_key}.")
80+
81+
# Delegate the chat completion to the correct provider's implementation
82+
# Any additional arguments will be passed to the provider's implementation.
83+
# Eg: max_tokens, temperature, etc.
84+
return provider.chat_completions_create(model_name, messages, **kwargs)

Diff for: aisuite/client/__init__.py

-3
This file was deleted.

Diff for: aisuite/client/chat.py

-18
This file was deleted.

Diff for: aisuite/client/client.py

-90
This file was deleted.

Diff for: aisuite/client/completions.py

-37
This file was deleted.

Diff for: aisuite/framework/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
1-
"""Provides the ProviderInterface for defining the interface that all FM providers must implement."""
2-
3-
from .provider_interface import ProviderInterface
41
from .chat_completion_response import ChatCompletionResponse

Diff for: aisuite/framework/provider_interface.py

-25
This file was deleted.

Diff for: aisuite/old_providers/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Provides the individual provider interfaces for each FM provider."""
2+
3+
from .anthropic_interface import AnthropicInterface
4+
from .aws_bedrock_interface import AWSBedrockInterface
5+
from .fireworks_interface import FireworksInterface
6+
from .groq_interface import GroqInterface
7+
from .mistral_interface import MistralInterface
8+
from .octo_interface import OctoInterface
9+
from .ollama_interface import OllamaInterface
10+
from .openai_interface import OpenAIInterface
11+
from .replicate_interface import ReplicateInterface
12+
from .together_interface import TogetherInterface
13+
from .google_interface import GoogleInterface
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

Diff for: aisuite/provider.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from abc import ABC, abstractmethod
2+
from enum import Enum
3+
import importlib
4+
5+
class LLMError(Exception):
6+
"""Custom exception for LLM errors."""
7+
def __init__(self, message):
8+
super().__init__(message)
9+
10+
11+
class Provider(ABC):
12+
@abstractmethod
13+
def chat_completions_create(self, model, messages):
14+
"""Abstract method for chat completion calls, to be implemented by each provider."""
15+
pass
16+
17+
18+
class ProviderNames(str, Enum):
19+
OPENAI = 'openai'
20+
AWS_BEDROCK = 'aws-bedrock'
21+
ANTHROPIC = 'anthropic'
22+
AZURE = 'azure'
23+
24+
25+
class ProviderFactory:
26+
"""Factory to register and create provider instances based on keys."""
27+
28+
_provider_info = {
29+
ProviderNames.OPENAI: ('aisuite.providers.openai_provider', 'OpenAIProvider'),
30+
ProviderNames.AWS_BEDROCK: ('aisuite.providers.aws_bedrock_provider', 'AWSBedrockProvider'),
31+
ProviderNames.ANTHROPIC: ('aisuite.providers.anthropic_provider', 'AnthropicProvider'),
32+
ProviderNames.AZURE: ('aisuite.providers.azure_provider', 'AzureProvider'),
33+
}
34+
35+
36+
@classmethod
37+
def create_provider(cls, provider_key, config):
38+
"""Dynamically import and create an instance of a provider based on the provider key."""
39+
if not isinstance(provider_key, ProviderNames):
40+
raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum")
41+
42+
module_name, class_name = cls._get_provider_info(provider_key)
43+
if not module_name:
44+
raise ValueError(f"Provider {provider_key.value} is not supported")
45+
46+
# Lazily load the module
47+
try:
48+
module = importlib.import_module(module_name)
49+
except ImportError as e:
50+
raise ImportError(f"Could not import module {module_name}: {str(e)}")
51+
52+
# Instantiate the provider class
53+
provider_class = getattr(module, class_name)
54+
return provider_class(**config)
55+
56+
@classmethod
57+
def _get_provider_info(cls, provider_key):
58+
"""Return the module name and class name for a given provider key."""
59+
return cls._provider_info.get(provider_key, (None, None))

Diff for: aisuite/providers/__init__.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -1,13 +1 @@
1-
"""Provides the individual provider interfaces for each FM provider."""
2-
3-
from .anthropic_interface import AnthropicInterface
4-
from .aws_bedrock_interface import AWSBedrockInterface
5-
from .fireworks_interface import FireworksInterface
6-
from .groq_interface import GroqInterface
7-
from .mistral_interface import MistralInterface
8-
from .octo_interface import OctoInterface
9-
from .ollama_interface import OllamaInterface
10-
from .openai_interface import OpenAIInterface
11-
from .replicate_interface import ReplicateInterface
12-
from .together_interface import TogetherInterface
13-
from .google_interface import GoogleInterface
1+
# Empty __init__.py

0 commit comments

Comments
 (0)