Skip to content

Commit aba41d7

Browse files
authored
feat: add structured output to model clients (microsoft#5936)
1 parent 9bde5ef commit aba41d7

File tree

27 files changed

+1627
-402
lines changed

27 files changed

+1627
-402
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,7 @@ async def get_messages(self) -> List[LLMMessage]:
574574
"function_calling": False,
575575
"json_output": False,
576576
"family": ModelFamily.R1,
577+
"structured_output": True,
577578
},
578579
)
579580

python/packages/autogen-agentchat/tests/test_assistant_agent.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
6868
"pass",
6969
"TERMINATE",
7070
],
71-
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
71+
model_info={
72+
"function_calling": True,
73+
"vision": True,
74+
"json_output": True,
75+
"family": ModelFamily.GPT_4O,
76+
"structured_output": True,
77+
},
7278
)
7379
agent = AssistantAgent(
7480
"tool_use_agent",
@@ -150,7 +156,13 @@ async def test_run_with_tools_and_reflection() -> None:
150156
cached=False,
151157
),
152158
],
153-
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
159+
model_info={
160+
"function_calling": True,
161+
"vision": True,
162+
"json_output": True,
163+
"family": ModelFamily.GPT_4O,
164+
"structured_output": True,
165+
},
154166
)
155167
agent = AssistantAgent(
156168
"tool_use_agent",
@@ -236,7 +248,13 @@ async def test_run_with_parallel_tools() -> None:
236248
"pass",
237249
"TERMINATE",
238250
],
239-
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
251+
model_info={
252+
"function_calling": True,
253+
"vision": True,
254+
"json_output": True,
255+
"family": ModelFamily.GPT_4O,
256+
"structured_output": True,
257+
},
240258
)
241259
agent = AssistantAgent(
242260
"tool_use_agent",
@@ -315,7 +333,13 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
315333
"pass",
316334
"TERMINATE",
317335
],
318-
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
336+
model_info={
337+
"function_calling": True,
338+
"vision": True,
339+
"json_output": True,
340+
"family": ModelFamily.GPT_4O,
341+
"structured_output": True,
342+
},
319343
)
320344
agent = AssistantAgent(
321345
"tool_use_agent",
@@ -389,7 +413,13 @@ async def test_handoffs() -> None:
389413
cached=False,
390414
)
391415
],
392-
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
416+
model_info={
417+
"function_calling": True,
418+
"vision": True,
419+
"json_output": True,
420+
"family": ModelFamily.GPT_4O,
421+
"structured_output": True,
422+
},
393423
)
394424
tool_use_agent = AssistantAgent(
395425
"tool_use_agent",
@@ -447,7 +477,13 @@ async def test_invalid_model_capabilities() -> None:
447477
model_client = OpenAIChatCompletionClient(
448478
model=model,
449479
api_key="",
450-
model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN},
480+
model_info={
481+
"vision": False,
482+
"function_calling": False,
483+
"json_output": False,
484+
"family": ModelFamily.UNKNOWN,
485+
"structured_output": False,
486+
},
451487
)
452488

453489
with pytest.raises(ValueError):
@@ -473,12 +509,24 @@ async def test_remove_images() -> None:
473509
model_client_1 = OpenAIChatCompletionClient(
474510
model=model,
475511
api_key="",
476-
model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN},
512+
model_info={
513+
"vision": False,
514+
"function_calling": False,
515+
"json_output": False,
516+
"family": ModelFamily.UNKNOWN,
517+
"structured_output": False,
518+
},
477519
)
478520
model_client_2 = OpenAIChatCompletionClient(
479521
model=model,
480522
api_key="",
481-
model_info={"vision": True, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN},
523+
model_info={
524+
"vision": True,
525+
"function_calling": False,
526+
"json_output": False,
527+
"family": ModelFamily.UNKNOWN,
528+
"structured_output": False,
529+
},
482530
)
483531

