Skip to content

Commit d9ede71

Browse files
rohitprasad15rohit-rptless
authored andcommitted
Refactoring client and providers.
Added Azure support & refactored code. Refactoring includes - - ProviderFactory provides a specific provider impl. - Lazily import the provider based on config passed to Client. - Reduced the number of files.
1 parent 00105ae commit d9ede71

30 files changed

+484
-189
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
.vscode/
33
__pycache__/
4+
env/

aisuite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .client import Client
2+
from .provider import ProviderNames

aisuite/client.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from .provider import ProviderFactory, ProviderNames
2+
3+
4+
class Client:
5+
def __init__(self, provider_configs: dict = {}):
6+
"""
7+
Initialize the client with provider configurations.
8+
Use the ProviderFactory to create provider instances.
9+
"""
10+
self.providers = {}
11+
self.provider_configs = provider_configs
12+
for provider_key, config in provider_configs.items():
13+
# Check if the provider key is a valid ProviderNames enum
14+
if not isinstance(provider_key, ProviderNames):
15+
raise ValueError(
16+
f"Provider {provider_key} is not a valid ProviderNames enum"
17+
)
18+
# Store the value of the enum in the providers dictionary
19+
self.providers[provider_key.value] = ProviderFactory.create_provider(
20+
provider_key, config
21+
)
22+
23+
self._chat = None
24+
25+
def configure(self, provider_configs: dict = None):
26+
"""
27+
Configure the client with provider configurations.
28+
"""
29+
if provider_configs is None:
30+
return
31+
32+
self.provider_configs.update(provider_configs)
33+
34+
for provider_key, config in self.provider_configs.items():
35+
if not isinstance(provider_key, ProviderNames):
36+
raise ValueError(
37+
f"Provider {provider_key} is not a valid ProviderNames enum"
38+
)
39+
self.providers[provider_key.value] = ProviderFactory.create_provider(
40+
provider_key, config
41+
)
42+
43+
@property
44+
def chat(self):
45+
"""Return the chat API interface."""
46+
if not self._chat:
47+
self._chat = Chat(self)
48+
return self._chat
49+
50+
51+
class Chat:
52+
def __init__(self, client: "Client"):
53+
self.client = client
54+
55+
@property
56+
def completions(self):
57+
"""Return the completions interface."""
58+
return Completions(self.client)
59+
60+
61+
class Completions:
62+
def __init__(self, client: "Client"):
63+
self.client = client
64+
65+
def create(self, model: str, messages: list, **kwargs):
66+
"""
67+
Create chat completion based on the model, messages, and any extra arguments.
68+
"""
69+
# Check that correct format is used
70+
if ":" not in model:
71+
raise ValueError(
72+
f"Invalid model format. Expected 'provider:model', got '{model}'"
73+
)
74+
75+
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
76+
provider_key, model_name = model.split(":", 1)
77+
78+
if provider_key not in ProviderNames._value2member_map_:
79+
raise ValueError(
80+
f"Provider {provider_key} is not a valid ProviderNames enum"
81+
)
82+
83+
if provider_key not in self.client.providers:
84+
config = {}
85+
if provider_key in self.client.provider_configs:
86+
config = self.client.provider_configs[provider_key]
87+
self.client.providers[provider_key] = ProviderFactory.create_provider(
88+
ProviderNames(provider_key), config
89+
)
90+
91+
provider = self.client.providers.get(provider_key)
92+
if not provider:
93+
raise ValueError(f"Could not load provider for {provider_key}.")
94+
95+
# Delegate the chat completion to the correct provider's implementation
96+
# Any additional arguments will be passed to the provider's implementation.
97+
# Eg: max_tokens, temperature, etc.
98+
return provider.chat_completions_create(model_name, messages, **kwargs)

aisuite/client/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

aisuite/client/chat.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

aisuite/client/client.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

aisuite/client/completions.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

