Skip to content

Commit

Permalink
feat: Update function call result message format (#1436)
Browse files Browse the repository at this point in the history
Co-authored-by: Zack <[email protected]>
Co-authored-by: Wendong <[email protected]>
Co-authored-by: Wendong-Fan <[email protected]>
  • Loading branch information
4 people authored Jan 19, 2025
1 parent 9f74dbb commit bfdc89f
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 81 deletions.
20 changes: 17 additions & 3 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ class FunctionCallingRecord(BaseModel):
args (Dict[str, Any]): The dictionary of arguments passed to
the function.
result (Any): The execution result of calling this function.
tool_call_id (str): The ID of the tool call, if available.
"""

func_name: str
args: Dict[str, Any]
result: Any
tool_call_id: str

def __str__(self) -> str:
r"""Overridden version of the string function.
Expand All @@ -109,7 +111,7 @@ def __str__(self) -> str:
return (
f"Function Execution: {self.func_name}\n"
f"\tArgs: {self.args}\n"
f"\tResult: {self.result}"
f"\tResult: {self.result}\n"
)

def as_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -1384,6 +1386,7 @@ def _step_tool_call(

tool = self.tool_dict[func_name]
result = tool(**args)
tool_call_id = choice.message.tool_calls[0].id

assist_msg = FunctionCallingMessage(
role_name=self.role_name,
Expand All @@ -1392,6 +1395,7 @@ def _step_tool_call(
content="",
func_name=func_name,
args=args,
tool_call_id=tool_call_id,
)
func_msg = FunctionCallingMessage(
role_name=self.role_name,
Expand All @@ -1400,11 +1404,15 @@ def _step_tool_call(
content="",
func_name=func_name,
result=result,
tool_call_id=tool_call_id,
)

# Record information about this function call
func_record = FunctionCallingRecord(
func_name=func_name, args=args, result=result
func_name=func_name,
args=args,
result=result,
tool_call_id=tool_call_id,
)
return assist_msg, func_msg, func_record

Expand Down Expand Up @@ -1448,6 +1456,7 @@ async def step_tool_call_async(
args = json.loads(choice.message.tool_calls[0].function.arguments)
tool = self.tool_dict[func_name]
result = await tool(**args)
tool_call_id = choice.message.tool_calls[0].id

assist_msg = FunctionCallingMessage(
role_name=self.role_name,
Expand All @@ -1456,6 +1465,7 @@ async def step_tool_call_async(
content="",
func_name=func_name,
args=args,
tool_call_id=tool_call_id,
)
func_msg = FunctionCallingMessage(
role_name=self.role_name,
Expand All @@ -1464,11 +1474,15 @@ async def step_tool_call_async(
content="",
func_name=func_name,
result=result,
tool_call_id=tool_call_id,
)

# Record information about this function call
func_record = FunctionCallingRecord(
func_name=func_name, args=args, result=result
func_name=func_name,
args=args,
result=result,
tool_call_id=tool_call_id,
)
return assist_msg, func_msg, func_record

Expand Down
10 changes: 5 additions & 5 deletions camel/data_collector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
self,
id: UUID,
name: str,
role: Literal["user", "assistant", "system", "function"],
role: Literal["user", "assistant", "system", "tool"],
message: Optional[str] = None,
function_call: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -52,7 +52,7 @@ def __init__(
ValueError: If neither message nor function call is provided.
"""
if role not in ["user", "assistant", "system", "function"]:
if role not in ["user", "assistant", "system", "tool"]:
raise ValueError(f"Role {role} not supported")
if role == "system" and function_call:
raise ValueError("System role cannot have function call")
Expand Down Expand Up @@ -82,7 +82,7 @@ def from_context(name, context: Dict[str, Any]) -> "CollectorData":
name=name,
role=context["role"],
message=context["content"],
function_call=context.get("function_call", None),
function_call=context.get("tool_calls", None),
)


Expand All @@ -98,15 +98,15 @@ def __init__(self) -> None:

def step(
self,
role: Literal["user", "assistant", "system", "function"],
role: Literal["user", "assistant", "system", "tool"],
name: Optional[str] = None,
message: Optional[str] = None,
function_call: Optional[Dict[str, Any]] = None,
) -> Self:
r"""Record a message.
Args:
role (Literal["user", "assistant", "system", "function"]):
role (Literal["user", "assistant", "system", "tool"]):
The role of the message.
name (Optional[str], optional): The name of the agent.
(default: :obj:`None`)
Expand Down
4 changes: 2 additions & 2 deletions camel/data_collector/sharegpt_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def convert(self) -> Dict[str, Any]:
conversations.append(
{"from": "gpt", "value": message.message}
)
elif role == "function":
elif role == "function" or role == "tool":
conversations.append(
{
"from": "observation",
Expand Down Expand Up @@ -182,7 +182,7 @@ def llm_convert(
if message.function_call:
context.append(prefix + json.dumps(message.function_call))

elif role == "function":
elif role == "function" or role == "tool":
context.append(prefix + json.dumps(message.message)) # type: ignore[attr-defined]
else:
context.append(prefix + str(message.message))
Expand Down
14 changes: 10 additions & 4 deletions camel/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Union

from camel.types import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)

Expand All @@ -32,9 +34,13 @@
)

OpenAISystemMessage = ChatCompletionSystemMessageParam
OpenAIAssistantMessage = ChatCompletionAssistantMessageParam
OpenAIAssistantMessage = Union[
ChatCompletionAssistantMessageParam,
ChatCompletionToolMessageParam,
]
OpenAIUserMessage = ChatCompletionUserMessageParam
OpenAIFunctionMessage = ChatCompletionFunctionMessageParam
OpenAIToolMessageParam = ChatCompletionToolMessageParam

OpenAIMessage = ChatCompletionMessageParam


Expand All @@ -45,7 +51,7 @@
'OpenAISystemMessage',
'OpenAIAssistantMessage',
'OpenAIUserMessage',
'OpenAIFunctionMessage',
'OpenAIToolMessageParam',
'OpenAIMessage',
'FunctionCallFormatter',
'HermesFunctionFormatter',
Expand Down
5 changes: 5 additions & 0 deletions camel/messages/conversion/conversation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def validate_conversation_flow(self) -> 'ShareGPTConversation':
for i in range(1, len(messages)):
curr, prev = messages[i], messages[i - 1]

print("@@@@")
print(curr)
print(prev)
print("@@@@")

if curr.from_ == "tool":
if prev.from_ != "gpt" or "<tool_call>" not in prev.value:
raise ValueError(
Expand Down
52 changes: 30 additions & 22 deletions camel/messages/func_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
from dataclasses import dataclass
from typing import Any, Dict, Optional

from camel.messages import (
BaseMessage,
HermesFunctionFormatter,
OpenAIAssistantMessage,
OpenAIFunctionMessage,
OpenAIMessage,
OpenAIToolMessageParam,
)
from camel.messages.conversion import (
ShareGPTMessage,
Expand All @@ -44,11 +45,14 @@ class FunctionCallingMessage(BaseMessage):
function. (default: :obj:`None`)
result (Optional[Any]): The result of function execution.
(default: :obj:`None`)
tool_call_id (Optional[str]): The ID of the tool call, if available.
(default: :obj:`None`)
"""

func_name: Optional[str] = None
args: Optional[Dict] = None
result: Optional[Any] = None
tool_call_id: Optional[str] = None

def to_openai_message(
self,
Expand All @@ -66,7 +70,7 @@ def to_openai_message(
if role_at_backend == OpenAIBackendRole.ASSISTANT:
return self.to_openai_assistant_message()
elif role_at_backend == OpenAIBackendRole.FUNCTION:
return self.to_openai_function_message()
return self.to_openai_tool_message()
else:
raise ValueError(f"Unsupported role: {role_at_backend}.")

Expand Down Expand Up @@ -120,36 +124,40 @@ def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
" due to missing function name or arguments."
)

msg_dict: OpenAIAssistantMessage = {
return {
"role": "assistant",
"content": self.content,
"function_call": {
"name": self.func_name,
"arguments": str(self.args),
},
"content": self.content or "",
"tool_calls": [
{
"id": self.tool_call_id or "",
"type": "function",
"function": {
"name": self.func_name,
"arguments": json.dumps(self.args),
},
}
],
}

return msg_dict

def to_openai_function_message(self) -> OpenAIFunctionMessage:
r"""Converts the message to an :obj:`OpenAIMessage` object
with the role being "function".
def to_openai_tool_message(self) -> OpenAIToolMessageParam:
r"""Converts the message to an :obj:`OpenAIToolMessageParam` object
with the role being "tool".
Returns:
OpenAIMessage: The converted :obj:`OpenAIMessage` object
with its role being "function".
OpenAIToolMessageParam: The converted
:obj:`OpenAIToolMessageParam` object with its role being
"tool".
"""
if not self.func_name:
raise ValueError(
"Invalid request for converting into function message"
" due to missing function name."
)

result_content = {"result": {str(self.result)}}
msg_dict: OpenAIFunctionMessage = {
"role": "function",
"name": self.func_name,
"content": f'{result_content}',
}
result_content = json.dumps(self.result)

return msg_dict
return {
"role": "tool",
"content": result_content,
"tool_call_id": self.tool_call_id or "",
}
8 changes: 8 additions & 0 deletions camel/models/cohere_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def run(self, messages: List[OpenAIMessage]) -> ChatCompletion:

cohere_messages = self._to_cohere_chatmessage(messages)

# Removing 'strict': True from the dictionary for
# cohere client
if self.model_config_dict.get('tools') is not None:
for tool in self.model_config_dict.get('tools', []):
function_dict = tool.get('function', {})
if 'strict' in function_dict:
del function_dict['strict']

try:
response = self._client.chat(
messages=cohere_messages,
Expand Down
21 changes: 14 additions & 7 deletions camel/models/mistral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,25 @@ def _to_mistral_chatmessage(
new_messages = []
for msg in messages:
tool_id = uuid.uuid4().hex[:9]
tool_call_id = uuid.uuid4().hex[:9]
tool_call_id = msg.get("tool_call_id") or uuid.uuid4().hex[:9]

role = msg.get("role")
function_call = msg.get("function_call")
tool_calls = msg.get("tool_calls")
content = msg.get("content")

mistral_function_call = None
if function_call:
mistral_function_call = FunctionCall(
name=function_call.get("name"), # type: ignore[attr-defined]
arguments=function_call.get("arguments"), # type: ignore[attr-defined]
if tool_calls:
# Ensure tool_calls is treated as a list
tool_calls_list = (
tool_calls
if isinstance(tool_calls, list)
else [tool_calls]
)
for tool_call in tool_calls_list:
mistral_function_call = FunctionCall(
name=tool_call["function"].get("name"), # type: ignore[attr-defined]
arguments=tool_call["function"].get("arguments"), # type: ignore[attr-defined]
)

tool_calls = None
if mistral_function_call:
Expand All @@ -178,7 +185,7 @@ def _to_mistral_chatmessage(
new_messages.append(
ToolMessage(
content=content, # type: ignore[arg-type]
tool_call_id=tool_call_id,
tool_call_id=tool_call_id, # type: ignore[arg-type]
name=msg.get("name"), # type: ignore[arg-type]
)
)
Expand Down
6 changes: 4 additions & 2 deletions camel/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionFunctionMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
Choice,
CompletionUsage,
Expand All @@ -62,7 +63,8 @@
'ChatCompletionSystemMessageParam',
'ChatCompletionUserMessageParam',
'ChatCompletionAssistantMessageParam',
'ChatCompletionFunctionMessageParam',
'ChatCompletionToolMessageParam',
'ChatCompletionMessageToolCall',
'CompletionUsage',
'OpenAIImageType',
'OpenAIVisionDetailType',
Expand Down
Loading

0 comments on commit bfdc89f

Please sign in to comment.