484532
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
@@ -642,7 +690,13 @@ class BadMemory:
642690
async def test_assistant_agent_declarative() -> None:
643691
model_client = ReplayChatCompletionClient(
644692
["Response to message 3"],
645-
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
693+
model_info={
694+
"function_calling": True,
695+
"vision": True,
696+
"json_output": True,
697+
"family": ModelFamily.GPT_4O,
698+
"structured_output": True,
699+
},
646700
)
647701
model_context = BufferedChatCompletionContext(buffer_size=2)
648702
agent = AssistantAgent(

python/packages/autogen-agentchat/tests/test_group_chat.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,13 @@ async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) -
259259
"Hello",
260260
"TERMINATE",
261261
],
262-
model_info={"family": "gpt-4o", "function_calling": True, "json_output": True, "vision": True},
262+
model_info={
263+
"family": "gpt-4o",
264+
"function_calling": True,
265+
"json_output": True,
266+
"vision": True,
267+
"structured_output": True,
268+
},
263269
)
264270
tool = FunctionTool(_pass_function, name="pass", description="pass function")
265271
tool_use_agent = AssistantAgent("tool_use_agent", model_client=model_client, tools=[tool])
@@ -805,7 +811,13 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
805811
"Hello",
806812
"TERMINATE",
807813
],
808-
model_info={"family": "gpt-4o", "function_calling": True, "json_output": True, "vision": True},
814+
model_info={
815+
"family": "gpt-4o",
816+
"function_calling": True,
817+
"json_output": True,
818+
"vision": True,
819+
"structured_output": True,
820+
},
809821
)
810822
agent1 = AssistantAgent(
811823
"agent1",
@@ -889,7 +901,13 @@ async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> N
889901
"Hello",
890902
"TERMINATE",
891903
],
892-
model_info={"family": "gpt-4o", "function_calling": True, "json_output": True, "vision": True},
904+
model_info={
905+
"family": "gpt-4o",
906+
"function_calling": True,
907+
"json_output": True,
908+
"vision": True,
909+
"structured_output": True,
910+
},
893911
)
894912

895913
expected_handoff_context: List[LLMMessage] = [

python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/migration-guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ custom_model_client = OpenAIChatCompletionClient(
154154
"function_calling": True,
155155
"json_output": True,
156156
"family": "unknown",
157+
"structured_output": True,
157158
},
158159
)
159160
```

python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@
251251
" \"function_calling\": False,\n",
252252
" \"vision\": False,\n",
253253
" \"family\": \"unknown\",\n",
254+
" \"structured_output\": False,\n",
254255
" },\n",
255256
")\n",
256257
"\n",

python/packages/autogen-core/docs/src/user-guide/autogenstudio-user-guide/faq.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ print(anthropic_client.dump_component().model_dump_json())
4848
mistral_vllm_model = OpenAIChatCompletionClient(
4949
model="TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
5050
base_url="http://localhost:1234/v1",
51-
model_info=ModelInfo(vision=False, function_calling=True, json_output=False, family="unknown"),
51+
model_info=ModelInfo(vision=False, function_calling=True, json_output=False, family="unknown", structured_output=True),
5252
)
5353
print(mistral_vllm_model.dump_component().model_dump_json())
5454
```
@@ -122,7 +122,8 @@ Have a local model server like Ollama, vLLM or LMStudio that provide an OpenAI c
122122
"vision": false,
123123
"function_calling": true,
124124
"json_output": false,
125-
"family": "unknown"
125+
"family": "unknown",
126+
"structured_output": true
126127
},
127128
"base_url": "http://localhost:1234/v1"
128129
}

python/packages/autogen-core/src/autogen_core/models/_model_client.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class ModelInfo(TypedDict, total=False):
107107
"""True if the model supports json output, otherwise False. Note: this is different to structured json."""
108108
family: Required[ModelFamily.ANY | str]
109109
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""
110+
structured_output: Required[bool]
111+
"""True if the model supports structured output, otherwise False. This is different to json_output."""
110112

111113

112114
def validate_model_info(model_info: ModelInfo) -> None:
@@ -122,6 +124,15 @@ def validate_model_info(model_info: ModelInfo) -> None:
122124
f"Missing required field '{field}' in ModelInfo. "
123125
"Starting in v0.4.7, the required fields are enforced."
124126
)
127+
new_required_fields = ["structured_output"]
128+
for field in new_required_fields:
129+
if field not in model_info:
130+
warnings.warn(
131+
f"Missing required field '{field}' in ModelInfo. "
132+
"This field will be required in a future version of AutoGen.",
133+
UserWarning,
134+
stacklevel=2,
135+
)
125136

126137

127138
class ChatCompletionClient(ComponentBase[BaseModel], ABC):
@@ -134,10 +145,24 @@ async def create(
134145
tools: Sequence[Tool | ToolSchema] = [],
135146
# None means do not override the default
136147
# A value means to override the client default - often specified in the constructor
137-
json_output: Optional[bool] = None,
148+
json_output: Optional[bool | type[BaseModel]] = None,
138149
extra_create_args: Mapping[str, Any] = {},
139150
cancellation_token: Optional[CancellationToken] = None,
140-
) -> CreateResult: ...
151+
) -> CreateResult:
152+
"""Creates a single response from the model.
153+
154+
Args:
155+
messages (Sequence[LLMMessage]): The messages to send to the model.
156+
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
157+
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither. Defaults to None. If set to a type, it will be used as the output type
158+
for structured output. If set to a boolean, it will be used to determine whether to use JSON mode or not.
159+
extra_create_args (Mapping[str, Any], optional): Extra arguments to pass to the underlying client. Defaults to {}.
160+
cancellation_token (Optional[CancellationToken], optional): A token for cancellation. Defaults to None.
161+
162+
Returns:
163+
CreateResult: The result of the model call.
164+
"""
165+
...
141166

