Skip to content

Commit 12b058d

Browse files
feat: Add utils to convert sync functions to async functions (#690)
Co-authored-by: Wendong-Fan <[email protected]>
1 parent 8082e80 commit 12b058d

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

camel/utils/async_func.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
import asyncio
15+
from copy import deepcopy
16+
17+
from camel.functions.openai_function import OpenAIFunction
18+
19+
20+
def sync_funcs_to_async(funcs: list[OpenAIFunction]) -> list[OpenAIFunction]:
21+
r"""Convert a list of Python synchronous functions to Python
22+
asynchronous functions.
23+
24+
Args:
25+
funcs (list[OpenAIFunction]): List of Python synchronous
26+
functions in the :obj:`OpenAIFunction` format.
27+
28+
Returns:
29+
list[OpenAIFunction]: List of Python asynchronous functions
30+
in the :obj:`OpenAIFunction` format.
31+
"""
32+
async_funcs = []
33+
for func in funcs:
34+
sync_func = func.func
35+
36+
def async_callable(*args, **kwargs):
37+
return asyncio.to_thread(sync_func, *args, **kwargs) # noqa: B023
38+
39+
async_funcs.append(
40+
OpenAIFunction(async_callable, deepcopy(func.openai_tool_schema))
41+
)
42+
return async_funcs

test/agents/test_chat_agent.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
RoleType,
4141
TaskType,
4242
)
43+
from camel.utils.async_func import sync_funcs_to_async
4344

4445
parametrize = pytest.mark.parametrize(
4546
'model',
@@ -420,6 +421,52 @@ def test_tool_calling_sync():
420421
assert tool_calls[0].result == 16
421422

422423

424+
@pytest.mark.model_backend
425+
@pytest.mark.asyncio
426+
async def test_tool_calling_math_async():
427+
system_message = BaseMessage(
428+
role_name="assistant",
429+
role_type=RoleType.ASSISTANT,
430+
meta_dict=None,
431+
content="You are a help assistant.",
432+
)
433+
math_funcs = sync_funcs_to_async(MATH_FUNCS)
434+
model_config = ChatGPTConfig(tools=[*math_funcs])
435+
model = ModelFactory.create(
436+
model_platform=ModelPlatformType.OPENAI,
437+
model_type=ModelType.GPT_3_5_TURBO,
438+
model_config_dict=model_config.__dict__,
439+
)
440+
agent = ChatAgent(
441+
system_message=system_message,
442+
model=model,
443+
tools=math_funcs,
444+
)
445+
446+
ref_funcs = math_funcs
447+
448+
assert len(agent.func_dict) == len(ref_funcs)
449+
450+
user_msg = BaseMessage(
451+
role_name="User",
452+
role_type=RoleType.USER,
453+
meta_dict=dict(),
454+
content="Calculate the result of: 2*8-10.",
455+
)
456+
agent_response = await agent.step_async(user_msg)
457+
458+
tool_calls: List[FunctionCallingRecord] = agent_response.info['tool_calls']
459+
for called_func in tool_calls:
460+
print(str(called_func))
461+
462+
assert len(tool_calls) > 0
463+
assert str(tool_calls[0]).startswith("Function Execution")
464+
465+
assert tool_calls[0].func_name == "mul"
466+
assert tool_calls[0].args == {"a": 2, "b": 8}
467+
assert tool_calls[0].result == 16
468+
469+
423470
@pytest.mark.model_backend
424471
@pytest.mark.asyncio
425472
async def test_tool_calling_async():

0 commit comments

Comments
 (0)