Skip to content

Commit 510b881

Browse files
committed
Allowing provider name as string in config()
Changed client.ipynb
1 parent 7d6d3a4 commit 510b881

File tree

3 files changed

+108
-96
lines changed

3 files changed

+108
-96
lines changed

aisuite/client.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,33 @@ def __init__(self, provider_configs: dict = {}):
99
"""
1010
self.providers = {}
1111
self.provider_configs = provider_configs
12-
for provider_key, config in provider_configs.items():
13-
# Check if the provider key is a valid ProviderNames enum
14-
if not isinstance(provider_key, ProviderNames):
15-
raise ValueError(
16-
f"Provider {provider_key} is not a valid ProviderNames enum"
17-
)
18-
# Store the value of the enum in the providers dictionary
12+
self._chat = None
13+
self._initialize_providers()
14+
15+
def _initialize_providers(self):
16+
"""Helper method to initialize or update providers."""
17+
for provider_key, config in self.provider_configs.items():
18+
provider_key = self._validate_provider_key(provider_key)
1919
self.providers[provider_key.value] = ProviderFactory.create_provider(
2020
provider_key, config
2121
)
2222

23-
self._chat = None
23+
def _validate_provider_key(self, provider_key):
24+
"""
25+
Validate if the provider key is part of ProviderNames enum.
26+
Allow strings as well and convert them to ProviderNames.
27+
"""
28+
if isinstance(provider_key, str):
29+
if provider_key not in ProviderNames._value2member_map_:
30+
raise ValueError(f"Provider {provider_key} is not a valid provider")
31+
return ProviderNames(provider_key)
32+
33+
if isinstance(provider_key, ProviderNames):
34+
return provider_key
35+
36+
raise ValueError(
37+
f"Provider {provider_key} should either be a string or enum ProviderNames"
38+
)
2439

2540
def configure(self, provider_configs: dict = None):
2641
"""
@@ -30,15 +45,7 @@ def configure(self, provider_configs: dict = None):
3045
return
3146

3247
self.provider_configs.update(provider_configs)
33-
34-
for provider_key, config in self.provider_configs.items():
35-
if not isinstance(provider_key, ProviderNames):
36-
raise ValueError(
37-
f"Provider {provider_key} is not a valid ProviderNames enum"
38-
)
39-
self.providers[provider_key.value] = ProviderFactory.create_provider(
40-
provider_key, config
41-
)
48+
self._initialize_providers() # NOTE: This will override existing provider instances.
4249

4350
@property
4451
def chat(self):

examples/client.ipynb

