diff --git a/.gitignore b/.gitignore index f1f1d58d..e1084c98 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea/ .vscode/ __pycache__/ +env/ diff --git a/aisuite/__init__.py b/aisuite/__init__.py index 3ff722bf..7f5ee70b 100644 --- a/aisuite/__init__.py +++ b/aisuite/__init__.py @@ -1 +1,2 @@ from .client import Client +from .provider import ProviderNames diff --git a/aisuite/client.py b/aisuite/client.py new file mode 100644 index 00000000..326a83c0 --- /dev/null +++ b/aisuite/client.py @@ -0,0 +1,123 @@ +from .provider import ProviderFactory, ProviderNames + + +class Client: + def __init__(self, provider_configs: dict = {}): + """ + Initialize the client with provider configurations. + Use the ProviderFactory to create provider instances. + + Args: + provider_configs (dict): A dictionary containing provider configurations. + Each key should be a ProviderNames enum or its string representation, + and the value should be a dictionary of configuration options for that provider. + For example: + { + ProviderNames.OPENAI: {"api_key": "your_openai_api_key"}, + "aws-bedrock": { + "aws_access_key": "your_aws_access_key", + "aws_secret_key": "your_aws_secret_key", + "aws_region": "us-west-2" + } + } + """ + self.providers = {} + self.provider_configs = provider_configs + self._chat = None + self._initialize_providers() + + def _initialize_providers(self): + """Helper method to initialize or update providers.""" + for provider_key, config in self.provider_configs.items(): + provider_key = self._validate_provider_key(provider_key) + self.providers[provider_key.value] = ProviderFactory.create_provider( + provider_key, config + ) + + def _validate_provider_key(self, provider_key): + """ + Validate if the provider key is part of ProviderNames enum. + Allow strings as well and convert them to ProviderNames. + """ + if isinstance(provider_key, str): + if provider_key not in ProviderNames._value2member_map_: + raise ValueError(f"Provider {provider_key} is not a valid provider") + return ProviderNames(provider_key) + + if isinstance(provider_key, ProviderNames): + return provider_key + + raise ValueError( + f"Provider {provider_key} should either be a string or enum ProviderNames" + ) + + def configure(self, provider_configs: dict = None): + """ + Configure the client with provider configurations. + """ + if provider_configs is None: + return + + self.provider_configs.update(provider_configs) + self._initialize_providers() # NOTE: This will override existing provider instances. + + @property + def chat(self): + """Return the chat API interface.""" + if not self._chat: + self._chat = Chat(self) + return self._chat + + +class Chat: + def __init__(self, client: "Client"): + self.client = client + self._completions = Completions(self.client) + + @property + def completions(self): + """Return the completions interface.""" + return self._completions + + +class Completions: + def __init__(self, client: "Client"): + self.client = client + + def create(self, model: str, messages: list, **kwargs): + """ + Create chat completion based on the model, messages, and any extra arguments. + """ + # Check that correct format is used + if ":" not in model: + raise ValueError( + f"Invalid model format. Expected 'provider:model', got '{model}'" + ) + + # Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name" + provider_key, model_name = model.split(":", 1) + + if provider_key not in ProviderNames._value2member_map_: + # If the provider key does not match, give a clearer message to guide the user + valid_providers = ", ".join([p.value for p in ProviderNames]) + raise ValueError( + f"Invalid provider key '{provider_key}'. Expected one of: {valid_providers}. " + "Make sure the model string is formatted correctly as 'provider:model'." + ) + + if provider_key not in self.client.providers: + config = {} + if provider_key in self.client.provider_configs: + config = self.client.provider_configs[provider_key] + self.client.providers[provider_key] = ProviderFactory.create_provider( + ProviderNames(provider_key), config + ) + + provider = self.client.providers.get(provider_key) + if not provider: + raise ValueError(f"Could not load provider for {provider_key}.") + + # Delegate the chat completion to the correct provider's implementation + # Any additional arguments will be passed to the provider's implementation. + # Eg: max_tokens, temperature, etc. + return provider.chat_completions_create(model_name, messages, **kwargs) diff --git a/aisuite/client/__init__.py b/aisuite/client/__init__.py deleted file mode 100644 index 75e4cc81..00000000 --- a/aisuite/client/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Provides the Client for managing chats across many FM providers.""" - -from .client import Client diff --git a/aisuite/client/chat.py b/aisuite/client/chat.py deleted file mode 100644 index 6b51ae0f..00000000 --- a/aisuite/client/chat.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Chat is instantiated with a client and manages completions.""" - -from .completions import Completions - - -class Chat: - """Manage chat sessions with multiple providers.""" - - def __init__(self, topmost_instance): - """Initialize a new Chat instance. - - Args: - ---- - topmost_instance: The chat session's client instance (Client). - - """ - self.topmost_instance = topmost_instance - self.completions = Completions(topmost_instance) diff --git a/aisuite/client/client.py b/aisuite/client/client.py deleted file mode 100644 index e2654126..00000000 --- a/aisuite/client/client.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Client manages a Chat across multiple provider interfaces.""" - -from .chat import Chat -from ..providers import ( - AnthropicInterface, - AWSBedrockInterface, - FireworksInterface, - GroqInterface, - MistralInterface, - OctoInterface, - OllamaInterface, - OpenAIInterface, - ReplicateInterface, - TogetherInterface, - GoogleInterface, -) - - -class Client: - """Manages multiple provider interfaces.""" - - _MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE = ( - "Expected ':' in model identifier to specify provider:model. Got {model}." - ) - _NO_FACTORY_ERROR_MESSAGE_TEMPLATE = ( - "Could not find factory to create interface for provider '{provider}'." - ) - - def __init__(self): - """Initialize the Client instance. - - Attributes - ---------- - chat (Chat): The chat session. - all_interfaces (dict): Stores interface instances by provider names. - all_factories (dict): Maps provider names to their corresponding interfaces. - - """ - self.chat = Chat(self) - self.all_interfaces = {} - self.all_factories = { - "anthropic": AnthropicInterface, - "aws": AWSBedrockInterface, - "fireworks": FireworksInterface, - "groq": GroqInterface, - "mistral": MistralInterface, - "octo": OctoInterface, - "ollama": OllamaInterface, - "openai": OpenAIInterface, - "replicate": ReplicateInterface, - "together": TogetherInterface, - "google": GoogleInterface, - } - - def get_provider_interface(self, model): - """Retrieve or create a provider interface based on a model identifier. - - Args: - ---- - model (str): The model identifier in the format 'provider:model'. - - Raises: - ------ - ValueError: If the model identifier does colon-separate provider and model. - Exception: If no factory is found from the supplied model. - - Returns: - ------- - The interface instance for the provider and the model name. - - """ - if ":" not in model: - raise ValueError( - self._MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE.format(model=model) - ) - - model_parts = model.split(":", maxsplit=1) - provider = model_parts[0] - model_name = model_parts[1] - - if provider in self.all_interfaces: - return self.all_interfaces[provider], model_name - - if provider not in self.all_factories: - raise Exception( - self._NO_FACTORY_ERROR_MESSAGE_TEMPLATE.format(provider=provider) - ) - - self.all_interfaces[provider] = self.all_factories[provider]() - return self.all_interfaces[provider], model_name diff --git a/aisuite/client/completions.py b/aisuite/client/completions.py deleted file mode 100644 index c87c19f3..00000000 --- a/aisuite/client/completions.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Completions is instantiated with a client and manages completion requests in chat sessions.""" - - -class Completions: - """Manage completion requests in chat sessions.""" - - def __init__(self, topmost_instance): - """Initialize a new Completions instance. - - Args: - ---- - topmost_instance: The chat session's client instance (Client). - - """ - self.topmost_instance = topmost_instance - - def create(self, model=None, temperature=0, messages=None): - """Create a completion request using a specified provider/model combination. - - Args: - ---- - model (str): The model identifier with format 'provider:model'. - temperature (float): The sampling temperature. - messages (list): A list of previous messages. - - Returns: - ------- - The resulting completion. - - """ - interface, model_name = self.topmost_instance.get_provider_interface(model) - - return interface.chat_completion_create( - messages=messages, - model=model_name, - temperature=temperature, - ) diff --git a/aisuite/framework/__init__.py b/aisuite/framework/__init__.py index 2d72fd37..aad7ebd2 100644 --- a/aisuite/framework/__init__.py +++ b/aisuite/framework/__init__.py @@ -1,4 +1,2 @@ -"""Provides the ProviderInterface for defining the interface that all FM providers must implement.""" - from .provider_interface import ProviderInterface from .chat_completion_response import ChatCompletionResponse diff --git a/aisuite/provider.py b/aisuite/provider.py new file mode 100644 index 00000000..c28e096a --- /dev/null +++ b/aisuite/provider.py @@ -0,0 +1,68 @@ +from abc import ABC, abstractmethod +from enum import Enum +import importlib + + +class LLMError(Exception): + """Custom exception for LLM errors.""" + + def __init__(self, message): + super().__init__(message) + + +class Provider(ABC): + @abstractmethod + def chat_completions_create(self, model, messages): + """Abstract method for chat completion calls, to be implemented by each provider.""" + pass + + +class ProviderNames(str, Enum): + OPENAI = "openai" + AWS_BEDROCK = "aws-bedrock" + ANTHROPIC = "anthropic" + AZURE = "azure" + + +class ProviderFactory: + """Factory to register and create provider instances based on keys.""" + + _provider_info = { + ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"), + ProviderNames.AWS_BEDROCK: ( + "aisuite.providers.aws_bedrock_provider", + "AWSBedrockProvider", + ), + ProviderNames.ANTHROPIC: ( + "aisuite.providers.anthropic_provider", + "AnthropicProvider", + ), + ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"), + } + + @classmethod + def create_provider(cls, provider_key, config): + """Dynamically import and create an instance of a provider based on the provider key.""" + if not isinstance(provider_key, ProviderNames): + raise ValueError( + f"Provider {provider_key} is not a valid ProviderNames enum" + ) + + module_name, class_name = cls._get_provider_info(provider_key) + if not module_name: + raise ValueError(f"Provider {provider_key.value} is not supported") + + # Lazily load the module + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise ImportError(f"Could not import module {module_name}: {str(e)}") + + # Instantiate the provider class + provider_class = getattr(module, class_name) + return provider_class(**config) + + @classmethod + def _get_provider_info(cls, provider_key): + """Return the module name and class name for a given provider key.""" + return cls._provider_info.get(provider_key, (None, None)) diff --git a/aisuite/providers/anthropic_provider.py b/aisuite/providers/anthropic_provider.py new file mode 100644 index 00000000..c4868b41 --- /dev/null +++ b/aisuite/providers/anthropic_provider.py @@ -0,0 +1,40 @@ +import anthropic +from aisuite.provider import Provider +from aisuite.framework import ChatCompletionResponse + +# Define a constant for the default max_tokens value +DEFAULT_MAX_TOKENS = 4096 + + +class AnthropicProvider(Provider): + def __init__(self, **config): + """ + Initialize the Anthropic provider with the given configuration. + Pass the entire configuration dictionary to the Anthropic client constructor. + """ + + self.client = anthropic.Anthropic(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Check if the fist message is a system message + if messages[0]["role"] == "system": + system_message = messages[0]["content"] + messages = messages[1:] + else: + system_message = None + + # kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS) + if "max_tokens" not in kwargs: + kwargs["max_tokens"] = DEFAULT_MAX_TOKENS + + return self.normalize_response( + self.client.messages.create( + model=model, system=system_message, messages=messages, **kwargs + ) + ) + + def normalize_response(self, response): + """Normalize the response from the Anthropic API to match OpenAI's response format.""" + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response.content[0].text + return normalized_response diff --git a/aisuite/providers/aws_bedrock_provider.py b/aisuite/providers/aws_bedrock_provider.py new file mode 100644 index 00000000..c6991c82 --- /dev/null +++ b/aisuite/providers/aws_bedrock_provider.py @@ -0,0 +1,88 @@ +import boto3 +from aisuite.provider import Provider, LLMError +from aisuite.framework import ChatCompletionResponse + + +class AWSBedrockProvider(Provider): + def __init__(self, **config): + """ + Initialize the AWS Bedrock provider with the given configuration. + + This class uses the AWS Bedrock converse API, which provides a consistent interface + for all Amazon Bedrock models that support messages. Examples include: + - anthropic.claude-v2 + - meta.llama3-70b-instruct-v1:0 + - mistral.mixtral-8x7b-instruct-v0:1 + + The model value can be a baseModelId for on-demand throughput or a provisionedModelArn + for higher throughput. To obtain a provisionedModelArn, use the CreateProvisionedModelThroughput API. + + For more information on model IDs, see: + https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html + + Note: + - The Anthropic Bedrock client uses default AWS credential providers, such as + ~/.aws/credentials or the "AWS_SECRET_ACCESS_KEY" and "AWS_ACCESS_KEY_ID" environment variables. + - If the region is not set, it defaults to us-west-1, which may lead to a + "Could not connect to the endpoint URL" error. + - The client constructor does not accept additional parameters. + + Args: + **config: Configuration options for the provider. + + """ + self.client = boto3.client("bedrock-runtime", region_name="us-west-2") + self.inference_parameters = [ + "maxTokens", + "temperature", + "topP", + "stopSequences", + ] + + def normalize_response(self, response): + """Normalize the response from the Bedrock API to match OpenAI's response format.""" + norm_response = ChatCompletionResponse() + norm_response.choices[0].message.content = response["output"]["message"][ + "content" + ][0]["text"] + return norm_response + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by Anthropic will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html + system_message = None + if messages[0]["role"] == "system": + system_message = [{"text": messages[0]["content"]}] + messages = messages[1:] + + formatted_messages = [] + for message in messages: + # QUIETLY Ignore any "system" messages except the first system message. + if message["role"] != "system": + formatted_messages.append( + {"role": message["role"], "content": [{"text": message["content"]}]} + ) + + # Maintain a list of Inference Parameters which Bedrock supports. + # These fields need to be passed using inferenceConfig. + # Rest all other fields are passed as additionalModelRequestFields. + inference_config = {} + additional_model_request_fields = {} + + # Iterate over the kwargs and separate the inference parameters and additional model request fields. + for key, value in kwargs.items(): + if key in self.inference_parameters: + inference_config[key] = value + else: + additional_model_request_fields[key] = value + + # Call the Bedrock Converse API. + response = self.client.converse( + modelId=model, # baseModelId or provisionedModelArn + messages=formatted_messages, + system=system_message, + inferenceConfig=inference_config, + additionalModelRequestFields=additional_model_request_fields, + ) + return self.normalize_response(response) diff --git a/aisuite/providers/azure_provider.py b/aisuite/providers/azure_provider.py new file mode 100644 index 00000000..3e51a4fd --- /dev/null +++ b/aisuite/providers/azure_provider.py @@ -0,0 +1,44 @@ +import urllib.request +import json +from aisuite.provider import Provider +from aisuite.framework import ChatCompletionResponse + + +class AzureProvider(Provider): + def __init__(self, **config): + self.base_url = config.get("base_url") + self.api_key = config.get("api_key") + if not self.api_key: + raise ValueError("api_key is required in the config") + + def chat_completions_create(self, model, messages, **kwargs): + # TODO: Need to decide if we need to use base_url or just ignore it. + # TODO: Remove the hardcoded region name to use environment variable. + url = f"https://{model}.westus3.models.ai.azure.com/v1/chat/completions" + if self.base_url: + url = f"{self.base_url}/chat/completions" + + # Remove 'stream' from kwargs if present + kwargs.pop("stream", None) + data = {"messages": messages, **kwargs} + + body = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json", "Authorization": self.api_key} + + try: + req = urllib.request.Request(url, body, headers) + with urllib.request.urlopen(req) as response: + result = response.read() + resp_json = json.loads(result) + completion_response = ChatCompletionResponse() + # TODO: Add checks for fields being present in resp_json. + completion_response.choices[0].message.content = resp_json["choices"][ + 0 + ]["message"]["content"] + return completion_response + + except urllib.error.HTTPError as error: + error_message = f"The request failed with status code: {error.code}\n" + error_message += f"Headers: {error.info()}\n" + error_message += error.read().decode("utf-8", "ignore") + raise Exception(error_message) diff --git a/aisuite/providers/gcp_provider.py b/aisuite/providers/gcp_provider.py new file mode 100644 index 00000000..8f8676d6 --- /dev/null +++ b/aisuite/providers/gcp_provider.py @@ -0,0 +1,9 @@ +from aisuite.provider import Provider + + +class GcpProvider(Provider): + def __init__(self) -> None: + pass + + def chat_completions_create(self, model, messages): + raise ValueError("GCP Provider not yet implemented.") diff --git a/aisuite/providers/groq_provider.py b/aisuite/providers/groq_provider.py new file mode 100644 index 00000000..6ddde342 --- /dev/null +++ b/aisuite/providers/groq_provider.py @@ -0,0 +1,9 @@ +from aisuite.provider import Provider + + +class GroqProvider(Provider): + def __init__(self) -> None: + pass + + def chat_completions_create(self, model, messages): + raise ValueError("Groq provider not yet implemented.") diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py new file mode 100644 index 00000000..5823c45a --- /dev/null +++ b/aisuite/providers/openai_provider.py @@ -0,0 +1,33 @@ +import openai +import os +from aisuite.provider import Provider, LLMError + + +class OpenAIProvider(Provider): + def __init__(self, **config): + """ + Initialize the OpenAI provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("OPENAI_API_KEY")) + if not config["api_key"]: + raise ValueError( + "OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." + ) + + # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically + # infer certain values from the environment variables. + # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID, OPENAI_BASE_URL, etc. + + # Pass the entire config to the OpenAI client constructor + self.client = openai.OpenAI(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by OpenAI will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the OpenAI API + ) diff --git a/examples/client.ipynb b/examples/client.ipynb index a8049a37..f6578cd1 100644 --- a/examples/client.ipynb +++ b/examples/client.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 1, "id": "initial_id", "metadata": { "ExecuteTime": { @@ -26,48 +26,63 @@ "start_time": "2024-07-04T15:30:02.051986Z" } }, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import sys\n", - "sys.path.append('../../aisuite')\n", - "\n", "from dotenv import load_dotenv, find_dotenv\n", "\n", - "load_dotenv(find_dotenv())" + "sys.path.append('../../aisuite')" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4", + "execution_count": 3, + "id": "f75736ee", "metadata": {}, "outputs": [], "source": [ - "import os \n", + "import os\n", + "def configure_environment(additional_env_vars=None):\n", + " \"\"\"\n", + " Load environment variables from .env file and apply any additional variables.\n", + " :param additional_env_vars: A dictionary of additional environment variables to apply.\n", + " \"\"\"\n", + " # Load from .env file if available\n", + " load_dotenv(find_dotenv())\n", + "\n", + " # Apply additional environment variables\n", + " if additional_env_vars:\n", + " for key, value in additional_env_vars.items():\n", + " os.environ[key] = value\n", "\n", - "os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n", - "os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n", - "os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n", - "os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n", - "os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings\n", - "os.environ['AWS_ACCESS_KEY_ID'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home\n", - "os.environ['AWS_SECRET_ACCESS_KEY'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home" + "# Define additional API keys and AWS credentials\n", + "additional_keys = {\n", + " 'GROQ_API_KEY': 'xxx',\n", + " 'FIREWORKS_API_KEY': 'xxx', \n", + " 'REPLICATE_API_KEY': 'xxx', \n", + " 'TOGETHER_API_KEY': 'xxx', \n", + " 'OCTO_API_KEY': 'xxx',\n", + " 'AWS_ACCESS_KEY_ID': 'xxx',\n", + " 'AWS_SECRET_ACCESS_KEY': 'xxx',\n", + "}\n", + "\n", + "# Configure environment\n", + "configure_environment(additional_env_vars=additional_keys)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, + "id": "744c5c15", + "metadata": {}, + "outputs": [], + "source": [ + "print(os.environ[\"AWS_SECRET_ACCESS_KEY\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "id": "4de3a24f", "metadata": { "ExecuteTime": { @@ -80,21 +95,23 @@ "import aisuite as ai\n", "\n", "client = ai.Client()\n", - "\n", "messages = [\n", " {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n", - " {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n", + " {\"role\": \"user\", \"content\": \"Tell me a joke about Captain Jack Sparrow\"},\n", "]" ] }, { "cell_type": "code", "execution_count": null, - "id": "1ffe9a49-638e-4304-b9de-49ee21d9ac8d", + "id": "520a6879", "metadata": {}, "outputs": [], "source": [ - "#!pip install boto3" + "# print(os.environ[\"ANTHROPIC_API_KEY\"])\n", + "anthropic_claude_3_opus = \"anthropic:claude-3-5-sonnet-20240620\"\n", + "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", + "print(response.choices[0].message.content)" ] }, { @@ -104,11 +121,39 @@ "metadata": {}, "outputs": [], "source": [ - "aws_bedrock_llama3_8b = \"aws:meta.llama3-8b-instruct-v1:0\"\n", - "#aws_bedrock_llama3_8b = \"aws:meta.llama3-70b-instruct-v1:0\"\n", - "\n", + "# print(os.environ['AWS_SECRET_ACCESS_KEY'])\n", + "# print(os.environ['AWS_ACCESS_KEY_ID'])\n", + "# print(os.environ['AWS_REGION'])\n", + "aws_bedrock_llama3_8b = \"aws-bedrock:meta.llama3-1-8b-instruct-v1:0\"\n", "response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n", - "\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7e46c20a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Arrr, listen close me hearties! Here be a joke for ye:\n", + "\n", + "Why did Captain Jack Sparrow go to the doctor?\n", + "\n", + "Because he had a bit o' a \"crabby\" day! (get it? crabby? like a crustacean, but also feeling grumpy? Ah, never mind, matey, ye landlubbers wouldn't understand...\n" + ] + } + ], + "source": [ + "client2 = ai.Client({\"azure\" : {\n", + " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", + "}});\n", + "azure_model = \"azure:aisuite-Meta-Llama-3-8B-Inst\"\n", + "response = client2.chat.completions.create(model=azure_model, messages=messages)\n", "print(response.choices[0].message.content)" ] }, @@ -197,37 +242,6 @@ "print(response.choices[0].message.content)" ] }, - { - "cell_type": "code", - "execution_count": 4, - "id": "adebd2f0b578a909", - "metadata": { - "ExecuteTime": { - "end_time": "2024-07-04T15:31:25.060689Z", - "start_time": "2024-07-04T15:31:16.131205Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Arrr, me bucko, 'ere be a jolly jest fer ye!\n", - "\n", - "What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n", - "\n", - "Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n" - ] - } - ], - "source": [ - "anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n", - "\n", - "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", - "\n", - "print(response.choices[0].message.content)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -263,19 +277,10 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "611210a4dc92845f", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Why did the pirate go to the seafood restaurant? \n", - "Because he heard they had some great fish tales! Arrr!\n" - ] - } - ], + "outputs": [], "source": [ "openai_gpt35 = \"openai:gpt-3.5-turbo\"\n", "\n", @@ -301,7 +306,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 1232a37d..7884a7fe 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,45 +1,128 @@ -import pytest -from aisuite.client.client import Client, AnthropicInterface +import unittest +from unittest.mock import patch +from aisuite import Client +from aisuite import ProviderNames -def test_get_provider_interface_with_new_instance(): - """Test that get_provider_interface creates a new instance of the interface.""" - client = Client() - interface, model_name = client.get_provider_interface("anthropic:some-model:v1") - assert isinstance(interface, AnthropicInterface) - assert model_name == "some-model:v1" - assert client.all_interfaces["anthropic"] == interface +class TestClient(unittest.TestCase): + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + @patch( + "aisuite.providers.aws_bedrock_provider.AWSBedrockProvider.chat_completions_create" + ) + @patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") + @patch( + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" + ) + def test_client_chat_completions( + self, mock_anthropic, mock_azure, mock_bedrock, mock_openai + ): + # Mock responses from providers + mock_openai.return_value = "OpenAI Response" + mock_bedrock.return_value = "AWS Bedrock Response" + mock_azure.return_value = "Azure Response" + mock_anthropic.return_value = "Anthropic Response" -def test_get_provider_interface_with_existing_instance(): - """Test that get_provider_interface returns an existing instance of the interface, if already created.""" - client = Client() + # Provider configurations + provider_configs = { + ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, + ProviderNames.AWS_BEDROCK: { + "aws_access_key": "test_aws_access_key", + "aws_secret_key": "test_aws_secret_key", + "aws_session_token": "test_aws_session_token", + "aws_region": "us-west-2", + }, + ProviderNames.AZURE: { + "api_key": "azure-api-key", + }, + } - # New interface instance - new_instance, _ = client.get_provider_interface("anthropic:some-model:v2") + # Initialize the client + client = Client() + client.configure(provider_configs) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + ] - # Call twice, get same instance back - same_instance, _ = client.get_provider_interface("anthropic:some-model:v2") + # Test OpenAI model + open_ai_model = ProviderNames.OPENAI + ":" + "gpt-4o" + openai_response = client.chat.completions.create( + open_ai_model, messages=messages + ) + self.assertEqual(openai_response, "OpenAI Response") + mock_openai.assert_called_once() - assert new_instance is same_instance + # Test AWS Bedrock model + bedrock_model = ProviderNames.AWS_BEDROCK + ":" + "claude-v3" + bedrock_response = client.chat.completions.create( + bedrock_model, messages=messages + ) + self.assertEqual(bedrock_response, "AWS Bedrock Response") + mock_bedrock.assert_called_once() + azure_model = ProviderNames.AZURE + ":" + "azure-model" + azure_response = client.chat.completions.create(azure_model, messages=messages) + self.assertEqual(azure_response, "Azure Response") + mock_azure.assert_called_once() -def test_get_provider_interface_with_invalid_format(): - client = Client() + anthropic_model = ProviderNames.ANTHROPIC + ":" + "anthropic-model" + anthropic_response = client.chat.completions.create( + anthropic_model, messages=messages + ) + self.assertEqual(anthropic_response, "Anthropic Response") + mock_anthropic.assert_called_once() - with pytest.raises(ValueError) as exc_info: - client.get_provider_interface("invalid-model-no-colon") + # Test that new instances of Completion are not created each time we make an inference call. + compl_instance = client.chat.completions + next_compl_instance = client.chat.completions + assert compl_instance is next_compl_instance - assert "Expected ':' in model identifier" in str(exc_info.value) + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + def test_invalid_provider_in_client_config(self, mock_openai): + # Testing an invalid provider name in the configuration + invalid_provider_configs = { + "INVALID_PROVIDER": {"api_key": "invalid_api_key"}, + } + # Expect ValueError when initializing Client with invalid provider + with self.assertRaises(ValueError) as context: + client = Client(invalid_provider_configs) -def test_get_provider_interface_with_unknown_interface(): - client = Client() + # Verify the error message + self.assertIn( + "Provider INVALID_PROVIDER is not a valid provider", + str(context.exception), + ) - with pytest.raises(Exception) as exc_info: - client.get_provider_interface("unknown-interface:some-model") + @patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create") + def test_invalid_model_format_in_create(self, mock_openai): + # Valid provider configurations + provider_configs = { + ProviderNames.OPENAI: {"api_key": "test_openai_api_key"}, + } - assert ( - "Could not find factory to create interface for provider 'unknown-interface'" - in str(exc_info.value) - ) + # Initialize the client with valid provider + client = Client(provider_configs) + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format + with self.assertRaises(ValueError) as context: + client.chat.completions.create(invalid_model, messages=messages) + + # Verify the error message + self.assertIn( + "Invalid model format. Expected 'provider:model'", str(context.exception) + ) + + +if __name__ == "__main__": + unittest.main()