Skip to content

Commit

Permalink
feat: Add utils to convert sync functions to async functions (#690)
Browse files Browse the repository at this point in the history
Co-authored-by: Wendong-Fan <[email protected]>
  • Loading branch information
zechengz and Wendong-Fan authored Jul 3, 2024
1 parent 8082e80 commit 12b058d
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
42 changes: 42 additions & 0 deletions camel/utils/async_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import asyncio
from copy import deepcopy

from camel.functions.openai_function import OpenAIFunction


def sync_funcs_to_async(funcs: list[OpenAIFunction]) -> list[OpenAIFunction]:
r"""Convert a list of Python synchronous functions to Python
asynchronous functions.
Args:
funcs (list[OpenAIFunction]): List of Python synchronous
functions in the :obj:`OpenAIFunction` format.
Returns:
list[OpenAIFunction]: List of Python asynchronous functions
in the :obj:`OpenAIFunction` format.
"""
async_funcs = []
for func in funcs:
sync_func = func.func

def async_callable(*args, **kwargs):
return asyncio.to_thread(sync_func, *args, **kwargs) # noqa: B023

async_funcs.append(
OpenAIFunction(async_callable, deepcopy(func.openai_tool_schema))
)
return async_funcs
47 changes: 47 additions & 0 deletions test/agents/test_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
RoleType,
TaskType,
)
from camel.utils.async_func import sync_funcs_to_async

parametrize = pytest.mark.parametrize(
'model',
Expand Down Expand Up @@ -420,6 +421,52 @@ def test_tool_calling_sync():
assert tool_calls[0].result == 16


@pytest.mark.model_backend
@pytest.mark.asyncio
async def test_tool_calling_math_async():
system_message = BaseMessage(
role_name="assistant",
role_type=RoleType.ASSISTANT,
meta_dict=None,
content="You are a help assistant.",
)
math_funcs = sync_funcs_to_async(MATH_FUNCS)
model_config = ChatGPTConfig(tools=[*math_funcs])
model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_3_5_TURBO,
model_config_dict=model_config.__dict__,
)
agent = ChatAgent(
system_message=system_message,
model=model,
tools=math_funcs,
)

ref_funcs = math_funcs

assert len(agent.func_dict) == len(ref_funcs)

user_msg = BaseMessage(
role_name="User",
role_type=RoleType.USER,
meta_dict=dict(),
content="Calculate the result of: 2*8-10.",
)
agent_response = await agent.step_async(user_msg)

tool_calls: List[FunctionCallingRecord] = agent_response.info['tool_calls']
for called_func in tool_calls:
print(str(called_func))

assert len(tool_calls) > 0
assert str(tool_calls[0]).startswith("Function Execution")

assert tool_calls[0].func_name == "mul"
assert tool_calls[0].args == {"a": 2, "b": 8}
assert tool_calls[0].result == 16


@pytest.mark.model_backend
@pytest.mark.asyncio
async def test_tool_calling_async():
Expand Down

0 comments on commit 12b058d

Please sign in to comment.