diff --git a/.gitignore b/.gitignore index f1f1d58d..e1084c98 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea/ .vscode/ __pycache__/ +env/ diff --git a/aisuite/__init__.py b/aisuite/__init__.py index 3ff722bf..7f5ee70b 100644 --- a/aisuite/__init__.py +++ b/aisuite/__init__.py @@ -1 +1,2 @@ from .client import Client +from .provider import ProviderNames diff --git a/aisuite/client.py b/aisuite/client.py new file mode 100644 index 00000000..5a275f18 --- /dev/null +++ b/aisuite/client.py @@ -0,0 +1,86 @@ +from .provider import ProviderFactory, ProviderNames + + +class Client: + def __init__(self, provider_configs: dict = {}): + """ + Initialize the client with provider configurations. + Use the ProviderFactory to create provider instances. + """ + self.providers = {} + self.provider_configs = provider_configs + for provider_key, config in provider_configs.items(): + # Check if the provider key is a valid ProviderNames enum + if not isinstance(provider_key, ProviderNames): + raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") + # Store the value of the enum in the providers dictionary + self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config) + + self._chat = None + + def configure(self, provider_configs: dict = None): + """ + Configure the client with provider configurations. + """ + if provider_configs is None: + return + + self.provider_configs.update(provider_configs) + + for provider_key, config in self.provider_configs.items(): + if not isinstance(provider_key, ProviderNames): + raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") + self.providers[provider_key.value] = ProviderFactory.create_provider(provider_key, config) + + @property + def chat(self): + """Return the chat API interface.""" + if not self._chat: + self._chat = Chat(self) + return self._chat + + +class Chat: + def __init__(self, client: 'Client'): + self.client = client + + @property + def completions(self): + """Return the completions interface.""" + return Completions(self.client) + + +class Completions: + def __init__(self, client: 'Client'): + self.client = client + + def create(self, model: str, messages: list, **kwargs): + """ + Create chat completion based on the model, messages, and any extra arguments. + """ + # Check that correct format is used + if ':' not in model: + raise ValueError(f"Invalid model format. Expected 'provider:model', got '{model}'") + + # Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name" + provider_key, model_name = model.split(":", 1) + + if provider_key not in ProviderNames._value2member_map_: + raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") + + if provider_key not in self.client.providers: + config = {} + if provider_key in self.client.provider_configs: + config = self.client.provider_configs[provider_key] + self.client.providers[provider_key] = ProviderFactory.create_provider(ProviderNames(provider_key), config) + + provider = self.client.providers.get(provider_key) + if not provider: + raise ValueError(f"Could not load provider for {provider_key}.") + + # Delegate the chat completion to the correct provider's implementation + # Any additional arguments will be passed to the provider's implementation. + # Eg: max_tokens, temperature, etc. + return provider.chat_completions_create(model_name, messages, **kwargs) + + diff --git a/aisuite/client/__init__.py b/aisuite/client/__init__.py deleted file mode 100644 index 75e4cc81..00000000 --- a/aisuite/client/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Provides the Client for managing chats across many FM providers.""" - -from .client import Client diff --git a/aisuite/client/chat.py b/aisuite/client/chat.py deleted file mode 100644 index 6b51ae0f..00000000 --- a/aisuite/client/chat.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Chat is instantiated with a client and manages completions.""" - -from .completions import Completions - - -class Chat: - """Manage chat sessions with multiple providers.""" - - def __init__(self, topmost_instance): - """Initialize a new Chat instance. - - Args: - ---- - topmost_instance: The chat session's client instance (Client). - - """ - self.topmost_instance = topmost_instance - self.completions = Completions(topmost_instance) diff --git a/aisuite/client/client.py b/aisuite/client/client.py deleted file mode 100644 index e2654126..00000000 --- a/aisuite/client/client.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Client manages a Chat across multiple provider interfaces.""" - -from .chat import Chat -from ..providers import ( - AnthropicInterface, - AWSBedrockInterface, - FireworksInterface, - GroqInterface, - MistralInterface, - OctoInterface, - OllamaInterface, - OpenAIInterface, - ReplicateInterface, - TogetherInterface, - GoogleInterface, -) - - -class Client: - """Manages multiple provider interfaces.""" - - _MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE = ( - "Expected ':' in model identifier to specify provider:model. Got {model}." - ) - _NO_FACTORY_ERROR_MESSAGE_TEMPLATE = ( - "Could not find factory to create interface for provider '{provider}'." - ) - - def __init__(self): - """Initialize the Client instance. - - Attributes - ---------- - chat (Chat): The chat session. - all_interfaces (dict): Stores interface instances by provider names. - all_factories (dict): Maps provider names to their corresponding interfaces. - - """ - self.chat = Chat(self) - self.all_interfaces = {} - self.all_factories = { - "anthropic": AnthropicInterface, - "aws": AWSBedrockInterface, - "fireworks": FireworksInterface, - "groq": GroqInterface, - "mistral": MistralInterface, - "octo": OctoInterface, - "ollama": OllamaInterface, - "openai": OpenAIInterface, - "replicate": ReplicateInterface, - "together": TogetherInterface, - "google": GoogleInterface, - } - - def get_provider_interface(self, model): - """Retrieve or create a provider interface based on a model identifier. - - Args: - ---- - model (str): The model identifier in the format 'provider:model'. - - Raises: - ------ - ValueError: If the model identifier does colon-separate provider and model. - Exception: If no factory is found from the supplied model. - - Returns: - ------- - The interface instance for the provider and the model name. - - """ - if ":" not in model: - raise ValueError( - self._MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE.format(model=model) - ) - - model_parts = model.split(":", maxsplit=1) - provider = model_parts[0] - model_name = model_parts[1] - - if provider in self.all_interfaces: - return self.all_interfaces[provider], model_name - - if provider not in self.all_factories: - raise Exception( - self._NO_FACTORY_ERROR_MESSAGE_TEMPLATE.format(provider=provider) - ) - - self.all_interfaces[provider] = self.all_factories[provider]() - return self.all_interfaces[provider], model_name diff --git a/aisuite/client/completions.py b/aisuite/client/completions.py deleted file mode 100644 index c87c19f3..00000000 --- a/aisuite/client/completions.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Completions is instantiated with a client and manages completion requests in chat sessions.""" - - -class Completions: - """Manage completion requests in chat sessions.""" - - def __init__(self, topmost_instance): - """Initialize a new Completions instance. - - Args: - ---- - topmost_instance: The chat session's client instance (Client). - - """ - self.topmost_instance = topmost_instance - - def create(self, model=None, temperature=0, messages=None): - """Create a completion request using a specified provider/model combination. - - Args: - ---- - model (str): The model identifier with format 'provider:model'. - temperature (float): The sampling temperature. - messages (list): A list of previous messages. - - Returns: - ------- - The resulting completion. - - """ - interface, model_name = self.topmost_instance.get_provider_interface(model) - - return interface.chat_completion_create( - messages=messages, - model=model_name, - temperature=temperature, - ) diff --git a/aisuite/framework/__init__.py b/aisuite/framework/__init__.py deleted file mode 100644 index 2d72fd37..00000000 --- a/aisuite/framework/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Provides the ProviderInterface for defining the interface that all FM providers must implement.""" - -from .provider_interface import ProviderInterface -from .chat_completion_response import ChatCompletionResponse diff --git a/aisuite/framework/chat_completion_response.py b/aisuite/framework/chat_completion_response.py deleted file mode 100644 index ef13fc2c..00000000 --- a/aisuite/framework/chat_completion_response.py +++ /dev/null @@ -1,8 +0,0 @@ -from aisuite.framework.choice import Choice - - -class ChatCompletionResponse: - """Used to conform to the response model of OpenAI""" - - def __init__(self): - self.choices = [Choice()] # Adjust the range as needed for more choices diff --git a/aisuite/framework/choice.py b/aisuite/framework/choice.py deleted file mode 100644 index 3542da57..00000000 --- a/aisuite/framework/choice.py +++ /dev/null @@ -1,6 +0,0 @@ -from aisuite.framework.message import Message - - -class Choice: - def __init__(self): - self.message = Message() diff --git a/aisuite/framework/message.py b/aisuite/framework/message.py deleted file mode 100644 index 5aa7f822..00000000 --- a/aisuite/framework/message.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Interface to hold contents of api responses when they do not conform to the OpenAI style response""" - - -class Message: - def __init__(self): - self.content = None diff --git a/aisuite/framework/provider_interface.py b/aisuite/framework/provider_interface.py deleted file mode 100644 index 3b6db766..00000000 --- a/aisuite/framework/provider_interface.py +++ /dev/null @@ -1,25 +0,0 @@ -"""The shared interface for model providers.""" - - -class ProviderInterface: - """Defines the expected behavior for provider-specific interfaces.""" - - def chat_completion_create(self, messages=None, model=None, temperature=0) -> None: - """Create a chat completion using the specified messages, model, and temperature. - - This method must be implemented by subclasses to perform completions. - - Args: - ---- - messages (list): The chat history. - model (str): The identifier of the model to be used in the completion. - temperature (float): The temperature to use in the completion. - - Raises: - ------ - NotImplementedError: If this method has not been implemented by a subclass. - - """ - raise NotImplementedError( - "Provider Interface has not implemented chat_completion_create()" - ) diff --git a/aisuite/old_providers/__init__.py b/aisuite/old_providers/__init__.py new file mode 100644 index 00000000..816f790e --- /dev/null +++ b/aisuite/old_providers/__init__.py @@ -0,0 +1,13 @@ +"""Provides the individual provider interfaces for each FM provider.""" + +from .anthropic_interface import AnthropicInterface +from .aws_bedrock_interface import AWSBedrockInterface +from .fireworks_interface import FireworksInterface +from .groq_interface import GroqInterface +from .mistral_interface import MistralInterface +from .octo_interface import OctoInterface +from .ollama_interface import OllamaInterface +from .openai_interface import OpenAIInterface +from .replicate_interface import ReplicateInterface +from .together_interface import TogetherInterface +from .google_interface import GoogleInterface diff --git a/aisuite/providers/anthropic_interface.py b/aisuite/old_providers/anthropic_interface.py similarity index 100% rename from aisuite/providers/anthropic_interface.py rename to aisuite/old_providers/anthropic_interface.py diff --git a/aisuite/providers/aws_bedrock_interface.py b/aisuite/old_providers/aws_bedrock_interface.py similarity index 100% rename from aisuite/providers/aws_bedrock_interface.py rename to aisuite/old_providers/aws_bedrock_interface.py diff --git a/aisuite/providers/fireworks_interface.py b/aisuite/old_providers/fireworks_interface.py similarity index 100% rename from aisuite/providers/fireworks_interface.py rename to aisuite/old_providers/fireworks_interface.py diff --git a/aisuite/providers/google_interface.py b/aisuite/old_providers/google_interface.py similarity index 100% rename from aisuite/providers/google_interface.py rename to aisuite/old_providers/google_interface.py diff --git a/aisuite/providers/groq_interface.py b/aisuite/old_providers/groq_interface.py similarity index 100% rename from aisuite/providers/groq_interface.py rename to aisuite/old_providers/groq_interface.py diff --git a/aisuite/providers/mistral_interface.py b/aisuite/old_providers/mistral_interface.py similarity index 100% rename from aisuite/providers/mistral_interface.py rename to aisuite/old_providers/mistral_interface.py diff --git a/aisuite/providers/octo_interface.py b/aisuite/old_providers/octo_interface.py similarity index 100% rename from aisuite/providers/octo_interface.py rename to aisuite/old_providers/octo_interface.py diff --git a/aisuite/providers/ollama_interface.py b/aisuite/old_providers/ollama_interface.py similarity index 100% rename from aisuite/providers/ollama_interface.py rename to aisuite/old_providers/ollama_interface.py diff --git a/aisuite/providers/openai_interface.py b/aisuite/old_providers/openai_interface.py similarity index 100% rename from aisuite/providers/openai_interface.py rename to aisuite/old_providers/openai_interface.py diff --git a/aisuite/providers/replicate_interface.py b/aisuite/old_providers/replicate_interface.py similarity index 100% rename from aisuite/providers/replicate_interface.py rename to aisuite/old_providers/replicate_interface.py diff --git a/aisuite/providers/together_interface.py b/aisuite/old_providers/together_interface.py similarity index 100% rename from aisuite/providers/together_interface.py rename to aisuite/old_providers/together_interface.py diff --git a/aisuite/provider.py b/aisuite/provider.py new file mode 100644 index 00000000..ea51e8d1 --- /dev/null +++ b/aisuite/provider.py @@ -0,0 +1,64 @@ +from abc import ABC, abstractmethod +from enum import Enum +import importlib + +class LLMError(Exception): + """Custom exception for LLM errors.""" + def __init__(self, message): + super().__init__(message) + + +class Provider(ABC): + @abstractmethod + def chat_completions_create(self, model, messages): + """Abstract method for chat completion calls, to be implemented by each provider.""" + pass + + +class ProviderNames(str, Enum): + OPENAI = 'openai' + AWS_BEDROCK = 'aws-bedrock' + ANTHROPIC = 'anthropic' + AZURE = 'azure' + + +class ProviderFactory: + """Factory to register and create provider instances based on keys.""" + + _provider_info = { + ProviderNames.OPENAI: ('aisuite.providers.openai_provider', 'OpenAIProvider'), + ProviderNames.AWS_BEDROCK: ('aisuite.providers.aws_bedrock_provider', 'AWSBedrockProvider'), + ProviderNames.ANTHROPIC: ('aisuite.providers.anthropic_provider', 'AnthropicProvider'), + ProviderNames.AZURE: ('aisuite.providers.azure_provider', 'AzureProvider'), + } + + # TODO: + # jon_provider.py + # More - OpenAI - ASR, Image generation, Embeddings, TTS. + # + + + @classmethod + def create_provider(cls, provider_key, config): + """Dynamically import and create an instance of a provider based on the provider key.""" + if not isinstance(provider_key, ProviderNames): + raise ValueError(f"Provider {provider_key} is not a valid ProviderNames enum") + + module_name, class_name = cls._get_provider_info(provider_key) + if not module_name: + raise ValueError(f"Provider {provider_key.value} is not supported") + + # Lazily load the module + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise ImportError(f"Could not import module {module_name}: {str(e)}") + + # Instantiate the provider class + provider_class = getattr(module, class_name) + return provider_class(**config) + + @classmethod + def _get_provider_info(cls, provider_key): + """Return the module name and class name for a given provider key.""" + return cls._provider_info.get(provider_key, (None, None)) diff --git a/aisuite/providers/__init__.py b/aisuite/providers/__init__.py index 816f790e..5910d76d 100644 --- a/aisuite/providers/__init__.py +++ b/aisuite/providers/__init__.py @@ -1,13 +1 @@ -"""Provides the individual provider interfaces for each FM provider.""" - -from .anthropic_interface import AnthropicInterface -from .aws_bedrock_interface import AWSBedrockInterface -from .fireworks_interface import FireworksInterface -from .groq_interface import GroqInterface -from .mistral_interface import MistralInterface -from .octo_interface import OctoInterface -from .ollama_interface import OllamaInterface -from .openai_interface import OpenAIInterface -from .replicate_interface import ReplicateInterface -from .together_interface import TogetherInterface -from .google_interface import GoogleInterface +# Empty __init__.py diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py new file mode 100644 index 00000000..6cb208a1 --- /dev/null +++ b/aisuite/providers/anthropic_provider.py @@ -0,0 +1,46 @@ +import anthropic +from aisuite.provider import Provider + +# Define a constant for the default max_tokens value +DEFAULT_MAX_TOKENS = 4096 + +class AnthropicProvider(Provider): + def __init__(self, **config): + """ + Initialize the Anthropic provider with the given configuration. + Pass the entire configuration dictionary to the Anthropic client constructor. + """ + + self.client = anthropic.Anthropic(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Check if the fist message is a system message + if messages[0]["role"] == "system": + system_message = messages[0]["content"] + messages = messages[1:] + else: + system_message = None + + # kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS) + if 'max_tokens' not in kwargs: + kwargs['max_tokens'] = DEFAULT_MAX_TOKENS + + return self.normalize_response(self.client.messages.create( + model=model, + system=system_message, + messages=messages, + **kwargs + )) + + def normalize_response(self, response): + """ Normalize the response from the Anthropic API to match OpenAI's response format. """ + return { + "choices": [ + { + "message": { + "role": response.get("role", "assistant"), + "content": response.get("content", ""), + } + } + ] + } \ No newline at end of file diff --git a/aisuite/providers/aws_bedrock_provider.py b/aisuite/providers/aws_bedrock_provider.py new file mode 100644 index 00000000..c2d738d9 --- /dev/null +++ b/aisuite/providers/aws_bedrock_provider.py @@ -0,0 +1,81 @@ +import boto3 +from aisuite.provider import Provider, LLMError + +# Used to call the AWS Bedrock converse API +# Converse API provides consistent API, that works with all Amazon Bedrock models that support messages. +# Eg: anthropic.claude-v2, +# meta.llama3-70b-instruct-v1:0, +# mistral.mixtral-8x7b-instruct-v0:1 +# The model value can be a baseModelId or provisionedModelArn. +# Using a base model id gives on-demand throughput. +# Use CreateProvisionedModelThroughput API to get provisionedModelArn for higher throughput. +# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +class AWSBedrockProvider(Provider): + def __init__(self, **config): + """ + Initialize the AWS Bedrock provider with the given configuration. + Pass the entire configuration dictionary to the Anthropic Bedrock client constructor. + """ + # Anthropic Bedrock client will use the default AWS credential providers, such as + # using ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables. + # If region is not set, it will use a default to us-west-1 which can lead to error - + # "Could not connect to the endpoint URL" + # It does not like parameters passed to the constructor. + self.client = boto3.client("bedrock-runtime") + # Maintain a list of Inference Parameters which Bedrock supports. + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html + self.inference_parameters = ['maxTokens', 'temperature', 'topP', 'stopSequences'] + + def normalize_response(self, response): + """Normalize the response from the Bedrock API to match OpenAI's response format.""" + return { + "choices": [ + { + "message": { + "content": response["output"]["message"]["content"] if response["output"].get("message") else "", + "role": "assistant" + }, + } + ] + } + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by Anthropic will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html + system_message = None + if messages[0]["role"] == "system": + system_message = [{"text": messages[0]["content"]}] + messages = messages[1:] + + formatted_messages = [] + for message in messages: + # QUIETLY Ignore any "system" messages except the first system message. + if message["role"] != "system": + formatted_messages.append({ + "role": message["role"], + "content": [{"text": message["content"]}] + }) + + # Maintain a list of Inference Parameters which Bedrock supports. + # These fields need to be passed using inferenceConfig. + # Rest all other fields are passed as additionalModelRequestFields. + inference_config = {} + additional_model_request_fields = {} + + # Iterate over the kwargs and separate the inference parameters and additional model request fields. + for key, value in kwargs.items(): + if key in self.inference_parameters: + inference_config[key] = value + else: + additional_model_request_fields[key] = value + + # Call the Bedrock Converse API. + response = self.client.converse( + modelId=model, # baseModelId or provisionedModelArn + messages=formatted_messages, + system=system_message, + inferenceConfig=inference_config, + additionalModelRequestFields=additional_model_request_fields + ) + return self.normalize_response(response) diff --git a/aisuite/providers/azure_provider.py b/aisuite/providers/azure_provider.py new file mode 100644 index 00000000..0421990d --- /dev/null +++ b/aisuite/providers/azure_provider.py @@ -0,0 +1,44 @@ +import urllib.request +import json +from aisuite.provider import Provider +from typing import override + +class AzureProvider(Provider): + def __init__(self, **config): + self.base_url = config.get('base_url') + self.api_key = config.get('api_key') + if not self.api_key: + raise ValueError("api_key is required in the config") + + @override + def chat_completions_create(self, model, messages, **kwargs): + # TODO: Need to decide if we need to use base_url or just ignore it. + if self.base_url: + url = f"{self.base_url}/chat/completions" + else: + url = f"https://{model}.inference.ai.azure.com/v1/chat/completions" + + # Remove 'stream' from kwargs if present + kwargs.pop('stream', None) + data = { + "messages": messages, + **kwargs + } + + body = json.dumps(data).encode('utf-8') + headers = { + "Content-Type": "application/json", + "Authorization": self.api_key + } + + req = urllib.request.Request(url, body, headers) + + try: + with urllib.request.urlopen(req) as response: + result = response.read() + return json.loads(result) + except urllib.error.HTTPError as error: + error_message = f"The request failed with status code: {error.code}\n" + error_message += f"Headers: {error.info()}\n" + error_message += error.read().decode("utf-8", "ignore") + raise Exception(error_message) \ No newline at end of file diff --git a/aisuite/providers/gcp_provider.py b/aisuite/providers/gcp_provider.py new file mode 100644 index 00000000..f4ca5e0b --- /dev/null +++ b/aisuite/providers/gcp_provider.py @@ -0,0 +1,10 @@ +from aisuite.provider import Provider +from typing import override + +class GcpProvider(Provider): + def __init__(self) -> None: + pass + + @override + def chat_completions_create(self, model, messages): + raise ValueError("GCP Provider not yet implemented.") \ No newline at end of file diff --git a/aisuite/providers/groq_provider.py b/aisuite/providers/groq_provider.py new file mode 100644 index 00000000..3a2611cd --- /dev/null +++ b/aisuite/providers/groq_provider.py @@ -0,0 +1,10 @@ +from aisuite.provider import Provider +from typing import override + +class GroqProvider(Provider): + def __init__(self) -> None: + pass + + @override + def chat_completions_create(self, model, messages): + raise ValueError("Groq provider not yet implemented.") \ No newline at end of file diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py new file mode 100644 index 00000000..65634322 --- /dev/null +++ b/aisuite/providers/openai_provider.py @@ -0,0 +1,30 @@ +import openai +import os +from aisuite.provider import Provider, LLMError + +class OpenAIProvider(Provider): + def __init__(self, **config): + """ + Initialize the OpenAI provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault('api_key', os.getenv('OPENAI_API_KEY')) + if not config['api_key']: + raise ValueError("OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable.") + + # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically + # infer certain values from the environment variables. + # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc. + + # Pass the entire config to the OpenAI client constructor + self.client = openai.OpenAI(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by OpenAI will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the OpenAI API + )