Skip to content

Commit bfdc89f

Browse files
ZackYuleZackWendong-Fan
authored
feat: Update function call result message format (#1436)
Co-authored-by: Zack <[email protected]> Co-authored-by: Wendong <[email protected]> Co-authored-by: Wendong-Fan <[email protected]>
1 parent 9f74dbb commit bfdc89f

File tree

15 files changed

+147
-81
lines changed

15 files changed

+147
-81
lines changed

camel/agents/chat_agent.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,13 @@ class FunctionCallingRecord(BaseModel):
9494
args (Dict[str, Any]): The dictionary of arguments passed to
9595
the function.
9696
result (Any): The execution result of calling this function.
97+
tool_call_id (str): The ID of the tool call, if available.
9798
"""
9899

99100
func_name: str
100101
args: Dict[str, Any]
101102
result: Any
103+
tool_call_id: str
102104

103105
def __str__(self) -> str:
104106
r"""Overridden version of the string function.
@@ -109,7 +111,7 @@ def __str__(self) -> str:
109111
return (
110112
f"Function Execution: {self.func_name}\n"
111113
f"\tArgs: {self.args}\n"
112-
f"\tResult: {self.result}"
114+
f"\tResult: {self.result}\n"
113115
)
114116

115117
def as_dict(self) -> dict[str, Any]:
@@ -1384,6 +1386,7 @@ def _step_tool_call(
13841386

13851387
tool = self.tool_dict[func_name]
13861388
result = tool(**args)
1389+
tool_call_id = choice.message.tool_calls[0].id
13871390

13881391
assist_msg = FunctionCallingMessage(
13891392
role_name=self.role_name,
@@ -1392,6 +1395,7 @@ def _step_tool_call(
13921395
content="",
13931396
func_name=func_name,
13941397
args=args,
1398+
tool_call_id=tool_call_id,
13951399
)
13961400
func_msg = FunctionCallingMessage(
13971401
role_name=self.role_name,
@@ -1400,11 +1404,15 @@ def _step_tool_call(
14001404
content="",
14011405
func_name=func_name,
14021406
result=result,
1407+
tool_call_id=tool_call_id,
14031408
)
14041409

14051410
# Record information about this function call
14061411
func_record = FunctionCallingRecord(
1407-
func_name=func_name, args=args, result=result
1412+
func_name=func_name,
1413+
args=args,
1414+
result=result,
1415+
tool_call_id=tool_call_id,
14081416
)
14091417
return assist_msg, func_msg, func_record
14101418

@@ -1448,6 +1456,7 @@ async def step_tool_call_async(
14481456
args = json.loads(choice.message.tool_calls[0].function.arguments)
14491457
tool = self.tool_dict[func_name]
14501458
result = await tool(**args)
1459+
tool_call_id = choice.message.tool_calls[0].id
14511460

14521461
assist_msg = FunctionCallingMessage(
14531462
role_name=self.role_name,
@@ -1456,6 +1465,7 @@ async def step_tool_call_async(
14561465
content="",
14571466
func_name=func_name,
14581467
args=args,
1468+
tool_call_id=tool_call_id,
14591469
)
14601470
func_msg = FunctionCallingMessage(
14611471
role_name=self.role_name,
@@ -1464,11 +1474,15 @@ async def step_tool_call_async(
14641474
content="",
14651475
func_name=func_name,
14661476
result=result,
1477+
tool_call_id=tool_call_id,
14671478
)
14681479

14691480
# Record information about this function call
14701481
func_record = FunctionCallingRecord(
1471-
func_name=func_name, args=args, result=result
1482+
func_name=func_name,
1483+
args=args,
1484+
result=result,
1485+
tool_call_id=tool_call_id,
14721486
)
14731487
return assist_msg, func_msg, func_record
14741488

camel/data_collector/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self,
2828
id: UUID,
2929
name: str,
30-
role: Literal["user", "assistant", "system", "function"],
30+
role: Literal["user", "assistant", "system", "tool"],
3131
message: Optional[str] = None,
3232
function_call: Optional[Dict[str, Any]] = None,
3333
) -> None:
@@ -52,7 +52,7 @@ def __init__(
5252
ValueError: If neither message nor function call is provided.
5353
5454
"""
55-
if role not in ["user", "assistant", "system", "function"]:
55+
if role not in ["user", "assistant", "system", "tool"]:
5656
raise ValueError(f"Role {role} not supported")
5757
if role == "system" and function_call:
5858
raise ValueError("System role cannot have function call")
@@ -82,7 +82,7 @@ def from_context(name, context: Dict[str, Any]) -> "CollectorData":
8282
name=name,
8383
role=context["role"],
8484
message=context["content"],
85-
function_call=context.get("function_call", None),
85+
function_call=context.get("tool_calls", None),
8686
)
8787

8888

@@ -98,15 +98,15 @@ def __init__(self) -> None:
9898

9999
def step(
100100
self,
101-
role: Literal["user", "assistant", "system", "function"],
101+
role: Literal["user", "assistant", "system", "tool"],
102102
name: Optional[str] = None,
103103
message: Optional[str] = None,
104104
function_call: Optional[Dict[str, Any]] = None,
105105
) -> Self:
106106
r"""Record a message.
107107
108108
Args:
109-
role (Literal["user", "assistant", "system", "function"]):
109+
role (Literal["user", "assistant", "system", "tool"]):
110110
The role of the message.
111111
name (Optional[str], optional): The name of the agent.
112112
(default: :obj:`None`)

camel/data_collector/sharegpt_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def convert(self) -> Dict[str, Any]:
131131
conversations.append(
132132
{"from": "gpt", "value": message.message}
133133
)
134-
elif role == "function":
134+
elif role == "function" or role == "tool":
135135
conversations.append(
136136
{
137137
"from": "observation",
@@ -182,7 +182,7 @@ def llm_convert(
182182
if message.function_call:
183183
context.append(prefix + json.dumps(message.function_call))
184184

185-
elif role == "function":
185+
elif role == "function" or role == "tool":
186186
context.append(prefix + json.dumps(message.message)) # type: ignore[attr-defined]
187187
else:
188188
context.append(prefix + str(message.message))

camel/messages/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
from typing import Union
15+
1416
from camel.types import (
1517
ChatCompletionAssistantMessageParam,
16-
ChatCompletionFunctionMessageParam,
1718
ChatCompletionMessageParam,
1819
ChatCompletionSystemMessageParam,
20+
ChatCompletionToolMessageParam,
1921
ChatCompletionUserMessageParam,
2022
)
2123

@@ -32,9 +34,13 @@
3234
)
3335

3436
OpenAISystemMessage = ChatCompletionSystemMessageParam
35-
OpenAIAssistantMessage = ChatCompletionAssistantMessageParam
37+
OpenAIAssistantMessage = Union[
38+
ChatCompletionAssistantMessageParam,
39+
ChatCompletionToolMessageParam,
40+
]
3641
OpenAIUserMessage = ChatCompletionUserMessageParam
37-
OpenAIFunctionMessage = ChatCompletionFunctionMessageParam
42+
OpenAIToolMessageParam = ChatCompletionToolMessageParam
43+
3844
OpenAIMessage = ChatCompletionMessageParam
3945

4046

@@ -45,7 +51,7 @@
4551
'OpenAISystemMessage',
4652
'OpenAIAssistantMessage',
4753
'OpenAIUserMessage',
48-
'OpenAIFunctionMessage',
54+
'OpenAIToolMessageParam',
4955
'OpenAIMessage',
5056
'FunctionCallFormatter',
5157
'HermesFunctionFormatter',

camel/messages/conversion/conversation_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def validate_conversation_flow(self) -> 'ShareGPTConversation':
6969
for i in range(1, len(messages)):
7070
curr, prev = messages[i], messages[i - 1]
7171

72+
print("@@@@")
73+
print(curr)
74+
print(prev)
75+
print("@@@@")
76+
7277
if curr.from_ == "tool":
7378
if prev.from_ != "gpt" or "<tool_call>" not in prev.value:
7479
raise ValueError(

camel/messages/func_message.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
import json
1415
from dataclasses import dataclass
1516
from typing import Any, Dict, Optional
1617

1718
from camel.messages import (
1819
BaseMessage,
1920
HermesFunctionFormatter,
2021
OpenAIAssistantMessage,
21-
OpenAIFunctionMessage,
2222
OpenAIMessage,
23+
OpenAIToolMessageParam,
2324
)
2425
from camel.messages.conversion import (
2526
ShareGPTMessage,
@@ -44,11 +45,14 @@ class FunctionCallingMessage(BaseMessage):
4445
function. (default: :obj:`None`)
4546
result (Optional[Any]): The result of function execution.
4647
(default: :obj:`None`)
48+
tool_call_id (Optional[str]): The ID of the tool call, if available.
49+
(default: :obj:`None`)
4750
"""
4851

4952
func_name: Optional[str] = None
5053
args: Optional[Dict] = None
5154
result: Optional[Any] = None
55+
tool_call_id: Optional[str] = None
5256

5357
def to_openai_message(
5458
self,
@@ -66,7 +70,7 @@ def to_openai_message(
6670
if role_at_backend == OpenAIBackendRole.ASSISTANT:
6771
return self.to_openai_assistant_message()
6872
elif role_at_backend == OpenAIBackendRole.FUNCTION:
69-
return self.to_openai_function_message()
73+
return self.to_openai_tool_message()
7074
else:
7175
raise ValueError(f"Unsupported role: {role_at_backend}.")
7276

@@ -120,36 +124,40 @@ def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
120124
" due to missing function name or arguments."
121125
)
122126

123-
msg_dict: OpenAIAssistantMessage = {
127+
return {
124128
"role": "assistant",
125-
"content": self.content,
126-
"function_call": {
127-
"name": self.func_name,
128-
"arguments": str(self.args),
129-
},
129+
"content": self.content or "",
130+
"tool_calls": [
131+
{
132+
"id": self.tool_call_id or "",
133+
"type": "function",
134+
"function": {
135+
"name": self.func_name,
136+
"arguments": json.dumps(self.args),
137+
},
138+
}
139+
],
130140
}
131141

132-
return msg_dict
133-
134-
def to_openai_function_message(self) -> OpenAIFunctionMessage:
135-
r"""Converts the message to an :obj:`OpenAIMessage` object
136-
with the role being "function".
142+
def to_openai_tool_message(self) -> OpenAIToolMessageParam:
143+
r"""Converts the message to an :obj:`OpenAIToolMessageParam` object
144+
with the role being "tool".
137145
138146
Returns:
139-
OpenAIMessage: The converted :obj:`OpenAIMessage` object
140-
with its role being "function".
147+
OpenAIToolMessageParam: The converted
148+
:obj:`OpenAIToolMessageParam` object with its role being
149+
"tool".
141150
"""
142151
if not self.func_name:
143152
raise ValueError(
144153
"Invalid request for converting into function message"
145154
" due to missing function name."
146155
)
147156

148-
result_content = {"result": {str(self.result)}}
149-
msg_dict: OpenAIFunctionMessage = {
150-
"role": "function",
151-
"name": self.func_name,
152-
"content": f'{result_content}',
153-
}
157+
result_content = json.dumps(self.result)
154158

155-
return msg_dict
159+
return {
160+
"role": "tool",
161+
"content": result_content,
162+
"tool_call_id": self.tool_call_id or "",
163+
}

camel/models/cohere_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,14 @@ def run(self, messages: List[OpenAIMessage]) -> ChatCompletion:
228228

229229
cohere_messages = self._to_cohere_chatmessage(messages)
230230

231+
# Removing 'strict': True from the dictionary for
232+
# cohere client
233+
if self.model_config_dict.get('tools') is not None:
234+
for tool in self.model_config_dict.get('tools', []):
235+
function_dict = tool.get('function', {})
236+
if 'strict' in function_dict:
237+
del function_dict['strict']
238+
231239
try:
232240
response = self._client.chat(
233241
messages=cohere_messages,

camel/models/mistral_model.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,25 @@ def _to_mistral_chatmessage(
147147
new_messages = []
148148
for msg in messages:
149149
tool_id = uuid.uuid4().hex[:9]
150-
tool_call_id = uuid.uuid4().hex[:9]
150+
tool_call_id = msg.get("tool_call_id") or uuid.uuid4().hex[:9]
151151

152152
role = msg.get("role")
153-
function_call = msg.get("function_call")
153+
tool_calls = msg.get("tool_calls")
154154
content = msg.get("content")
155155

156156
mistral_function_call = None
157-
if function_call:
158-
mistral_function_call = FunctionCall(
159-
name=function_call.get("name"), # type: ignore[attr-defined]
160-
arguments=function_call.get("arguments"), # type: ignore[attr-defined]
157+
if tool_calls:
158+
# Ensure tool_calls is treated as a list
159+
tool_calls_list = (
160+
tool_calls
161+
if isinstance(tool_calls, list)
162+
else [tool_calls]
161163
)
164+
for tool_call in tool_calls_list:
165+
mistral_function_call = FunctionCall(
166+
name=tool_call["function"].get("name"), # type: ignore[attr-defined]
167+
arguments=tool_call["function"].get("arguments"), # type: ignore[attr-defined]
168+
)
162169

163170
tool_calls = None
164171
if mistral_function_call:
@@ -178,7 +185,7 @@ def _to_mistral_chatmessage(
178185
new_messages.append(
179186
ToolMessage(
180187
content=content, # type: ignore[arg-type]
181-
tool_call_id=tool_call_id,
188+
tool_call_id=tool_call_id, # type: ignore[arg-type]
182189
name=msg.get("name"), # type: ignore[arg-type]
183190
)
184191
)

camel/types/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333
ChatCompletion,
3434
ChatCompletionAssistantMessageParam,
3535
ChatCompletionChunk,
36-
ChatCompletionFunctionMessageParam,
3736
ChatCompletionMessage,
3837
ChatCompletionMessageParam,
38+
ChatCompletionMessageToolCall,
3939
ChatCompletionSystemMessageParam,
40+
ChatCompletionToolMessageParam,
4041
ChatCompletionUserMessageParam,
4142
Choice,
4243
CompletionUsage,
@@ -62,7 +63,8 @@
6263
'ChatCompletionSystemMessageParam',
6364
'ChatCompletionUserMessageParam',
6465
'ChatCompletionAssistantMessageParam',
65-
'ChatCompletionFunctionMessageParam',
66+
'ChatCompletionToolMessageParam',
67+
'ChatCompletionMessageToolCall',
6668
'CompletionUsage',
6769
'OpenAIImageType',
6870
'OpenAIVisionDetailType',

0 commit comments

Comments
 (0)