Skip to content

Commit d77f312

Browse files
Refactoring client and providers. (#27)
Added Azure support & refactored code. Refactoring includes - - ProviderFactory - Lazily import the provider based on config passed to Client. Will need to port the older provider files to the new format. Till then keeping the older provider interface related tests. Co-authored-by: rohit-rptless <[email protected]>
1 parent 00105ae commit d77f312

17 files changed

+612
-258
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
.vscode/
33
__pycache__/
4+
env/

aisuite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .client import Client
2+
from .provider import ProviderNames

aisuite/client.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from .provider import ProviderFactory, ProviderNames
2+
3+
4+
class Client:
5+
def __init__(self, provider_configs: dict = {}):
6+
"""
7+
Initialize the client with provider configurations.
8+
Use the ProviderFactory to create provider instances.
9+
10+
Args:
11+
provider_configs (dict): A dictionary containing provider configurations.
12+
Each key should be a ProviderNames enum or its string representation,
13+
and the value should be a dictionary of configuration options for that provider.
14+
For example:
15+
{
16+
ProviderNames.OPENAI: {"api_key": "your_openai_api_key"},
17+
"aws-bedrock": {
18+
"aws_access_key": "your_aws_access_key",
19+
"aws_secret_key": "your_aws_secret_key",
20+
"aws_region": "us-west-2"
21+
}
22+
}
23+
"""
24+
self.providers = {}
25+
self.provider_configs = provider_configs
26+
self._chat = None
27+
self._initialize_providers()
28+
29+
def _initialize_providers(self):
30+
"""Helper method to initialize or update providers."""
31+
for provider_key, config in self.provider_configs.items():
32+
provider_key = self._validate_provider_key(provider_key)
33+
self.providers[provider_key.value] = ProviderFactory.create_provider(
34+
provider_key, config
35+
)
36+
37+
def _validate_provider_key(self, provider_key):
38+
"""
39+
Validate if the provider key is part of ProviderNames enum.
40+
Allow strings as well and convert them to ProviderNames.
41+
"""
42+
if isinstance(provider_key, str):
43+
if provider_key not in ProviderNames._value2member_map_:
44+
raise ValueError(f"Provider {provider_key} is not a valid provider")
45+
return ProviderNames(provider_key)
46+
47+
if isinstance(provider_key, ProviderNames):
48+
return provider_key
49+
50+
raise ValueError(
51+
f"Provider {provider_key} should either be a string or enum ProviderNames"
52+
)
53+
54+
def configure(self, provider_configs: dict = None):
55+
"""
56+
Configure the client with provider configurations.
57+
"""
58+
if provider_configs is None:
59+
return
60+
61+
self.provider_configs.update(provider_configs)
62+
self._initialize_providers() # NOTE: This will override existing provider instances.
63+
64+
@property
65+
def chat(self):
66+
"""Return the chat API interface."""
67+
if not self._chat:
68+
self._chat = Chat(self)
69+
return self._chat
70+
71+
72+
class Chat:
73+
def __init__(self, client: "Client"):
74+
self.client = client
75+
self._completions = Completions(self.client)
76+
77+
@property
78+
def completions(self):
79+
"""Return the completions interface."""
80+
return self._completions
81+
82+
83+
class Completions:
84+
def __init__(self, client: "Client"):
85+
self.client = client
86+
87+
def create(self, model: str, messages: list, **kwargs):
88+
"""
89+
Create chat completion based on the model, messages, and any extra arguments.
90+
"""
91+
# Check that correct format is used
92+
if ":" not in model:
93+
raise ValueError(
94+
f"Invalid model format. Expected 'provider:model', got '{model}'"
95+
)
96+
97+
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
98+
provider_key, model_name = model.split(":", 1)
99+
100+
if provider_key not in ProviderNames._value2member_map_:
101+
# If the provider key does not match, give a clearer message to guide the user
102+
valid_providers = ", ".join([p.value for p in ProviderNames])
103+
raise ValueError(
104+
f"Invalid provider key '{provider_key}'. Expected one of: {valid_providers}. "
105+
"Make sure the model string is formatted correctly as 'provider:model'."
106+
)
107+
108+
if provider_key not in self.client.providers:
109+
config = {}
110+
if provider_key in self.client.provider_configs:
111+
config = self.client.provider_configs[provider_key]
112+
self.client.providers[provider_key] = ProviderFactory.create_provider(
113+
ProviderNames(provider_key), config
114+
)
115+
116+
provider = self.client.providers.get(provider_key)
117+
if not provider:
118+
raise ValueError(f"Could not load provider for {provider_key}.")
119+
120+
# Delegate the chat completion to the correct provider's implementation
121+
# Any additional arguments will be passed to the provider's implementation.
122+
# Eg: max_tokens, temperature, etc.
123+
return provider.chat_completions_create(model_name, messages, **kwargs)

aisuite/client/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

aisuite/client/chat.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

aisuite/client/client.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

aisuite/client/completions.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

aisuite/framework/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
"""Provides the ProviderInterface for defining the interface that all FM providers must implement."""
2-
31
from .provider_interface import ProviderInterface
42
from .chat_completion_response import ChatCompletionResponse

aisuite/provider.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from abc import ABC, abstractmethod
2+
from enum import Enum
3+
import importlib
4+
5+
6+
class LLMError(Exception):
7+
"""Custom exception for LLM errors."""
8+
9+
def __init__(self, message):
10+
super().__init__(message)
11+
12+
13+
class Provider(ABC):
14+
@abstractmethod
15+
def chat_completions_create(self, model, messages):
16+
"""Abstract method for chat completion calls, to be implemented by each provider."""
17+
pass
18+
19+
20+
class ProviderNames(str, Enum):
21+
OPENAI = "openai"
22+
AWS_BEDROCK = "aws-bedrock"
23+
ANTHROPIC = "anthropic"
24+
AZURE = "azure"
25+
26+
27+
class ProviderFactory:
28+
"""Factory to register and create provider instances based on keys."""
29+
30+
_provider_info = {
31+
ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"),
32+
ProviderNames.AWS_BEDROCK: (
33+
"aisuite.providers.aws_bedrock_provider",
34+
"AWSBedrockProvider",
35+
),
36+
ProviderNames.ANTHROPIC: (
37+
"aisuite.providers.anthropic_provider",
38+
"AnthropicProvider",
39+
),
40+
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
41+
}
42+
43+
@classmethod
44+
def create_provider(cls, provider_key, config):
45+
"""Dynamically import and create an instance of a provider based on the provider key."""
46+
if not isinstance(provider_key, ProviderNames):
47+
raise ValueError(
48+
f"Provider {provider_key} is not a valid ProviderNames enum"
49+
)
50+
51+
module_name, class_name = cls._get_provider_info(provider_key)
52+
if not module_name:
53+
raise ValueError(f"Provider {provider_key.value} is not supported")
54+
55+
# Lazily load the module
56+
try:
57+
module = importlib.import_module(module_name)
58+
except ImportError as e:
59+
raise ImportError(f"Could not import module {module_name}: {str(e)}")
60+
61+
# Instantiate the provider class
62+
provider_class = getattr(module, class_name)
63+
return provider_class(**config)
64+
65+
@classmethod
66+
def _get_provider_info(cls, provider_key):
67+
"""Return the module name and class name for a given provider key."""
68+
return cls._provider_info.get(provider_key, (None, None))
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import anthropic
2+
from aisuite.provider import Provider
3+
from aisuite.framework import ChatCompletionResponse
4+
5+
# Define a constant for the default max_tokens value
6+
DEFAULT_MAX_TOKENS = 4096
7+
8+
9+
class AnthropicProvider(Provider):
10+
def __init__(self, **config):
11+
"""
12+
Initialize the Anthropic provider with the given configuration.
13+
Pass the entire configuration dictionary to the Anthropic client constructor.
14+
"""
15+
16+
self.client = anthropic.Anthropic(**config)
17+
18+
def chat_completions_create(self, model, messages, **kwargs):
19+
# Check if the fist message is a system message
20+
if messages[0]["role"] == "system":
21+
system_message = messages[0]["content"]
22+
messages = messages[1:]
23+
else:
24+
system_message = None
25+
26+
# kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS)
27+
if "max_tokens" not in kwargs:
28+
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS
29+
30+
return self.normalize_response(
31+
self.client.messages.create(
32+
model=model, system=system_message, messages=messages, **kwargs
33+
)
34+
)
35+
36+
def normalize_response(self, response):
37+
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
38+
normalized_response = ChatCompletionResponse()
39+
normalized_response.choices[0].message.content = response.content[0].text
40+
return normalized_response

0 commit comments

Comments
 (0)