|
40 | 40 | RoleType,
|
41 | 41 | TaskType,
|
42 | 42 | )
|
| 43 | +from camel.utils.async_func import sync_funcs_to_async |
43 | 44 |
|
44 | 45 | parametrize = pytest.mark.parametrize(
|
45 | 46 | 'model',
|
@@ -420,6 +421,52 @@ def test_tool_calling_sync():
|
420 | 421 | assert tool_calls[0].result == 16
|
421 | 422 |
|
422 | 423 |
|
| 424 | +@pytest.mark.model_backend |
| 425 | +@pytest.mark.asyncio |
| 426 | +async def test_tool_calling_math_async(): |
| 427 | + system_message = BaseMessage( |
| 428 | + role_name="assistant", |
| 429 | + role_type=RoleType.ASSISTANT, |
| 430 | + meta_dict=None, |
| 431 | + content="You are a help assistant.", |
| 432 | + ) |
| 433 | + math_funcs = sync_funcs_to_async(MATH_FUNCS) |
| 434 | + model_config = ChatGPTConfig(tools=[*math_funcs]) |
| 435 | + model = ModelFactory.create( |
| 436 | + model_platform=ModelPlatformType.OPENAI, |
| 437 | + model_type=ModelType.GPT_3_5_TURBO, |
| 438 | + model_config_dict=model_config.__dict__, |
| 439 | + ) |
| 440 | + agent = ChatAgent( |
| 441 | + system_message=system_message, |
| 442 | + model=model, |
| 443 | + tools=math_funcs, |
| 444 | + ) |
| 445 | + |
| 446 | + ref_funcs = math_funcs |
| 447 | + |
| 448 | + assert len(agent.func_dict) == len(ref_funcs) |
| 449 | + |
| 450 | + user_msg = BaseMessage( |
| 451 | + role_name="User", |
| 452 | + role_type=RoleType.USER, |
| 453 | + meta_dict=dict(), |
| 454 | + content="Calculate the result of: 2*8-10.", |
| 455 | + ) |
| 456 | + agent_response = await agent.step_async(user_msg) |
| 457 | + |
| 458 | + tool_calls: List[FunctionCallingRecord] = agent_response.info['tool_calls'] |
| 459 | + for called_func in tool_calls: |
| 460 | + print(str(called_func)) |
| 461 | + |
| 462 | + assert len(tool_calls) > 0 |
| 463 | + assert str(tool_calls[0]).startswith("Function Execution") |
| 464 | + |
| 465 | + assert tool_calls[0].func_name == "mul" |
| 466 | + assert tool_calls[0].args == {"a": 2, "b": 8} |
| 467 | + assert tool_calls[0].result == 16 |
| 468 | + |
| 469 | + |
423 | 470 | @pytest.mark.model_backend
|
424 | 471 | @pytest.mark.asyncio
|
425 | 472 | async def test_tool_calling_async():
|
|
0 commit comments