aisuite/framework/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
1-
"""Provides the ProviderInterface for defining the interface that all FM providers must implement."""
2-
3-
from .provider_interface import ProviderInterface
41
from .chat_completion_response import ChatCompletionResponse

aisuite/framework/provider_interface.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

aisuite/old_providers/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Provides the individual provider interfaces for each FM provider."""
2+
3+
from .anthropic_interface import AnthropicInterface
4+
from .aws_bedrock_interface import AWSBedrockInterface
5+
from .fireworks_interface import FireworksInterface
6+
from .groq_interface import GroqInterface
7+
from .mistral_interface import MistralInterface
8+
from .octo_interface import OctoInterface
9+
from .ollama_interface import OllamaInterface
10+
from .openai_interface import OpenAIInterface
11+
from .replicate_interface import ReplicateInterface
12+
from .together_interface import TogetherInterface
13+
from .google_interface import GoogleInterface

aisuite/provider.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from abc import ABC, abstractmethod
2+
from enum import Enum
3+
import importlib
4+
5+
6+
class LLMError(Exception):
7+
"""Custom exception for LLM errors."""
8+
9+
def __init__(self, message):
10+
super().__init__(message)
11+
12+
13+
class Provider(ABC):
14+
@abstractmethod
15+
def chat_completions_create(self, model, messages):
16+
"""Abstract method for chat completion calls, to be implemented by each provider."""
17+
pass
18+
19+
20+
class ProviderNames(str, Enum):
21+
OPENAI = "openai"
22+
AWS_BEDROCK = "aws-bedrock"
23+
ANTHROPIC = "anthropic"
24+
AZURE = "azure"
25+
26+
27+
class ProviderFactory:
28+
"""Factory to register and create provider instances based on keys."""
29+
30+
_provider_info = {
31+
ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"),
32+
ProviderNames.AWS_BEDROCK: (
33+
"aisuite.providers.aws_bedrock_provider",
34+
"AWSBedrockProvider",
35+
),
36+
ProviderNames.ANTHROPIC: (
37+
"aisuite.providers.anthropic_provider",
38+
"AnthropicProvider",
39+
),
40+
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
41+
}
42+
43+
@classmethod
44+
def create_provider(cls, provider_key, config):
45+
"""Dynamically import and create an instance of a provider based on the provider key."""
46+
if not isinstance(provider_key, ProviderNames):
47+
raise ValueError(
48+
f"Provider {provider_key} is not a valid ProviderNames enum"
49+
)
50+
51+
module_name, class_name = cls._get_provider_info(provider_key)
52+
if not module_name:
53+
raise ValueError(f"Provider {provider_key.value} is not supported")
54+
55+
# Lazily load the module
56+
try:
57+
module = importlib.import_module(module_name)
58+
except ImportError as e:
59+
raise ImportError(f"Could not import module {module_name}: {str(e)}")
60+
61+
# Instantiate the provider class
62+
provider_class = getattr(module, class_name)
63+
return provider_class(**config)
64+
65+
@classmethod
66+
def _get_provider_info(cls, provider_key):
67+
"""Return the module name and class name for a given provider key."""
68+
return cls._provider_info.get(provider_key, (None, None))

aisuite/providers/__init__.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1 @@
1-
"""Provides the individual provider interfaces for each FM provider."""
2-
3-
from .anthropic_interface import AnthropicInterface
4-
from .aws_bedrock_interface import AWSBedrockInterface
5-
from .fireworks_interface import FireworksInterface
6-
from .groq_interface import GroqInterface
7-
from .mistral_interface import MistralInterface
8-
from .octo_interface import OctoInterface
9-
from .ollama_interface import OllamaInterface
10-
from .openai_interface import OpenAIInterface
11-
from .replicate_interface import ReplicateInterface
12-
from .together_interface import TogetherInterface
13-
from .google_interface import GoogleInterface
1+
# Empty __init__.py

0 commit comments

Comments
 (0)