Lines changed: 83 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -18,56 +18,71 @@
1818
},
1919
{
2020
"cell_type": "code",
21-
"execution_count": 22,
21+
"execution_count": 1,
2222
"id": "initial_id",
2323
"metadata": {
2424
"ExecuteTime": {
2525
"end_time": "2024-07-04T15:30:02.064319Z",
2626
"start_time": "2024-07-04T15:30:02.051986Z"
2727
}
2828
},
29-
"outputs": [
30-
{
31-
"data": {
32-
"text/plain": [
33-
"True"
34-
]
35-
},
36-
"execution_count": 22,
37-
"metadata": {},
38-
"output_type": "execute_result"
39-
}
40-
],
29+
"outputs": [],
4130
"source": [
4231
"import sys\n",
43-
"sys.path.append('../../aisuite')\n",
44-
"\n",
4532
"from dotenv import load_dotenv, find_dotenv\n",
4633
"\n",
47-
"load_dotenv(find_dotenv())"
34+
"sys.path.append('../../aisuite')"
4835
]
4936
},
5037
{
5138
"cell_type": "code",
52-
"execution_count": 2,
53-
"id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4",
39+
"execution_count": 3,
40+
"id": "f75736ee",
5441
"metadata": {},
5542
"outputs": [],
5643
"source": [
57-
"import os \n",
44+
"import os\n",
45+
"def configure_environment(additional_env_vars=None):\n",
46+
" \"\"\"\n",
47+
" Load environment variables from .env file and apply any additional variables.\n",
48+
" :param additional_env_vars: A dictionary of additional environment variables to apply.\n",
49+
" \"\"\"\n",
50+
" # Load from .env file if available\n",
51+
" load_dotenv(find_dotenv())\n",
52+
"\n",
53+
" # Apply additional environment variables\n",
54+
" if additional_env_vars:\n",
55+
" for key, value in additional_env_vars.items():\n",
56+
" os.environ[key] = value\n",
5857
"\n",
59-
"os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n",
60-
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
61-
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n",
62-
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n",
63-
"os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings\n",
64-
"os.environ['AWS_ACCESS_KEY_ID'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home\n",
65-
"os.environ['AWS_SECRET_ACCESS_KEY'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home"
58+
"# Define additional API keys and AWS credentials\n",
59+
"additional_keys = {\n",
60+
" 'GROQ_API_KEY': 'xxx',\n",
61+
" 'FIREWORKS_API_KEY': 'xxx', \n",
62+
" 'REPLICATE_API_KEY': 'xxx', \n",
63+
" 'TOGETHER_API_KEY': 'xxx', \n",
64+
" 'OCTO_API_KEY': 'xxx',\n",
65+
" 'AWS_ACCESS_KEY_ID': 'xxx',\n",
66+
" 'AWS_SECRET_ACCESS_KEY': 'xxx',\n",
67+
"}\n",
68+
"\n",
69+
"# Configure environment\n",
70+
"configure_environment(additional_env_vars=additional_keys)"
6671
]
6772
},
6873
{
6974
"cell_type": "code",
70-
"execution_count": 3,
75+
"execution_count": null,
76+
"id": "744c5c15",
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"print(os.environ[\"AWS_SECRET_ACCESS_KEY\"])"
81+
]
82+
},
83+
{
84+
"cell_type": "code",
85+
"execution_count": 4,
7186
"id": "4de3a24f",
7287
"metadata": {
7388
"ExecuteTime": {
@@ -80,21 +95,23 @@
8095
"import aisuite as ai\n",
8196
"\n",
8297
"client = ai.Client()\n",
83-
"\n",
8498
"messages = [\n",
8599
" {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n",
86-
" {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n",
100+
" {\"role\": \"user\", \"content\": \"Tell me a joke about Captain Jack Sparrow\"},\n",
87101
"]"
88102
]
89103
},
90104
{
91105
"cell_type": "code",
92106
"execution_count": null,
93-
"id": "1ffe9a49-638e-4304-b9de-49ee21d9ac8d",
107+
"id": "520a6879",
94108
"metadata": {},
95109
"outputs": [],
96110
"source": [
97-
"#!pip install boto3"
111+
"# print(os.environ[\"ANTHROPIC_API_KEY\"])\n",
112+
"anthropic_claude_3_opus = \"anthropic:claude-3-5-sonnet-20240620\"\n",
113+
"response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
114+
"print(response.choices[0].message.content)"
98115
]
99116
},
100117
{
@@ -104,11 +121,39 @@
104121
"metadata": {},
105122
"outputs": [],
106123
"source": [
107-
"aws_bedrock_llama3_8b = \"aws:meta.llama3-8b-instruct-v1:0\"\n",
108-
"#aws_bedrock_llama3_8b = \"aws:meta.llama3-70b-instruct-v1:0\"\n",
109-
"\n",
124+
"# print(os.environ['AWS_SECRET_ACCESS_KEY'])\n",
125+
"# print(os.environ['AWS_ACCESS_KEY_ID'])\n",
126+
"# print(os.environ['AWS_REGION'])\n",
127+
"aws_bedrock_llama3_8b = \"aws-bedrock:meta.llama3-1-8b-instruct-v1:0\"\n",
110128
"response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n",
111-
"\n",
129+
"print(response.choices[0].message.content)"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": 5,
135+
"id": "7e46c20a",
136+
"metadata": {},
137+
"outputs": [
138+
{
139+
"name": "stdout",
140+
"output_type": "stream",
141+
"text": [
142+
"<aisuite.framework.chat_completion_response.ChatCompletionResponse object at 0x10944bdd0>\n",
143+
"Arrr, listen close me hearties! Here be a joke for ye:\n",
144+
"\n",
145+
"Why did Captain Jack Sparrow go to the doctor?\n",
146+
"\n",
147+
"Because he had a bit o' a \"crabby\" day! (get it? crabby? like a crustacean, but also feeling grumpy? Ah, never mind, matey, ye landlubbers wouldn't understand...\n"
148+
]
149+
}
150+
],
151+
"source": [
152+
"client2 = ai.Client({\"azure\" : {\n",
153+
" \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
154+
"}});\n",
155+
"azure_model = \"azure:aisuite-Meta-Llama-3-8B-Inst\"\n",
156+
"response = client2.chat.completions.create(model=azure_model, messages=messages)\n",
112157
"print(response.choices[0].message.content)"
113158
]
114159
},
@@ -197,37 +242,6 @@
197242
"print(response.choices[0].message.content)"
198243
]
199244
},
200-
{
201-
"cell_type": "code",
202-
"execution_count": 4,
203-
"id": "adebd2f0b578a909",
204-
"metadata": {
205-
"ExecuteTime": {
206-
"end_time": "2024-07-04T15:31:25.060689Z",
207-
"start_time": "2024-07-04T15:31:16.131205Z"
208-
}
209-
},
210-
"outputs": [
211-
{
212-
"name": "stdout",
213-
"output_type": "stream",
214-
"text": [
215-
"Arrr, me bucko, 'ere be a jolly jest fer ye!\n",
216-
"\n",
217-
"What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n",
218-
"\n",
219-
"Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n"
220-
]
221-
}
222-
],
223-
"source": [
224-
"anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n",
225-
"\n",
226-
"response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
227-
"\n",
228-
"print(response.choices[0].message.content)"
229-
]
230-
},
231245
{
232246
"cell_type": "code",
233247
"execution_count": null,
@@ -263,19 +277,10 @@
263277
},
264278
{
265279
"cell_type": "code",
266-
"execution_count": 23,
280+
"execution_count": null,
267281
"id": "611210a4dc92845f",
268282
"metadata": {},
269-
"outputs": [
270-
{
271-
"name": "stdout",
272-
"output_type": "stream",
273-
"text": [
274-
"Why did the pirate go to the seafood restaurant? \n",
275-
"Because he heard they had some great fish tales! Arrr!\n"
276-
]
277-
}
278-
],
283+
"outputs": [],
279284
"source": [
280285
"openai_gpt35 = \"openai:gpt-3.5-turbo\"\n",
281286
"\n",
@@ -301,7 +306,7 @@
301306
"name": "python",
302307
"nbconvert_exporter": "python",
303308
"pygments_lexer": "ipython3",
304-
"version": "3.10.14"
309+
"version": "3.12.4"
305310
}
306311
},
307312
"nbformat": 4,

tests/client/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_invalid_provider_in_client_config(self, mock_openai):
9191

9292
# Verify the error message
9393
self.assertIn(
94-
"Provider INVALID_PROVIDER is not a valid ProviderNames enum",
94+
"Provider INVALID_PROVIDER is not a valid provider",
9595
str(context.exception),
9696
)
9797

0 commit comments

Comments
 (0)