Skip to content

Commit 30213a6

Browse files
committed
Add __call__ to Client
1 parent afe2af1 commit 30213a6

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

aisuite/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,27 @@ def __init__(self, provider_configs: dict = {}):
2626
self._chat = None
2727
self._initialize_providers()
2828

29+
def __call__(self, prompt: str, model: str, system_message: str = None, **kwargs):
30+
"""
31+
Call the client directly with a prompt and model.
32+
33+
Args:
34+
prompt (str): The user's prompt or question.
35+
model (str): The model to use in the format "provider:model".
36+
system_message (str, optional): A system message to set the context.
37+
**kwargs: Additional arguments to pass to the chat completion.
38+
39+
Returns:
40+
The response from the AI model.
41+
"""
42+
messages = []
43+
if system_message:
44+
messages.append({"role": "system", "content": system_message})
45+
messages.append({"role": "user", "content": prompt})
46+
47+
response = self.chat.completions.create(model, messages, **kwargs)
48+
return response.choices[0].message.content
49+
2950
def _initialize_providers(self):
3051
"""Helper method to initialize or update providers."""
3152
for provider_key, config in self.provider_configs.items():

tests/client/test_client.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,75 @@ def test_invalid_model_format_in_create(self, mock_openai):
171171
"Invalid model format. Expected 'provider:model'", str(context.exception)
172172
)
173173

174+
def create_mock_provider_response(self, expected_response):
175+
"""
176+
Helper method to create a mock provider response.
177+
"""
178+
mock_response = unittest.mock.Mock()
179+
mock_response.choices = [unittest.mock.Mock()]
180+
mock_response.choices[0].message.content = expected_response
181+
return mock_response
182+
183+
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
184+
def test_client_call_method(self, mock_openai):
185+
expected_response = "expected-text-response"
186+
provider_response = self.create_mock_provider_response(expected_response)
187+
mock_openai.return_value = provider_response
188+
189+
config = {
190+
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"},
191+
}
192+
193+
new_client = Client(config)
194+
195+
# Test __call__ method (without system message)
196+
response = new_client("test-user-prompt", "openai:gpt-3.5-turbo")
197+
self.assertEqual(response, expected_response)
198+
mock_openai.assert_called_once()
199+
200+
mock_openai.reset_mock()
201+
202+
# Test __call__ method (with system message)
203+
response = new_client(
204+
"test-user-prompt",
205+
"openai:gpt-3.5-turbo",
206+
system_message="You are a helpful assistant.",
207+
)
208+
self.assertEqual(response, expected_response)
209+
mock_openai.assert_called_once()
210+
211+
# Ensure the correct messages were passed to the provider
212+
called_messages = mock_openai.call_args[0][1]
213+
self.assertEqual(len(called_messages), 2)
214+
self.assertEqual(called_messages[0]["role"], "system")
215+
self.assertEqual(called_messages[0]["content"], "You are a helpful assistant.")
216+
self.assertEqual(called_messages[1]["role"], "user")
217+
self.assertEqual(called_messages[1]["content"], "test-user-prompt")
218+
219+
@patch("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create")
220+
def test_client_call_method_with_kwargs(self, mock_openai):
221+
expected_response = "expected-text-response"
222+
provider_response = self.create_mock_provider_response(expected_response)
223+
mock_openai.return_value = provider_response
224+
225+
config = {
226+
ProviderNames.OPENAI: {"api_key": "test_openai_api_key"},
227+
}
228+
229+
new_client = Client(config)
230+
231+
# Test __call__ method with additional kwargs
232+
response = new_client(
233+
"test-user-prompt", "openai:gpt-3.5-turbo", max_tokens=100, temperature=0.7
234+
)
235+
self.assertEqual(response, expected_response)
236+
mock_openai.assert_called_once()
237+
238+
# Ensure the additional kwargs were passed to the provider
239+
_, kwargs = mock_openai.call_args
240+
self.assertEqual(kwargs.get("max_tokens"), 100)
241+
self.assertEqual(kwargs.get("temperature"), 0.7)
242+
174243

175244
if __name__ == "__main__":
176245
unittest.main()

0 commit comments

Comments
 (0)