142167
@abstractmethod
143168
def create_stream(
@@ -147,10 +172,24 @@ def create_stream(
147172
tools: Sequence[Tool | ToolSchema] = [],
148173
# None means do not override the default
149174
# A value means to override the client default - often specified in the constructor
150-
json_output: Optional[bool] = None,
175+
json_output: Optional[bool | type[BaseModel]] = None,
151176
extra_create_args: Mapping[str, Any] = {},
152177
cancellation_token: Optional[CancellationToken] = None,
153-
) -> AsyncGenerator[Union[str, CreateResult], None]: ...
178+
) -> AsyncGenerator[Union[str, CreateResult], None]:
179+
"""Creates a stream of string chunks from the model ending with a CreateResult.
180+
181+
Args:
182+
messages (Sequence[LLMMessage]): The messages to send to the model.
183+
tools (Sequence[Tool | ToolSchema], optional): The tools to use with the model. Defaults to [].
184+
json_output (Optional[bool | type[BaseModel]], optional): Whether to use JSON mode, structured output, or neither. Defaults to None. If set to a type, it will be used as the output type
185+
for structured output. If set to a boolean, it will be used to determine whether to use JSON mode or not.
186+
extra_create_args (Mapping[str, Any], optional): Extra arguments to pass to the underlying client. Defaults to {}.
187+
cancellation_token (Optional[CancellationToken], optional): A token for cancellation. Defaults to None.
188+
189+
Returns:
190+
AsyncGenerator[Union[str, CreateResult], None]: A generator that yields string chunks and ends with a :py:class:`CreateResult`.
191+
"""
192+
...
154193

155194
@abstractmethod
156195
async def close(self) -> None: ...

python/packages/autogen-core/tests/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def test_model_info() -> None:
99
"vision": True,
1010
"function_calling": True,
1111
"json_output": True,
12+
"structured_output": True,
1213
}
1314
validate_model_info(info)
1415

python/packages/autogen-core/tests/test_tool_agent.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
tool_agent_caller_loop,
2626
)
2727
from autogen_core.tools import FunctionTool, Tool, ToolSchema
28+
from pydantic import BaseModel
2829

2930
logging.getLogger(EVENT_LOGGER_NAME).setLevel(logging.INFO)
3031

@@ -101,7 +102,7 @@ async def create(
101102
messages: Sequence[LLMMessage],
102103
*,
103104
tools: Sequence[Tool | ToolSchema] = [],
104-
json_output: Optional[bool] = None,
105+
json_output: Optional[bool | type[BaseModel]] = None,
105106
extra_create_args: Mapping[str, Any] = {},
106107
cancellation_token: Optional[CancellationToken] = None,
107108
) -> CreateResult:
@@ -126,7 +127,7 @@ def create_stream(
126127
messages: Sequence[LLMMessage],
127128
*,
128129
tools: Sequence[Tool | ToolSchema] = [],
129-
json_output: Optional[bool] = None,
130+
json_output: Optional[bool | type[BaseModel]] = None,
130131
extra_create_args: Mapping[str, Any] = {},
131132
cancellation_token: Optional[CancellationToken] = None,
132133
) -> AsyncGenerator[Union[str, CreateResult], None]:
@@ -153,7 +154,13 @@ def capabilities(self) -> ModelCapabilities: # type: ignore
153154

154155
@property
155156
def model_info(self) -> ModelInfo:
156-
return ModelInfo(vision=False, function_calling=True, json_output=False, family=ModelFamily.UNKNOWN)
157+
return ModelInfo(
158+
vision=False,
159+
function_calling=True,
160+
json_output=False,
161+
family=ModelFamily.UNKNOWN,
162+
structured_output=False,
163+
)
157164

158165
client = MockChatCompletionClient()
159166
tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")]

python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/chat_completion_client_recorder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
RequestUsage,
1414
)
1515
from autogen_core.tools import Tool, ToolSchema
16+
from pydantic import BaseModel
1617

1718
from .page_logger import PageLogger
1819

@@ -87,7 +88,7 @@ async def create(
8788
messages: Sequence[LLMMessage],
8889
*,
8990
tools: Sequence[Tool | ToolSchema] = [],
90-
json_output: Optional[bool] = None,
91+
json_output: Optional[bool | type[BaseModel]] = None,
9192
extra_create_args: Mapping[str, Any] = {},
9293
cancellation_token: Optional[CancellationToken] = None,
9394
) -> CreateResult:
@@ -154,7 +155,7 @@ def create_stream(
154155
messages: Sequence[LLMMessage],
155156
*,
156157
tools: Sequence[Tool | ToolSchema] = [],
157-
json_output: Optional[bool] = None,
158+
json_output: Optional[bool | type[BaseModel]] = None,
158159
extra_create_args: Mapping[str, Any] = {},
159160
cancellation_token: Optional[CancellationToken] = None,
160161
) -> AsyncGenerator[Union[str, CreateResult], None]:

0 commit comments

Comments
 (0)