diff --git a/aisuite/client.py b/aisuite/client.py index e175cbe1..f6474096 100644 --- a/aisuite/client.py +++ b/aisuite/client.py @@ -9,18 +9,33 @@ def __init__(self, provider_configs: dict = {}): """ self.providers = {} self.provider_configs = provider_configs - for provider_key, config in provider_configs.items(): - # Check if the provider key is a valid ProviderNames enum - if not isinstance(provider_key, ProviderNames): - raise ValueError( - f"Provider {provider_key} is not a valid ProviderNames enum" - ) - # Store the value of the enum in the providers dictionary + self._chat = None + self._initialize_providers() + + def _initialize_providers(self): + """Helper method to initialize or update providers.""" + for provider_key, config in self.provider_configs.items(): + provider_key = self._validate_provider_key(provider_key) self.providers[provider_key.value] = ProviderFactory.create_provider( provider_key, config ) - self._chat = None + def _validate_provider_key(self, provider_key): + """ + Validate if the provider key is part of ProviderNames enum. + Allow strings as well and convert them to ProviderNames. + """ + if isinstance(provider_key, str): + if provider_key not in ProviderNames._value2member_map_: + raise ValueError(f"Provider {provider_key} is not a valid provider") + return ProviderNames(provider_key) + + if isinstance(provider_key, ProviderNames): + return provider_key + + raise ValueError( + f"Provider {provider_key} should either be a string or enum ProviderNames" + ) def configure(self, provider_configs: dict = None): """ @@ -30,15 +45,7 @@ def configure(self, provider_configs: dict = None): return self.provider_configs.update(provider_configs) - - for provider_key, config in self.provider_configs.items(): - if not isinstance(provider_key, ProviderNames): - raise ValueError( - f"Provider {provider_key} is not a valid ProviderNames enum" - ) - self.providers[provider_key.value] = ProviderFactory.create_provider( - provider_key, config - ) + self._initialize_providers() # NOTE: This will override existing provider instances. @property def chat(self): diff --git a/examples/client.ipynb b/examples/client.ipynb index a8049a37..f6578cd1 100644 --- a/examples/client.ipynb +++ b/examples/client.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 1, "id": "initial_id", "metadata": { "ExecuteTime": { @@ -26,48 +26,63 @@ "start_time": "2024-07-04T15:30:02.051986Z" } }, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import sys\n", - "sys.path.append('../../aisuite')\n", - "\n", "from dotenv import load_dotenv, find_dotenv\n", "\n", - "load_dotenv(find_dotenv())" + "sys.path.append('../../aisuite')" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4", + "execution_count": 3, + "id": "f75736ee", "metadata": {}, "outputs": [], "source": [ - "import os \n", + "import os\n", + "def configure_environment(additional_env_vars=None):\n", + " \"\"\"\n", + " Load environment variables from .env file and apply any additional variables.\n", + " :param additional_env_vars: A dictionary of additional environment variables to apply.\n", + " \"\"\"\n", + " # Load from .env file if available\n", + " load_dotenv(find_dotenv())\n", + "\n", + " # Apply additional environment variables\n", + " if additional_env_vars:\n", + " for key, value in additional_env_vars.items():\n", + " os.environ[key] = value\n", "\n", - "os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n", - "os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n", - "os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n", - "os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n", - "os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings\n", - "os.environ['AWS_ACCESS_KEY_ID'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home\n", - "os.environ['AWS_SECRET_ACCESS_KEY'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home" + "# Define additional API keys and AWS credentials\n", + "additional_keys = {\n", + " 'GROQ_API_KEY': 'xxx',\n", + " 'FIREWORKS_API_KEY': 'xxx', \n", + " 'REPLICATE_API_KEY': 'xxx', \n", + " 'TOGETHER_API_KEY': 'xxx', \n", + " 'OCTO_API_KEY': 'xxx',\n", + " 'AWS_ACCESS_KEY_ID': 'xxx',\n", + " 'AWS_SECRET_ACCESS_KEY': 'xxx',\n", + "}\n", + "\n", + "# Configure environment\n", + "configure_environment(additional_env_vars=additional_keys)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, + "id": "744c5c15", + "metadata": {}, + "outputs": [], + "source": [ + "print(os.environ[\"AWS_SECRET_ACCESS_KEY\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "id": "4de3a24f", "metadata": { "ExecuteTime": { @@ -80,21 +95,23 @@ "import aisuite as ai\n", "\n", "client = ai.Client()\n", - "\n", "messages = [\n", " {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n", - " {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n", + " {\"role\": \"user\", \"content\": \"Tell me a joke about Captain Jack Sparrow\"},\n", "]" ] }, { "cell_type": "code", "execution_count": null, - "id": "1ffe9a49-638e-4304-b9de-49ee21d9ac8d", + "id": "520a6879", "metadata": {}, "outputs": [], "source": [ - "#!pip install boto3" + "# print(os.environ[\"ANTHROPIC_API_KEY\"])\n", + "anthropic_claude_3_opus = \"anthropic:claude-3-5-sonnet-20240620\"\n", + "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", + "print(response.choices[0].message.content)" ] }, { @@ -104,11 +121,39 @@ "metadata": {}, "outputs": [], "source": [ - "aws_bedrock_llama3_8b = \"aws:meta.llama3-8b-instruct-v1:0\"\n", - "#aws_bedrock_llama3_8b = \"aws:meta.llama3-70b-instruct-v1:0\"\n", - "\n", + "# print(os.environ['AWS_SECRET_ACCESS_KEY'])\n", + "# print(os.environ['AWS_ACCESS_KEY_ID'])\n", + "# print(os.environ['AWS_REGION'])\n", + "aws_bedrock_llama3_8b = \"aws-bedrock:meta.llama3-1-8b-instruct-v1:0\"\n", "response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n", - "\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7e46c20a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Arrr, listen close me hearties! Here be a joke for ye:\n", + "\n", + "Why did Captain Jack Sparrow go to the doctor?\n", + "\n", + "Because he had a bit o' a \"crabby\" day! (get it? crabby? like a crustacean, but also feeling grumpy? Ah, never mind, matey, ye landlubbers wouldn't understand...\n" + ] + } + ], + "source": [ + "client2 = ai.Client({\"azure\" : {\n", + " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", + "}});\n", + "azure_model = \"azure:aisuite-Meta-Llama-3-8B-Inst\"\n", + "response = client2.chat.completions.create(model=azure_model, messages=messages)\n", "print(response.choices[0].message.content)" ] }, @@ -197,37 +242,6 @@ "print(response.choices[0].message.content)" ] }, - { - "cell_type": "code", - "execution_count": 4, - "id": "adebd2f0b578a909", - "metadata": { - "ExecuteTime": { - "end_time": "2024-07-04T15:31:25.060689Z", - "start_time": "2024-07-04T15:31:16.131205Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Arrr, me bucko, 'ere be a jolly jest fer ye!\n", - "\n", - "What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n", - "\n", - "Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n" - ] - } - ], - "source": [ - "anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n", - "\n", - "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", - "\n", - "print(response.choices[0].message.content)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -263,19 +277,10 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "611210a4dc92845f", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Why did the pirate go to the seafood restaurant? \n", - "Because he heard they had some great fish tales! Arrr!\n" - ] - } - ], + "outputs": [], "source": [ "openai_gpt35 = \"openai:gpt-3.5-turbo\"\n", "\n", @@ -301,7 +306,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 077974e6..7884a7fe 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -91,7 +91,7 @@ def test_invalid_provider_in_client_config(self, mock_openai): # Verify the error message self.assertIn( - "Provider INVALID_PROVIDER is not a valid ProviderNames enum", + "Provider INVALID_PROVIDER is not a valid provider", str(context.exception), )