Skip to content

Commit

Permalink
Allowing provider name as string in config()
Browse files Browse the repository at this point in the history
Changed client.ipynb
  • Loading branch information
rohit-rptless committed Sep 12, 2024
1 parent 7d6d3a4 commit 510b881
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 96 deletions.
41 changes: 24 additions & 17 deletions aisuite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,33 @@ def __init__(self, provider_configs: dict = {}):
"""
self.providers = {}
self.provider_configs = provider_configs
for provider_key, config in provider_configs.items():
# Check if the provider key is a valid ProviderNames enum
if not isinstance(provider_key, ProviderNames):
raise ValueError(
f"Provider {provider_key} is not a valid ProviderNames enum"
)
# Store the value of the enum in the providers dictionary
self._chat = None
self._initialize_providers()

def _initialize_providers(self):
"""Helper method to initialize or update providers."""
for provider_key, config in self.provider_configs.items():
provider_key = self._validate_provider_key(provider_key)
self.providers[provider_key.value] = ProviderFactory.create_provider(
provider_key, config
)

self._chat = None
def _validate_provider_key(self, provider_key):
"""
Validate if the provider key is part of ProviderNames enum.
Allow strings as well and convert them to ProviderNames.
"""
if isinstance(provider_key, str):
if provider_key not in ProviderNames._value2member_map_:
raise ValueError(f"Provider {provider_key} is not a valid provider")
return ProviderNames(provider_key)

if isinstance(provider_key, ProviderNames):
return provider_key

raise ValueError(
f"Provider {provider_key} should either be a string or enum ProviderNames"
)

def configure(self, provider_configs: dict = None):
"""
Expand All @@ -30,15 +45,7 @@ def configure(self, provider_configs: dict = None):
return

self.provider_configs.update(provider_configs)

for provider_key, config in self.provider_configs.items():
if not isinstance(provider_key, ProviderNames):
raise ValueError(
f"Provider {provider_key} is not a valid ProviderNames enum"
)
self.providers[provider_key.value] = ProviderFactory.create_provider(
provider_key, config
)
self._initialize_providers() # NOTE: This will override existing provider instances.

@property
def chat(self):
Expand Down
161 changes: 83 additions & 78 deletions examples/client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,71 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 1,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:30:02.064319Z",
"start_time": "2024-07-04T15:30:02.051986Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../../aisuite')\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"\n",
"load_dotenv(find_dotenv())"
"sys.path.append('../../aisuite')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4",
"execution_count": 3,
"id": "f75736ee",
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"import os\n",
"def configure_environment(additional_env_vars=None):\n",
" \"\"\"\n",
" Load environment variables from .env file and apply any additional variables.\n",
" :param additional_env_vars: A dictionary of additional environment variables to apply.\n",
" \"\"\"\n",
" # Load from .env file if available\n",
" load_dotenv(find_dotenv())\n",
"\n",
" # Apply additional environment variables\n",
" if additional_env_vars:\n",
" for key, value in additional_env_vars.items():\n",
" os.environ[key] = value\n",
"\n",
"os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n",
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n",
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n",
"os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings\n",
"os.environ['AWS_ACCESS_KEY_ID'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home\n",
"os.environ['AWS_SECRET_ACCESS_KEY'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home"
"# Define additional API keys and AWS credentials\n",
"additional_keys = {\n",
" 'GROQ_API_KEY': 'xxx',\n",
" 'FIREWORKS_API_KEY': 'xxx', \n",
" 'REPLICATE_API_KEY': 'xxx', \n",
" 'TOGETHER_API_KEY': 'xxx', \n",
" 'OCTO_API_KEY': 'xxx',\n",
" 'AWS_ACCESS_KEY_ID': 'xxx',\n",
" 'AWS_SECRET_ACCESS_KEY': 'xxx',\n",
"}\n",
"\n",
"# Configure environment\n",
"configure_environment(additional_env_vars=additional_keys)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "744c5c15",
"metadata": {},
"outputs": [],
"source": [
"print(os.environ[\"AWS_SECRET_ACCESS_KEY\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4de3a24f",
"metadata": {
"ExecuteTime": {
Expand All @@ -80,21 +95,23 @@
"import aisuite as ai\n",
"\n",
"client = ai.Client()\n",
"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me a joke about Captain Jack Sparrow\"},\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ffe9a49-638e-4304-b9de-49ee21d9ac8d",
"id": "520a6879",
"metadata": {},
"outputs": [],
"source": [
"#!pip install boto3"
"# print(os.environ[\"ANTHROPIC_API_KEY\"])\n",
"anthropic_claude_3_opus = \"anthropic:claude-3-5-sonnet-20240620\"\n",
"response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
"print(response.choices[0].message.content)"
]
},
{
Expand All @@ -104,11 +121,39 @@
"metadata": {},
"outputs": [],
"source": [
"aws_bedrock_llama3_8b = \"aws:meta.llama3-8b-instruct-v1:0\"\n",
"#aws_bedrock_llama3_8b = \"aws:meta.llama3-70b-instruct-v1:0\"\n",
"\n",
"# print(os.environ['AWS_SECRET_ACCESS_KEY'])\n",
"# print(os.environ['AWS_ACCESS_KEY_ID'])\n",
"# print(os.environ['AWS_REGION'])\n",
"aws_bedrock_llama3_8b = \"aws-bedrock:meta.llama3-1-8b-instruct-v1:0\"\n",
"response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7e46c20a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<aisuite.framework.chat_completion_response.ChatCompletionResponse object at 0x10944bdd0>\n",
"Arrr, listen close me hearties! Here be a joke for ye:\n",
"\n",
"Why did Captain Jack Sparrow go to the doctor?\n",
"\n",
"Because he had a bit o' a \"crabby\" day! (get it? crabby? like a crustacean, but also feeling grumpy? Ah, never mind, matey, ye landlubbers wouldn't understand...\n"
]
}
],
"source": [
"client2 = ai.Client({\"azure\" : {\n",
" \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
"}});\n",
"azure_model = \"azure:aisuite-Meta-Llama-3-8B-Inst\"\n",
"response = client2.chat.completions.create(model=azure_model, messages=messages)\n",
"print(response.choices[0].message.content)"
]
},
Expand Down Expand Up @@ -197,37 +242,6 @@
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "adebd2f0b578a909",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:31:25.060689Z",
"start_time": "2024-07-04T15:31:16.131205Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrr, me bucko, 'ere be a jolly jest fer ye!\n",
"\n",
"What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n",
"\n",
"Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n"
]
}
],
"source": [
"anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n",
"\n",
"response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -263,19 +277,10 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"id": "611210a4dc92845f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Why did the pirate go to the seafood restaurant? \n",
"Because he heard they had some great fish tales! Arrr!\n"
]
}
],
"outputs": [],
"source": [
"openai_gpt35 = \"openai:gpt-3.5-turbo\"\n",
"\n",
Expand All @@ -301,7 +306,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_invalid_provider_in_client_config(self, mock_openai):

# Verify the error message
self.assertIn(
"Provider INVALID_PROVIDER is not a valid ProviderNames enum",
"Provider INVALID_PROVIDER is not a valid provider",
str(context.exception),
)

Expand Down

0 comments on commit 510b881

Please sign in to comment.