Skip to content

Commit d15a80e

Browse files
authored
enhance: gemini stream mode support (#3443)
1 parent 6c2116b commit d15a80e

File tree

10 files changed

+451
-197
lines changed

10 files changed

+451
-197
lines changed

.github/ISSUE_TEMPLATE/bug_report.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ body:
2626
attributes:
2727
label: What version of camel are you using?
2828
description: Run command `python3 -c 'print(__import__("camel").__version__)'` in your shell and paste the output here.
29-
placeholder: E.g., 0.2.80a2
29+
placeholder: E.g., 0.2.80a3
3030
validations:
3131
required: true
3232

camel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from camel.logger import disable_logging, enable_logging, set_log_level
1616

17-
__version__ = '0.2.80a2'
17+
__version__ = '0.2.80a3'
1818

1919
__all__ = [
2020
'__version__',

camel/agents/_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ToolCallRequest(BaseModel):
2727
tool_name: str
2828
args: Dict[str, Any]
2929
tool_call_id: str
30+
extra_content: Optional[Dict[str, Any]] = None
3031

3132

3233
class ModelResponse(BaseModel):

camel/agents/chat_agent.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3692,8 +3692,13 @@ def _handle_batch_response(
36923692
tool_name = tool_call.function.name # type: ignore[union-attr]
36933693
tool_call_id = tool_call.id
36943694
args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
3695+
extra_content = getattr(tool_call, 'extra_content', None)
3696+
36953697
tool_call_request = ToolCallRequest(
3696-
tool_name=tool_name, args=args, tool_call_id=tool_call_id
3698+
tool_name=tool_name,
3699+
args=args,
3700+
tool_call_id=tool_call_id,
3701+
extra_content=extra_content,
36973702
)
36983703
tool_call_requests.append(tool_call_request)
36993704

@@ -3786,7 +3791,12 @@ def _execute_tool(
37863791
logger.warning(f"{error_msg} with result: {result}")
37873792

37883793
return self._record_tool_calling(
3789-
func_name, args, result, tool_call_id, mask_output=mask_flag
3794+
func_name,
3795+
args,
3796+
result,
3797+
tool_call_id,
3798+
mask_output=mask_flag,
3799+
extra_content=tool_call_request.extra_content,
37903800
)
37913801

37923802
async def _aexecute_tool(
@@ -3828,7 +3838,13 @@ async def _aexecute_tool(
38283838
error_msg = f"Error executing async tool '{func_name}': {e!s}"
38293839
result = f"Tool execution failed: {error_msg}"
38303840
logger.warning(error_msg)
3831-
return self._record_tool_calling(func_name, args, result, tool_call_id)
3841+
return self._record_tool_calling(
3842+
func_name,
3843+
args,
3844+
result,
3845+
tool_call_id,
3846+
extra_content=tool_call_request.extra_content,
3847+
)
38323848

38333849
def _record_tool_calling(
38343850
self,
@@ -3837,6 +3853,7 @@ def _record_tool_calling(
38373853
result: Any,
38383854
tool_call_id: str,
38393855
mask_output: bool = False,
3856+
extra_content: Optional[Dict[str, Any]] = None,
38403857
):
38413858
r"""Record the tool calling information in the memory, and return the
38423859
tool calling record.
@@ -3849,6 +3866,9 @@ def _record_tool_calling(
38493866
mask_output (bool, optional): Whether to return a sanitized
38503867
placeholder instead of the raw tool output.
38513868
(default: :obj:`False`)
3869+
extra_content (Optional[Dict[str, Any]], optional): Additional
3870+
content associated with the tool call.
3871+
(default: :obj:`None`)
38523872
38533873
Returns:
38543874
ToolCallingRecord: A struct containing information about
@@ -3862,6 +3882,7 @@ def _record_tool_calling(
38623882
func_name=func_name,
38633883
args=args,
38643884
tool_call_id=tool_call_id,
3885+
extra_content=extra_content,
38653886
)
38663887
func_msg = FunctionCallingMessage(
38673888
role_name=self.role_name,
@@ -3872,6 +3893,7 @@ def _record_tool_calling(
38723893
result=result,
38733894
tool_call_id=tool_call_id,
38743895
mask_output=mask_output,
3896+
extra_content=extra_content,
38753897
)
38763898

38773899
# Use precise timestamps to ensure correct ordering
@@ -4006,7 +4028,7 @@ def _stream_response(
40064028
return
40074029

40084030
# Handle streaming response
4009-
if isinstance(response, Stream):
4031+
if isinstance(response, Stream) or inspect.isgenerator(response):
40104032
(
40114033
stream_completed,
40124034
tool_calls_complete,
@@ -4346,6 +4368,7 @@ def _accumulate_tool_calls(
43464368
'id': '',
43474369
'type': 'function',
43484370
'function': {'name': '', 'arguments': ''},
4371+
'extra_content': None,
43494372
'complete': False,
43504373
}
43514374

@@ -4369,6 +4392,14 @@ def _accumulate_tool_calls(
43694392
tool_call_entry['function']['arguments'] += (
43704393
delta_tool_call.function.arguments
43714394
)
4395+
# Handle extra_content if present
4396+
if (
4397+
hasattr(delta_tool_call, 'extra_content')
4398+
and delta_tool_call.extra_content
4399+
):
4400+
tool_call_entry['extra_content'] = (
4401+
delta_tool_call.extra_content
4402+
)
43724403

43734404
# Check if any tool calls are complete
43744405
any_complete = False
@@ -4473,6 +4504,7 @@ def _execute_tool_from_stream_data(
44734504
function_name = tool_call_data['function']['name']
44744505
args = json.loads(tool_call_data['function']['arguments'])
44754506
tool_call_id = tool_call_data['id']
4507+
extra_content = tool_call_data.get('extra_content')
44764508

44774509
if function_name in self._internal_tools:
44784510
tool = self._internal_tools[function_name]
@@ -4488,6 +4520,7 @@ def _execute_tool_from_stream_data(
44884520
func_name=function_name,
44894521
args=args,
44904522
tool_call_id=tool_call_id,
4523+
extra_content=extra_content,
44914524
)
44924525

44934526
# Then create the tool response message
@@ -4499,6 +4532,7 @@ def _execute_tool_from_stream_data(
44994532
func_name=function_name,
45004533
result=result,
45014534
tool_call_id=tool_call_id,
4535+
extra_content=extra_content,
45024536
)
45034537

45044538
# Record both messages with precise timestamps to ensure
@@ -4544,6 +4578,7 @@ def _execute_tool_from_stream_data(
45444578
func_name=function_name,
45454579
result=result,
45464580
tool_call_id=tool_call_id,
4581+
extra_content=extra_content,
45474582
)
45484583

45494584
self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
@@ -4575,6 +4610,7 @@ async def _aexecute_tool_from_stream_data(
45754610
function_name = tool_call_data['function']['name']
45764611
args = json.loads(tool_call_data['function']['arguments'])
45774612
tool_call_id = tool_call_data['id']
4613+
extra_content = tool_call_data.get('extra_content')
45784614

45794615
if function_name in self._internal_tools:
45804616
# Create the tool call message
@@ -4586,6 +4622,7 @@ async def _aexecute_tool_from_stream_data(
45864622
func_name=function_name,
45874623
args=args,
45884624
tool_call_id=tool_call_id,
4625+
extra_content=extra_content,
45894626
)
45904627
assist_ts = time.time_ns() / 1_000_000_000
45914628
self.update_memory(
@@ -4632,6 +4669,7 @@ async def _aexecute_tool_from_stream_data(
46324669
func_name=function_name,
46334670
result=result,
46344671
tool_call_id=tool_call_id,
4672+
extra_content=extra_content,
46354673
)
46364674
func_ts = time.time_ns() / 1_000_000_000
46374675
self.update_memory(
@@ -4665,6 +4703,7 @@ async def _aexecute_tool_from_stream_data(
46654703
func_name=function_name,
46664704
result=result,
46674705
tool_call_id=tool_call_id,
4706+
extra_content=extra_content,
46684707
)
46694708
func_ts = time.time_ns() / 1_000_000_000
46704709
self.update_memory(
@@ -4979,6 +5018,11 @@ def _record_assistant_tool_calls_message(
49795018
"arguments": tool_call_data["function"]["arguments"],
49805019
},
49815020
}
5021+
# Include extra_content if present
5022+
if tool_call_data.get('extra_content'):
5023+
tool_call_dict["extra_content"] = tool_call_data[
5024+
"extra_content"
5025+
]
49825026
tool_calls_list.append(tool_call_dict)
49835027

49845028
# Create an assistant message with tool calls

camel/messages/func_message.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,17 @@ class FunctionCallingMessage(BaseMessage):
5050
mask_output (Optional[bool]): Whether to return a sanitized placeholder
5151
instead of the raw tool output.
5252
(default: :obj:`False`)
53+
extra_content (Optional[Dict[str, Any]]): Additional content
54+
associated with the tool call.
55+
(default: :obj:`None`)
5356
"""
5457

5558
func_name: Optional[str] = None
5659
args: Optional[Dict] = None
5760
result: Optional[Any] = None
5861
tool_call_id: Optional[str] = None
5962
mask_output: Optional[bool] = False
63+
extra_content: Optional[Dict[str, Any]] = None
6064

6165
def to_openai_message(
6266
self,
@@ -131,19 +135,23 @@ def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
131135
" due to missing function name or arguments."
132136
)
133137

138+
tool_call = {
139+
"id": self.tool_call_id or "null",
140+
"type": "function",
141+
"function": {
142+
"name": self.func_name,
143+
"arguments": json.dumps(self.args, ensure_ascii=False),
144+
},
145+
}
146+
147+
# Include extra_content if available
148+
if self.extra_content is not None:
149+
tool_call["extra_content"] = self.extra_content
150+
134151
return {
135152
"role": "assistant",
136153
"content": self.content or "",
137-
"tool_calls": [
138-
{
139-
"id": self.tool_call_id or "null",
140-
"type": "function",
141-
"function": {
142-
"name": self.func_name,
143-
"arguments": json.dumps(self.args, ensure_ascii=False),
144-
},
145-
}
146-
],
154+
"tool_calls": [tool_call], # type: ignore[list-item]
147155
}
148156

149157
def to_openai_tool_message(self) -> OpenAIToolMessageParam:
@@ -187,4 +195,6 @@ def to_dict(self) -> Dict:
187195
if self.tool_call_id is not None:
188196
base["tool_call_id"] = self.tool_call_id
189197
base["mask_output"] = self.mask_output
198+
if self.extra_content is not None:
199+
base["extra_content"] = self.extra_content
190200
return base

0 commit comments

Comments
 (0)