From d77f3121eabd85f7b2f2344bad8cec50d31dab72 Mon Sep 17 00:00:00 2001 From: Rohit Prasad Date: Thu, 12 Sep 2024 12:14:28 -0700 Subject: [PATCH] Refactoring client and providers. (#27) Added Azure support & refactored code. Refactoring includes - - ProviderFactory - Lazily import the provider based on config passed to Client. Will need to port the older provider files to the new format. Till then keeping the older provider interface related tests. Co-authored-by: rohit-rptless --- .gitignore | 1 + aisuite/__init__.py | 1 + aisuite/client.py | 123 +++++++++++++++++ aisuite/client/__init__.py | 3 - aisuite/client/chat.py | 18 --- aisuite/client/client.py | 90 ------------ aisuite/client/completions.py | 37 ----- aisuite/framework/__init__.py | 2 - aisuite/provider.py | 68 +++++++++ aisuite/providers/anthropic_provider.py | 40 ++++++ aisuite/providers/aws_bedrock_provider.py | 88 ++++++++++++ aisuite/providers/azure_provider.py | 44 ++++++ aisuite/providers/gcp_provider.py | 9 ++ aisuite/providers/groq_provider.py | 9 ++ aisuite/providers/openai_provider.py | 33 +++++ examples/client.ipynb | 161 +++++++++++----------- tests/client/test_client.py | 143 +++++++++++++++---- 17 files changed, 612 insertions(+), 258 deletions(-) create mode 100644 aisuite/client.py delete mode 100644 aisuite/client/__init__.py delete mode 100644 aisuite/client/chat.py delete mode 100644 aisuite/client/client.py delete mode 100644 aisuite/client/completions.py create mode 100644 aisuite/provider.py create mode 100644 aisuite/providers/anthropic_provider.py create mode 100644 aisuite/providers/aws_bedrock_provider.py create mode 100644 aisuite/providers/azure_provider.py create mode 100644 aisuite/providers/gcp_provider.py create mode 100644 aisuite/providers/groq_provider.py create mode 100644 aisuite/providers/openai_provider.py diff --git a/.gitignore b/.gitignore index f1f1d58d..e1084c98 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea/ .vscode/ __pycache__/ +env/ diff --git a/aisuite/__init__.py b/aisuite/__init__.py index 3ff722bf..7f5ee70b 100644 --- a/aisuite/__init__.py +++ b/aisuite/__init__.py @@ -1 +1,2 @@ from .client import Client +from .provider import ProviderNames diff --git a/aisuite/client.py b/aisuite/client.py new file mode 100644 index 00000000..326a83c0 --- /dev/null +++ b/aisuite/client.py @@ -0,0 +1,123 @@ +from .provider import ProviderFactory, ProviderNames + + +class Client: + def __init__(self, provider_configs: dict = {}): + """ + Initialize the client with provider configurations. + Use the ProviderFactory to create provider instances. + + Args: + provider_configs (dict): A dictionary containing provider configurations. + Each key should be a ProviderNames enum or its string representation, + and the value should be a dictionary of configuration options for that provider. + For example: + { + ProviderNames.OPENAI: {"api_key": "your_openai_api_key"}, + "aws-bedrock": { + "aws_access_key": "your_aws_access_key", + "aws_secret_key": "your_aws_secret_key", + "aws_region": "us-west-2" + } + } + """ + self.providers = {} + self.provider_configs = provider_configs + self._chat = None + self._initialize_providers() + + def _initialize_providers(self): + """Helper method to initialize or update providers.""" + for provider_key, config in self.provider_configs.items(): + provider_key = self._validate_provider_key(provider_key) + self.providers[provider_key.value] = ProviderFactory.create_provider( + provider_key, config + ) + + def _validate_provider_key(self, provider_key): + """ + Validate if the provider key is part of ProviderNames enum. + Allow strings as well and convert them to ProviderNames. + """ + if isinstance(provider_key, str): + if provider_key not in ProviderNames._value2member_map_: + raise ValueError(f"Provider {provider_key} is not a valid provider") + return ProviderNames(provider_key) + + if isinstance(provider_key, ProviderNames): + return provider_key + + raise ValueError( + f"Provider {provider_key} should either be a string or enum ProviderNames" + ) + + def configure(self, provider_configs: dict = None): + """ + Configure the client with provider configurations. + """ + if provider_configs is None: + return + + self.provider_configs.update(provider_configs) + self._initialize_providers() # NOTE: This will override existing provider instances. + + @property + def chat(self): + """Return the chat API interface.""" + if not self._chat: + self._chat = Chat(self) + return self._chat + + +class Chat: + def __init__(self, client: "Client"): + self.client = client + self._completions = Completions(self.client) + + @property + def completions(self): + """Return the completions interface.""" + return self._completions + + +class Completions: + def __init__(self, client: "Client"): + self.client = client + + def create(self, model: str, messages: list, **kwargs): + """ + Create chat completion based on the model, messages, and any extra arguments. + """ + # Check that correct format is used + if ":" not in model: + raise ValueError( + f"Invalid model format. Expected 'provider:model', got '{model}'" + ) + + # Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name" + provider_key, model_name = model.split(":", 1) + + if provider_key not in ProviderNames._value2member_map_: + # If the provider key does not match, give a clearer message to guide the user + valid_providers = ", ".join([p.value for p in ProviderNames]) + raise ValueError( + f"Invalid provider key '{provider_key}'. Expected one of: {valid_providers}. " + "Make sure the model string is formatted correctly as 'provider:model'." + ) + + if provider_key not in self.client.providers: + config = {} + if provider_key in self.client.provider_configs: + config = self.client.provider_configs[provider_key] + self.client.providers[provider_key] = ProviderFactory.create_provider( + ProviderNames(provider_key), config + ) + + provider = self.client.providers.get(provider_key) + if not provider: + raise ValueError(f"Could not load provider for {provider_key}.") + + # Delegate the chat completion to the correct provider's implementation + # Any additional arguments will be passed to the provider's implementation. + # Eg: max_tokens, temperature, etc. + return provider.chat_completions_create(model_name, messages, **kwargs) diff --git a/aisuite/client/__init__.py b/aisuite/client/__init__.py deleted file mode 100644 index 75e4cc81..00000000 --- a/aisuite/client/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Provides the Client for managing chats across many FM providers.""" - -from .client import Client diff --git a/aisuite/client/chat.py b/aisuite/client/chat.py deleted file mode 100644 index 6b51ae0f..00000000 --- a/aisuite/client/chat.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Chat is instantiated with a client and manages completions.""" - -from .completions import Completions - - -class Chat: - """Manage chat sessions with multiple providers.""" - - def __init__(self, topmost_instance): - """Initialize a new Chat instance. - - Args: - ---- - topmost_instance: The chat session's client instance (Client). - - """ - self.topmost_instance = topmost_instance - self.completions = Completions(topmost_instance) diff --git a/aisuite/client/client.py b/aisuite/client/client.py deleted file mode 100644 index e2654126..00000000 --- a/aisuite/client/client.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Client manages a Chat across multiple provider interfaces.""" - -from .chat import Chat -from ..providers import ( - AnthropicInterface, - AWSBedrockInterface, - FireworksInterface, - GroqInterface, - MistralInterface, - OctoInterface, - OllamaInterface, - OpenAIInterface, - ReplicateInterface, - TogetherInterface, - GoogleInterface, -) - - -class Client: - """Manages multiple provider interfaces.""" - - _MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE = ( - "Expected ':' in model identifier to specify provider:model. Got {model}." - ) - _NO_FACTORY_ERROR_MESSAGE_TEMPLATE = ( - "Could not find factory to create interface for provider '{provider}'." - ) - - def __init__(self): - """Initialize the Client instance. - - Attributes - ---------- - chat (Chat): The chat session. - all_interfaces (dict): Stores interface instances by provider names. - all_factories (dict): Maps provider names to their corresponding interfaces. - - """ - self.chat = Chat(self) - self.all_interfaces = {} - self.all_factories = { - "anthropic": AnthropicInterface, - "aws": AWSBedrockInterface, - "fireworks": FireworksInterface, - "groq": GroqInterface, - "mistral": MistralInterface, - "octo": OctoInterface, - "ollama": OllamaInterface, - "openai": OpenAIInterface, - "replicate": ReplicateInterface, - "together": TogetherInterface, - "google": GoogleInterface, - } - - def get_provider_interface(self, model): - """Retrieve or create a provider interface based on a model identifier. - - Args: - ---- - model (str): The model identifier in the format 'provider:model'. - - Raises: - ------ - ValueError: If the model identifier does colon-separate provider and model. - Exception: If no factory is found from the supplied model. - - Returns: - ------- - The interface instance for the provider and the model name. - - """ - if ":" not in model: - raise ValueError( - self._MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE.format(model=model) - ) - - model_parts = model.split(":", maxsplit=1) - provider = model_parts[0] - model_name = model_parts[1] - - if provider in self.all_interfaces: - return self.all_interfaces[provider], model_name - - if provider not in self.all_factories: - raise Exception( - self._NO_FACTORY_ERROR_MESSAGE_TEMPLATE.format(provider=provider) - ) - - self.all_interfaces[provider] = self.all_factories[provider]() - return self.all_interfaces[provider], model_name diff --git a/aisuite/client/completions.py b/aisuite/client/completions.py deleted file mode 100644 index c87c19f3..00000000 --- a/aisuite/client/completions.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Completions is instantiated with a client and manages completion requests in chat sessions.""" - - -class Completions: - """Manage completion requests in chat sessions.""" - - def __init__(self, topmost_instance): - """Initialize a new Completions instance. - - Args: - ---- - topmost_instance: The chat session's client instance (Client). - - """ - self.topmost_instance = topmost_instance - - def create(self, model=None, temperature=0, messages=None): - """Create a completion request using a specified provider/model combination. - - Args: - ---- - model (str): The model identifier with format 'provider:model'. - temperature (float): The sampling temperature. - messages (list): A list of previous messages. - - Returns: - ------- - The resulting completion. - - """ - interface, model_name = self.topmost_instance.get_provider_interface(model) - - return interface.chat_completion_create( - messages=messages, - model=model_name, - temperature=temperature, - ) diff --git a/aisuite/framework/__init__.py b/aisuite/framework/__init__.py index 2d72fd37..aad7ebd2 100644 --- a/aisuite/framework/__init__.py +++ b/aisuite/framework/__init__.py @@ -1,4 +1,2 @@ -"""Provides the ProviderInterface for defining the interface that all FM providers must implement.""" - from .provider_interface import ProviderInterface from .chat_completion_response import ChatCompletionResponse diff --git a/aisuite/provider.py b/aisuite/provider.py new file mode 100644 index 00000000..c28e096a --- /dev/null +++ b/aisuite/provider.py @@ -0,0 +1,68 @@ +from abc import ABC, abstractmethod +from enum import Enum +import importlib + + +class LLMError(Exception): + """Custom exception for LLM errors.""" + + def __init__(self, message): + super().__init__(message) + + +class Provider(ABC): + @abstractmethod + def chat_completions_create(self, model, messages): + """Abstract method for chat completion calls, to be implemented by each provider.""" + pass + + +class ProviderNames(str, Enum): + OPENAI = "openai" + AWS_BEDROCK = "aws-bedrock" + ANTHROPIC = "anthropic" + AZURE = "azure" + + +class ProviderFactory: + """Factory to register and create provider instances based on keys.""" + + _provider_info = { + ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"), + ProviderNames.AWS_BEDROCK: ( + "aisuite.providers.aws_bedrock_provider", + "AWSBedrockProvider", + ), + ProviderNames.ANTHROPIC: ( + "aisuite.providers.anthropic_provider", + "AnthropicProvider", + ), + ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"), + } + + @classmethod + def create_provider(cls, provider_key, config): + """Dynamically import and create an instance of a provider based on the provider key.""" + if not isinstance(provider_key, ProviderNames): + raise ValueError( + f"Provider {provider_key} is not a valid ProviderNames enum" + ) + + module_name, class_name = cls._get_provider_info(provider_key) + if not module_name: + raise ValueError(f"Provider {provider_key.value} is not supported") + + # Lazily load the module + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise ImportError(f"Could not import module {module_name}: {str(e)}") + + # Instantiate the provider class + provider_class = getattr(module, class_name) + return provider_class(**config) + + @classmethod + def _get_provider_info(cls, provider_key): + """Return the module name and class name for a given provider key.""" + return cls._provider_info.get(provider_key, (None, None)) diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py new file mode 100644 index 00000000..c4868b41 --- /dev/null +++ b/aisuite/providers/anthropic_provider.py @@ -0,0 +1,40 @@ +import anthropic +from aisuite.provider import Provider +from aisuite.framework import ChatCompletionResponse + +# Define a constant for the default max_tokens value +DEFAULT_MAX_TOKENS = 4096 + + +class AnthropicProvider(Provider): + def __init__(self, **config): + """ + Initialize the Anthropic provider with the given configuration. + Pass the entire configuration dictionary to the Anthropic client constructor. + """ + + self.client = anthropic.Anthropic(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Check if the fist message is a system message + if messages[0]["role"] == "system": + system_message = messages[0]["content"] + messages = messages[1:] + else: + system_message = None + + # kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS) + if "max_tokens" not in kwargs: + kwargs["max_tokens"] = DEFAULT_MAX_TOKENS + + return self.normalize_response( + self.client.messages.create( + model=model, system=system_message, messages=messages, **kwargs + ) + ) + + def normalize_response(self, response): + """Normalize the response from the Anthropic API to match OpenAI's response format.""" + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response.content[0].text + return normalized_response diff --git a/aisuite/providers/aws_bedrock_provider.py b/aisuite/providers/aws_bedrock_provider.py new file mode 100644 index 00000000..c6991c82 --- /dev/null +++ b/aisuite/providers/aws_bedrock_provider.py @@ -0,0 +1,88 @@ +import boto3 +from aisuite.provider import Provider, LLMError +from aisuite.framework import ChatCompletionResponse + + +class AWSBedrockProvider(Provider): + def __init__(self, **config): + """ + Initialize the AWS Bedrock provider with the given configuration. + + This class uses the AWS Bedrock converse API, which provides a consistent interface + for all Amazon Bedrock models that support messages. Examples include: + - anthropic.claude-v2 + - meta.llama3-70b-instruct-v1:0 + - mistral.mixtral-8x7b-instruct-v0:1 + + The model value can be a baseModelId for on-demand throughput or a provisionedModelArn + for higher throughput. To obtain a provisionedModelArn, use the CreateProvisionedModelThroughput API. + + For more information on model IDs, see: + https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html + + Note: + - The Anthropic Bedrock client uses default AWS credential providers, such as + ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables. + - If the region is not set, it defaults to us-west-1, which may lead to a + "Could not connect to the endpoint URL" error. + - The client constructor does not accept additional parameters. + + Args: + **config: Configuration options for the provider. + + """ + self.client = boto3.client("bedrock-runtime", region_name="us-west-2") + self.inference_parameters = [ + "maxTokens", + "temperature", + "topP", + "stopSequences", + ] + + def normalize_response(self, response): + """Normalize the response from the Bedrock API to match OpenAI's response format.""" + norm_response = ChatCompletionResponse() + norm_response.choices[0].message.content = response["output"]["message"][ + "content" + ][0]["text"] + return norm_response + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by Anthropic will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html + system_message = None + if messages[0]["role"] == "system": + system_message = [{"text": messages[0]["content"]}] + messages = messages[1:] + + formatted_messages = [] + for message in messages: + # QUIETLY Ignore any "system" messages except the first system message. + if message["role"] != "system": + formatted_messages.append( + {"role": message["role"], "content": [{"text": message["content"]}]} + ) + + # Maintain a list of Inference Parameters which Bedrock supports. + # These fields need to be passed using inferenceConfig. + # Rest all other fields are passed as additionalModelRequestFields. + inference_config = {} + additional_model_request_fields = {} + + # Iterate over the kwargs and separate the inference parameters and additional model request fields. + for key, value in kwargs.items(): + if key in self.inference_parameters: + inference_config[key] = value + else: + additional_model_request_fields[key] = value + + # Call the Bedrock Converse API. + response = self.client.converse( + modelId=model, # baseModelId or provisionedModelArn + messages=formatted_messages, + system=system_message, + inferenceConfig=inference_config, + additionalModelRequestFields=additional_model_request_fields, + ) + return self.normalize_response(response) diff --git a/aisuite/providers/azure_provider.py b/aisuite/providers/azure_provider.py new file mode 100644 index 00000000..3e51a4fd --- /dev/null +++ b/aisuite/providers/azure_provider.py @@ -0,0 +1,44 @@ +import urllib.request +import json +from aisuite.provider import Provider +from aisuite.framework import ChatCompletionResponse + + +class AzureProvider(Provider): + def __init__(self, **config): + self.base_url = config.get("base_url") + self.api_key = config.get("api_key") + if not self.api_key: + raise ValueError("api_key is required in the config") + + def chat_completions_create(self, model, messages, **kwargs): + # TODO: Need to decide if we need to use base_url or just ignore it. + # TODO: Remove the hardcoded region name to use environment variable. + url = f"https://{model}.westus3.models.ai.azure.com/v1/chat/completions" + if self.base_url: + url = f"{self.base_url}/chat/completions" + + # Remove 'stream' from kwargs if present + kwargs.pop("stream", None) + data = {"messages": messages, **kwargs} + + body = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json", "Authorization": self.api_key} + + try: + req = urllib.request.Request(url, body, headers) + with urllib.request.urlopen(req) as response: + result = response.read() + resp_json = json.loads(result) + completion_response = ChatCompletionResponse() + # TODO: Add checks for fields being present in resp_json. + completion_response.choices[0].message.content = resp_json["choices"][ + 0 + ]["message"]["content"] + return completion_response + + except urllib.error.HTTPError as error: + error_message = f"The request failed with status code: {error.code}\n" + error_message += f"Headers: {error.info()}\n" + error_message += error.read().decode("utf-8", "ignore") + raise Exception(error_message) diff --git a/aisuite/providers/gcp_provider.py b/aisuite/providers/gcp_provider.py new file mode 100644 index 00000000..8f8676d6 --- /dev/null +++ b/aisuite/providers/gcp_provider.py @@ -0,0 +1,9 @@ +from aisuite.provider import Provider + + +class GcpProvider(Provider): + def __init__(self) -> None: + pass + + def chat_completions_create(self, model, messages): + raise ValueError("GCP Provider not yet implemented.") diff --git a/aisuite/providers/groq_provider.py b/aisuite/providers/groq_provider.py new file mode 100644 index 00000000..6ddde342 --- /dev/null +++ b/aisuite/providers/groq_provider.py @@ -0,0 +1,9 @@ +from aisuite.provider import Provider + + +class GroqProvider(Provider): + def __init__(self) -> None: + pass + + def chat_completions_create(self, model, messages): + raise ValueError("Groq provider not yet implemented.") diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py new file mode 100644 index 00000000..5823c45a --- /dev/null +++ b/aisuite/providers/openai_provider.py @@ -0,0 +1,33 @@ +import openai +import os +from aisuite.provider import Provider, LLMError + + +class OpenAIProvider(Provider): + def __init__(self, **config): + """ + Initialize the OpenAI provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("OPENAI_API_KEY")) + if not config["api_key"]: + raise ValueError( + "OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." + ) + + # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically + # infer certain values from the environment variables. + # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc. + + # Pass the entire config to the OpenAI client constructor + self.client = openai.OpenAI(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by OpenAI will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the OpenAI API + ) diff --git a/examples/client.ipynb b/examples/client.ipynb index a8049a37..f6578cd1 100644 --- a/examples/client.ipynb +++ b/examples/client.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 1, "id": "initial_id", "metadata": { "ExecuteTime": { @@ -26,48 +26,63 @@ "start_time": "2024-07-04T15:30:02.051986Z" } }, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import sys\n", - "sys.path.append('../../aisuite')\n", - "\n", "from dotenv import load_dotenv, find_dotenv\n", "\n", - "load_dotenv(find_dotenv())" + "sys.path.append('../../aisuite')" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4", + "execution_count": 3, + "id": "f75736ee", "metadata": {}, "outputs": [], "source": [ - "import os \n", + "import os\n", + "def configure_environment(additional_env_vars=None):\n", + " \"\"\"\n", + " Load environment variables from .env file and apply any additional variables.\n", + " :param additional_env_vars: A dictionary of additional environment variables to apply.\n", + " \"\"\"\n", + " # Load from .env file if available\n", + " load_dotenv(find_dotenv())\n", + "\n", + " # Apply additional environment variables\n", + " if additional_env_vars:\n", + " for key, value in additional_env_vars.items():\n", + " os.environ[key] = value\n", "\n", - "os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n", - "os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n", - "os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n", - "os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n", - "os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings\n", - "os.environ['AWS_ACCESS_KEY_ID'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home\n", - "os.environ['AWS_SECRET_ACCESS_KEY'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home" + "# Define additional API keys and AWS credentials\n", + "additional_keys = {\n", + " 'GROQ_API_KEY': 'xxx',\n", + " 'FIREWORKS_API_KEY': 'xxx', \n", + " 'REPLICATE_API_KEY': 'xxx', \n", + " 'TOGETHER_API_KEY': 'xxx', \n", + " 'OCTO_API_KEY': 'xxx',\n", + " 'AWS_ACCESS_KEY_ID': 'xxx',\n", + " 'AWS_SECRET_ACCESS_KEY': 'xxx',\n", + "}\n", + "\n", + "# Configure environment\n", + "configure_environment(additional_env_vars=additional_keys)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, + "id": "744c5c15", + "metadata": {}, + "outputs": [], + "source": [ + "print(os.environ[\"AWS_SECRET_ACCESS_KEY\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "id": "4de3a24f", "metadata": { "ExecuteTime": { @@ -80,21 +95,23 @@ "import aisuite as ai\n", "\n", "client = ai.Client()\n", - "\n", "messages = [\n", " {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n", - " {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n", + " {\"role\": \"user\", \"content\": \"Tell me a joke about Captain Jack Sparrow\"},\n", "]" ] }, { "cell_type": "code", "execution_count": null, - "id": "1ffe9a49-638e-4304-b9de-49ee21d9ac8d", + "id": "520a6879", "metadata": {}, "outputs": [], "source": [ - "#!pip install boto3" + "# print(os.environ[\"ANTHROPIC_API_KEY\"])\n", + "anthropic_claude_3_opus = \"anthropic:claude-3-5-sonnet-20240620\"\n", + "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", + "print(response.choices[0].message.content)" ] }, { @@ -104,11 +121,39 @@ "metadata": {}, "outputs": [], "source": [ - "aws_bedrock_llama3_8b = \"aws:meta.llama3-8b-instruct-v1:0\"\n", - "#aws_bedrock_llama3_8b = \"aws:meta.llama3-70b-instruct-v1:0\"\n", - "\n", + "# print(os.environ['AWS_SECRET_ACCESS_KEY'])\n", + "# print(os.environ['AWS_ACCESS_KEY_ID'])\n", + "# print(os.environ['AWS_REGION'])\n", + "aws_bedrock_llama3_8b = \"aws-bedrock:meta.llama3-1-8b-instruct-v1:0\"\n", "response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n", - "\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7e46c20a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Arrr, listen close me hearties! Here be a joke for ye:\n", + "\n", + "Why did Captain Jack Sparrow go to the doctor?\n", + "\n", + "Because he had a bit o' a \"crabby\" day! (get it? crabby? like a crustacean, but also feeling grumpy? Ah, never mind, matey, ye landlubbers wouldn't understand...\n" + ] + } + ], + "source": [ + "client2 = ai.Client({\"azure\" : {\n", + " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", + "}});\n", + "azure_model = \"azure:aisuite-Meta-Llama-3-8B-Inst\"\n", + "response = client2.chat.completions.create(model=azure_model, messages=messages)\n", "print(response.choices[0].message.content)" ] }, @@ -197,37 +242,6 @@ "print(response.choices[0].message.content)" ] }, - { - "cell_type": "code", - "execution_count": 4, - "id": "adebd2f0b578a909", - "metadata": { - "ExecuteTime": { - "end_time": "2024-07-04T15:31:25.060689Z", - "start_time": "2024-07-04T15:31:16.131205Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Arrr, me bucko, 'ere be a jolly jest fer ye!\n", - "\n", - "What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n", - "\n", - "Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n" - ] - } - ], - "source": [ - "anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n", - "\n", - "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", - "\n", - "print(response.choices[0].message.content)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -263,19 +277,10 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "611210a4dc92845f", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Why did the pirate go to the seafood restaurant? \n", - "Because he heard they had some great fish tales! Arrr!\n" - ] - } - ], + "outputs": [], "source": [ "openai_gpt35 = \"openai:gpt-3.5-turbo\"\n", "\n", @@ -301,7 +306,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 1232a37d..7884a7fe 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,45 +1,128 @@ -import pytest -from aisuite.client.client import Client, AnthropicInterface +import unittest +from unittest.mock import patch +from aisuite import Client +from aisuite import ProviderNames -def test_get_provider_interface_with_new_instance(): - """Test that get_provider_interface creates a new instance of the interface.""" - client = Client() - interface, model_name = client.get_provider_interface("anthropic:some-model:v1") - assert isinstance(interface, AnthropicInterface) - assert model_name == "some-model:v1" - assert client.all_interfaces["anthropic"] == interface +class TestClient(unittest.TestCase): + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + @patch( + "aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create" + ) + @patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") + @patch( + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" + ) + def test_client_chat_completions( + self, mock_anthropic, mock_azure, mock_bedrock, mock_openai + ): + # Mock responses from providers + mock_openai.return_value = "OpenAI Response" + mock_bedrock.return_value = "AWS Bedrock Response" + mock_azure.return_value = "Azure Response" + mock_anthropic.return_value = "Anthropic Response" -def test_get_provider_interface_with_existing_instance(): - """Test that get_provider_interface returns an existing instance of the interface, if already created.""" - client = Client() + # Provider configurations + provider_configs = { + ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, + ProviderNames.AWS_BEDROCK: { + "aws_access_key": "test_aws_access_key", + "aws_secret_key": "test_aws_secret_key", + "aws_session_token": "test_aws_session_token", + "aws_region": "us-west-2", + }, + ProviderNames.AZURE: { + "api_key": "azure-api-key", + }, + } - # New interface instance - new_instance, _ = client.get_provider_interface("anthropic:some-model:v2") + # Initialize the client + client = Client() + client.configure(provider_configs) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + ] - # Call twice, get same instance back - same_instance, _ = client.get_provider_interface("anthropic:some-model:v2") + # Test OpenAI model + open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o" + openai_response = client.chat.completions.create( + open_ai_model, messages=messages + ) + self.assertEqual(openai_response, "OpenAI Response") + mock_openai.assert_called_once() - assert new_instance is same_instance + # Test AWS Bedrock model + bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3" + bedrock_response = client.chat.completions.create( + bedrock_model, messages=messages + ) + self.assertEqual(bedrock_response, "AWS Bedrock Response") + mock_bedrock.assert_called_once() + azure_model = ProviderNames.AZURE + ":" + "azure-model" + azure_response = client.chat.completions.create(azure_model, messages=messages) + self.assertEqual(azure_response, "Azure Response") + mock_azure.assert_called_once() -def test_get_provider_interface_with_invalid_format(): - client = Client() + anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model" + anthropic_response = client.chat.completions.create( + anthropic_model, messages=messages + ) + self.assertEqual(anthropic_response, "Anthropic Response") + mock_anthropic.assert_called_once() - with pytest.raises(ValueError) as exc_info: - client.get_provider_interface("invalid-model-no-colon") + # Test that new instances of Completion are not created each time we make an inference call. + compl_instance = client.chat.completions + next_compl_instance = client.chat.completions + assert compl_instance is next_compl_instance - assert "Expected ':' in model identifier" in str(exc_info.value) + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + def test_invalid_provider_in_client_config(self, mock_openai): + # Testing an invalid provider name in the configuration + invalid_provider_configs = { + "INVALID_PROVIDER": {"api_key": "invalid_api_key"}, + } + # Expect ValueError when initializing Client with invalid provider + with self.assertRaises(ValueError) as context: + client = Client(invalid_provider_configs) -def test_get_provider_interface_with_unknown_interface(): - client = Client() + # Verify the error message + self.assertIn( + "Provider INVALID_PROVIDER is not a valid provider", + str(context.exception), + ) - with pytest.raises(Exception) as exc_info: - client.get_provider_interface("unknown-interface:some-model") + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + def test_invalid_model_format_in_create(self, mock_openai): + # Valid provider configurations + provider_configs = { + ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, + } - assert ( - "Could not find factory to create interface for provider 'unknown-interface'" - in str(exc_info.value) - ) + # Initialize the client with valid provider + client = Client(provider_configs) + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format + with self.assertRaises(ValueError) as context: + client.chat.completions.create(invalid_model, messages=messages) + + # Verify the error message + self.assertIn( + "Invalid model format. Expected 'provider:model'", str(context.exception) + ) + + +if __name__ == "__main__": + unittest.main()