Skip to content

Commit a06d98e

Browse files
Li YinLi Yin
authored andcommitted
added thinking model support which directly use thinking model's reasoning instead of cot
1 parent 664ee01 commit a06d98e

4 files changed

Lines changed: 94 additions & 80 deletions

File tree

adalflow/adalflow/components/agent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def __init__(
260260
# default agent parameters
261261
answer_data_type: Optional[Type[T]] = str, # the data type of the final answer
262262
max_steps: Optional[int] = 10,
263-
is_thinking_model: Optional[bool] = False, # support thinking model in agent
263+
is_thinking_model: Optional[bool] = False, # when thinking model turned on, it disables the CoT field in the output
264264
# for fully customize the agent
265265
tool_manager: Optional[
266266
ToolManager

adalflow/adalflow/components/agent/runner.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def __init__(
154154
self._current_task = None # Track the current running task
155155
self._current_streaming_result = None # Track the current streaming result
156156

157+
# support thinking model
158+
self.is_thinking_model = agent.is_thinking_model if hasattr(agent, 'is_thinking_model') else False
159+
157160
def _init_permission_manager(self):
158161
"""Initialize the permission manager and register tools that require approval."""
159162
if self.permission_manager and hasattr(self.agent, "tool_manager"):
@@ -546,8 +549,11 @@ def call(
546549
break
547550

548551
function = output.data
552+
thinking = output.thinking if hasattr(output, 'thinking') else None
549553
if function is not None:
550-
function.id = str(uuid.uuid4()) # add function id
554+
function.id = str(uuid.uuid4()) # add function id
555+
if thinking is not None and self.is_thinking_model:
556+
function.thought = thinking
551557
printc(f"function: {function}", color="yellow")
552558
if function is None:
553559
error_msg = output.error
@@ -832,9 +838,12 @@ async def acall(
832838
break
833839

834840
function = output.data
841+
thinking = output.thinking if hasattr(output, 'thinking') else None
835842
if function is not None:
836843
# add a function id
837844
function.id = str(uuid.uuid4())
845+
if thinking is not None and self.is_thinking_model:
846+
function.thought = thinking
838847
printc(f"function: {function}", color="yellow")
839848

840849
if self._check_last_step(function):
@@ -1175,6 +1184,7 @@ async def impl_astream(
11751184
# handle function output
11761185

11771186
function = output.data # here are the recoverable errors, should continue to step output
1187+
thinking = output.thinking # check the reasoning model response
11781188
function.id = str(uuid.uuid4()) # add function id
11791189
function_result = None
11801190
function_output_observation = None
@@ -1201,6 +1211,9 @@ async def impl_astream(
12011211
# for normal function
12021212
function.id = str(uuid.uuid4())
12031213

1214+
if thinking is not None and self.is_thinking_model:
1215+
function.thought = thinking
1216+
12041217
# TODO: simplify this
12051218
tool_call_id = function.id
12061219
tool_call_name = function.name

0 commit comments

Comments
 (0)