@@ -171,6 +171,75 @@ def test_invalid_model_format_in_create(self, mock_openai):
171
171
"Invalid model format. Expected 'provider:model'" , str (context .exception )
172
172
)
173
173
174
+ def create_mock_provider_response (self , expected_response ):
175
+ """
176
+ Helper method to create a mock provider response.
177
+ """
178
+ mock_response = unittest .mock .Mock ()
179
+ mock_response .choices = [unittest .mock .Mock ()]
180
+ mock_response .choices [0 ].message .content = expected_response
181
+ return mock_response
182
+
183
+ @patch ("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create" )
184
+ def test_client_call_method (self , mock_openai ):
185
+ expected_response = "expected-text-response"
186
+ provider_response = self .create_mock_provider_response (expected_response )
187
+ mock_openai .return_value = provider_response
188
+
189
+ config = {
190
+ ProviderNames .OPENAI : {"api_key" : "test_openai_api_key" },
191
+ }
192
+
193
+ new_client = Client (config )
194
+
195
+ # Test __call__ method (without system message)
196
+ response = new_client ("test-user-prompt" , "openai:gpt-3.5-turbo" )
197
+ self .assertEqual (response , expected_response )
198
+ mock_openai .assert_called_once ()
199
+
200
+ mock_openai .reset_mock ()
201
+
202
+ # Test __call__ method (with system message)
203
+ response = new_client (
204
+ "test-user-prompt" ,
205
+ "openai:gpt-3.5-turbo" ,
206
+ system_message = "You are a helpful assistant." ,
207
+ )
208
+ self .assertEqual (response , expected_response )
209
+ mock_openai .assert_called_once ()
210
+
211
+ # Ensure the correct messages were passed to the provider
212
+ called_messages = mock_openai .call_args [0 ][1 ]
213
+ self .assertEqual (len (called_messages ), 2 )
214
+ self .assertEqual (called_messages [0 ]["role" ], "system" )
215
+ self .assertEqual (called_messages [0 ]["content" ], "You are a helpful assistant." )
216
+ self .assertEqual (called_messages [1 ]["role" ], "user" )
217
+ self .assertEqual (called_messages [1 ]["content" ], "test-user-prompt" )
218
+
219
+ @patch ("aisuite.providers.openai_provider.OpenAIProvider.chat_completions_create" )
220
+ def test_client_call_method_with_kwargs (self , mock_openai ):
221
+ expected_response = "expected-text-response"
222
+ provider_response = self .create_mock_provider_response (expected_response )
223
+ mock_openai .return_value = provider_response
224
+
225
+ config = {
226
+ ProviderNames .OPENAI : {"api_key" : "test_openai_api_key" },
227
+ }
228
+
229
+ new_client = Client (config )
230
+
231
+ # Test __call__ method with additional kwargs
232
+ response = new_client (
233
+ "test-user-prompt" , "openai:gpt-3.5-turbo" , max_tokens = 100 , temperature = 0.7
234
+ )
235
+ self .assertEqual (response , expected_response )
236
+ mock_openai .assert_called_once ()
237
+
238
+ # Ensure the additional kwargs were passed to the provider
239
+ _ , kwargs = mock_openai .call_args
240
+ self .assertEqual (kwargs .get ("max_tokens" ), 100 )
241
+ self .assertEqual (kwargs .get ("temperature" ), 0.7 )
242
+
174
243
175
244
if __name__ == "__main__" :
176
245
unittest .main ()
0 commit comments