diff --git a/src/openai/_models.py b/src/openai/_models.py index 5148d5a7b3..2d4bd2d13e 100644 --- a/src/openai/_models.py +++ b/src/openai/_models.py @@ -46,6 +46,7 @@ extract_type_arg, is_annotated_type, strip_annotated_type, + wrap_in_annotated_type, ) from ._compat import ( PYDANTIC_V2, @@ -356,7 +357,10 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: return field_get_default(field) if PYDANTIC_V2: - type_ = field.annotation + if field.metadata: + type_ = wrap_in_annotated_type(field) + else: + type_ = field.annotation else: type_ = cast(type, field.outer_type_) # type: ignore @@ -609,8 +613,13 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, # Note: if one variant defines an alias then they all should discriminator_alias = field_info.alias - if field_info.annotation and is_literal_type(field_info.annotation): - for entry in get_args(field_info.annotation): + if hasattr(field_info, "annotation"): + field_annotation = cast(type, field_info.annotation) + else: + # pydantic==1.9 + field_annotation = cast(type, field_info.outer_type_) # type: ignore + if field_annotation and is_literal_type(field_annotation): + for entry in get_args(field_annotation): if isinstance(entry, str): mapping[entry] = variant diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py index 3efe66c8e8..b0163e80e7 100644 --- a/src/openai/_utils/__init__.py +++ b/src/openai/_utils/__init__.py @@ -39,6 +39,7 @@ is_required_type as is_required_type, is_annotated_type as is_annotated_type, strip_annotated_type as strip_annotated_type, + wrap_in_annotated_type as wrap_in_annotated_type, extract_type_var_from_base as extract_type_var_from_base, ) from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator diff --git a/src/openai/_utils/_typing.py b/src/openai/_utils/_typing.py index c036991f04..fd1c8c849c 100644 --- a/src/openai/_utils/_typing.py +++ b/src/openai/_utils/_typing.py @@ -4,6 +4,8 @@ from collections import abc as _c_abc from typing_extensions import Required, Annotated, get_args, get_origin +from pydantic.fields import FieldInfo + from .._types import InheritsGeneric from .._compat import is_union as _is_union @@ -44,6 +46,10 @@ def strip_annotated_type(typ: type) -> type: return typ +def wrap_in_annotated_type(typ: FieldInfo) -> object: + return Annotated[cast(type, typ.annotation), typ.metadata[0]] + + def extract_type_arg(typ: type, index: int) -> type: args = get_args(typ) try: diff --git a/src/openai/types/beta/__init__.py b/src/openai/types/beta/__init__.py index 9c5ddfdbe0..368db4bdff 100644 --- a/src/openai/types/beta/__init__.py +++ b/src/openai/types/beta/__init__.py @@ -6,7 +6,7 @@ from .assistant import Assistant as Assistant from .vector_store import VectorStore as VectorStore from .function_tool import FunctionTool as FunctionTool -from .assistant_tool import AssistantTool as AssistantTool +from .assistant_tool import BaseTool as BaseTool, AssistantTool as AssistantTool from .thread_deleted import ThreadDeleted as ThreadDeleted from .file_search_tool import FileSearchTool as FileSearchTool from .assistant_deleted import AssistantDeleted as AssistantDeleted diff --git a/src/openai/types/beta/assistant_tool.py b/src/openai/types/beta/assistant_tool.py index 1bde6858b1..831d0b36a2 100644 --- a/src/openai/types/beta/assistant_tool.py +++ b/src/openai/types/beta/assistant_tool.py @@ -1,15 +1,33 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Union -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Literal, Annotated, TypeAlias from ..._utils import PropertyInfo +from ..._compat import PYDANTIC_V2 +from ..._models import BaseModel from .function_tool import FunctionTool from .file_search_tool import FileSearchTool from .code_interpreter_tool import CodeInterpreterTool -__all__ = ["AssistantTool"] +if PYDANTIC_V2: + from pydantic import field_serializer + + +__all__ = ["AssistantTool", "BaseTool"] + + +class BaseTool(BaseModel): + type: Literal["unknown"] + """A tool type""" + + if PYDANTIC_V2: + + @field_serializer("type", when_used="always") # type: ignore + def serialize_unknown_type(self, type_: str) -> str: + return type_ + AssistantTool: TypeAlias = Annotated[ - Union[CodeInterpreterTool, FileSearchTool, FunctionTool], PropertyInfo(discriminator="type") + Union[BaseTool, CodeInterpreterTool, FileSearchTool, FunctionTool], PropertyInfo(discriminator="type") ] diff --git a/src/openai/types/beta/threads/__init__.py b/src/openai/types/beta/threads/__init__.py index 70853177bd..642e3a6f39 100644 --- a/src/openai/types/beta/threads/__init__.py +++ b/src/openai/types/beta/threads/__init__.py @@ -6,17 +6,17 @@ from .text import Text as Text from .message import Message as Message from .image_url import ImageURL as ImageURL -from .annotation import Annotation as Annotation +from .annotation import Annotation as Annotation, BaseAnnotation as BaseAnnotation from .image_file import ImageFile as ImageFile from .run_status import RunStatus as RunStatus from .text_delta import TextDelta as TextDelta from .message_delta import MessageDelta as MessageDelta from .image_url_delta import ImageURLDelta as ImageURLDelta from .image_url_param import ImageURLParam as ImageURLParam -from .message_content import MessageContent as MessageContent +from .message_content import MessageContent as MessageContent, BaseContentBlock as BaseContentBlock from .message_deleted import MessageDeleted as MessageDeleted from .run_list_params import RunListParams as RunListParams -from .annotation_delta import AnnotationDelta as AnnotationDelta +from .annotation_delta import AnnotationDelta as AnnotationDelta, BaseDeltaAnnotation as BaseDeltaAnnotation from .image_file_delta import ImageFileDelta as ImageFileDelta from .image_file_param import ImageFileParam as ImageFileParam from .text_delta_block import TextDeltaBlock as TextDeltaBlock @@ -28,7 +28,7 @@ from .refusal_delta_block import RefusalDeltaBlock as RefusalDeltaBlock from .file_path_annotation import FilePathAnnotation as FilePathAnnotation from .image_url_delta_block import ImageURLDeltaBlock as ImageURLDeltaBlock -from .message_content_delta import MessageContentDelta as MessageContentDelta +from .message_content_delta import BaseDeltaBlock as BaseDeltaBlock, MessageContentDelta as MessageContentDelta from .message_create_params import MessageCreateParams as MessageCreateParams from .message_update_params import MessageUpdateParams as MessageUpdateParams from .refusal_content_block import RefusalContentBlock as RefusalContentBlock diff --git a/src/openai/types/beta/threads/annotation.py b/src/openai/types/beta/threads/annotation.py index 13c10abf4d..b01f2eb687 100644 --- a/src/openai/types/beta/threads/annotation.py +++ b/src/openai/types/beta/threads/annotation.py @@ -1,12 +1,34 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Union -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Literal, Annotated, TypeAlias from ...._utils import PropertyInfo +from ...._compat import PYDANTIC_V2 +from ...._models import BaseModel from .file_path_annotation import FilePathAnnotation from .file_citation_annotation import FileCitationAnnotation -__all__ = ["Annotation"] +if PYDANTIC_V2: + from pydantic import field_serializer -Annotation: TypeAlias = Annotated[Union[FileCitationAnnotation, FilePathAnnotation], PropertyInfo(discriminator="type")] +__all__ = ["Annotation", "BaseAnnotation"] + + +class BaseAnnotation(BaseModel): + text: str + """The index of the annotation in the text content part.""" + + type: Literal["unknown"] + """The type of annotation""" + + if PYDANTIC_V2: + + @field_serializer("type", when_used="always") # type: ignore + def serialize_unknown_type(self, type_: str) -> str: + return type_ + + +Annotation: TypeAlias = Annotated[ + Union[BaseAnnotation, FileCitationAnnotation, FilePathAnnotation], PropertyInfo(discriminator="type") +] diff --git a/src/openai/types/beta/threads/annotation_delta.py b/src/openai/types/beta/threads/annotation_delta.py index c7c6c89837..8ef1e5d72d 100644 --- a/src/openai/types/beta/threads/annotation_delta.py +++ b/src/openai/types/beta/threads/annotation_delta.py @@ -1,14 +1,35 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Union -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Literal, Annotated, TypeAlias from ...._utils import PropertyInfo +from ...._compat import PYDANTIC_V2 +from ...._models import BaseModel from .file_path_delta_annotation import FilePathDeltaAnnotation from .file_citation_delta_annotation import FileCitationDeltaAnnotation -__all__ = ["AnnotationDelta"] +if PYDANTIC_V2: + from pydantic import field_serializer + + +__all__ = ["AnnotationDelta", "BaseDeltaAnnotation"] + + +class BaseDeltaAnnotation(BaseModel): + index: int + """The index of the annotation in the text content part.""" + + type: Literal["unknown"] + """The type of annotation""" + + if PYDANTIC_V2: + + @field_serializer("type", when_used="always") # type: ignore + def serialize_unknown_type(self, type_: str) -> str: + return type_ + AnnotationDelta: TypeAlias = Annotated[ - Union[FileCitationDeltaAnnotation, FilePathDeltaAnnotation], PropertyInfo(discriminator="type") + Union[BaseDeltaAnnotation, FileCitationDeltaAnnotation, FilePathDeltaAnnotation], PropertyInfo(discriminator="type") ] diff --git a/src/openai/types/beta/threads/message_content.py b/src/openai/types/beta/threads/message_content.py index 9523c1e1b9..0156f76123 100644 --- a/src/openai/types/beta/threads/message_content.py +++ b/src/openai/types/beta/threads/message_content.py @@ -1,18 +1,35 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Union -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Literal, Annotated, TypeAlias from ...._utils import PropertyInfo +from ...._compat import PYDANTIC_V2 +from ...._models import BaseModel from .text_content_block import TextContentBlock from .refusal_content_block import RefusalContentBlock from .image_url_content_block import ImageURLContentBlock from .image_file_content_block import ImageFileContentBlock -__all__ = ["MessageContent"] +if PYDANTIC_V2: + from pydantic import field_serializer + + +__all__ = ["MessageContent", "BaseContentBlock"] + + +class BaseContentBlock(BaseModel): + type: Literal["unknown"] + """The type of content part""" + + if PYDANTIC_V2: + + @field_serializer("type", when_used="always") # type: ignore + def serialize_unknown_type(self, type_: str) -> str: + return type_ MessageContent: TypeAlias = Annotated[ - Union[ImageFileContentBlock, ImageURLContentBlock, TextContentBlock, RefusalContentBlock], + Union[BaseContentBlock, ImageFileContentBlock, ImageURLContentBlock, TextContentBlock, RefusalContentBlock], PropertyInfo(discriminator="type"), ] diff --git a/src/openai/types/beta/threads/message_content_delta.py b/src/openai/types/beta/threads/message_content_delta.py index b6e7dfa45a..3095dcff72 100644 --- a/src/openai/types/beta/threads/message_content_delta.py +++ b/src/openai/types/beta/threads/message_content_delta.py @@ -1,17 +1,38 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Union -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Literal, Annotated, TypeAlias from ...._utils import PropertyInfo +from ...._compat import PYDANTIC_V2 +from ...._models import BaseModel from .text_delta_block import TextDeltaBlock from .refusal_delta_block import RefusalDeltaBlock from .image_url_delta_block import ImageURLDeltaBlock from .image_file_delta_block import ImageFileDeltaBlock -__all__ = ["MessageContentDelta"] +if PYDANTIC_V2: + from pydantic import field_serializer + + +__all__ = ["MessageContentDelta", "BaseDeltaBlock"] + + +class BaseDeltaBlock(BaseModel): + index: int + """The index of the content part in the message.""" + + type: Literal["unknown"] + """The type of content part""" + + if PYDANTIC_V2: + + @field_serializer("type", when_used="always") # type: ignore + def serialize_unknown_type(self, type_: str) -> str: + return type_ + MessageContentDelta: TypeAlias = Annotated[ - Union[ImageFileDeltaBlock, TextDeltaBlock, RefusalDeltaBlock, ImageURLDeltaBlock], + Union[BaseDeltaBlock, ImageFileDeltaBlock, TextDeltaBlock, RefusalDeltaBlock, ImageURLDeltaBlock], PropertyInfo(discriminator="type"), ] diff --git a/src/openai/types/beta/threads/runs/__init__.py b/src/openai/types/beta/threads/runs/__init__.py index a312ce3df2..402d439947 100644 --- a/src/openai/types/beta/threads/runs/__init__.py +++ b/src/openai/types/beta/threads/runs/__init__.py @@ -3,9 +3,9 @@ from __future__ import annotations from .run_step import RunStep as RunStep -from .tool_call import ToolCall as ToolCall +from .tool_call import ToolCall as ToolCall, BaseToolCall as BaseToolCall from .run_step_delta import RunStepDelta as RunStepDelta -from .tool_call_delta import ToolCallDelta as ToolCallDelta +from .tool_call_delta import ToolCallDelta as ToolCallDelta, BaseToolCallDelta as BaseToolCallDelta from .step_list_params import StepListParams as StepListParams from .function_tool_call import FunctionToolCall as FunctionToolCall from .run_step_delta_event import RunStepDeltaEvent as RunStepDeltaEvent diff --git a/src/openai/types/beta/threads/runs/tool_call.py b/src/openai/types/beta/threads/runs/tool_call.py index 565e3109be..ec35747790 100644 --- a/src/openai/types/beta/threads/runs/tool_call.py +++ b/src/openai/types/beta/threads/runs/tool_call.py @@ -1,15 +1,38 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Union -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Literal, Annotated, TypeAlias from ....._utils import PropertyInfo +from ....._compat import PYDANTIC_V2 +from ....._models import BaseModel from .function_tool_call import FunctionToolCall from .file_search_tool_call import FileSearchToolCall from .code_interpreter_tool_call import CodeInterpreterToolCall -__all__ = ["ToolCall"] +if PYDANTIC_V2: + from pydantic import field_serializer + + +__all__ = ["ToolCall", "BaseToolCall"] + + +class BaseToolCall(BaseModel): + id: str + """The ID of the tool call.""" + + type: Literal["unknown"] + """The type of tool call. + """ + + if PYDANTIC_V2: + + @field_serializer("type", when_used="always") # type: ignore + def serialize_unknown_type(self, type_: str) -> str: + return type_ + ToolCall: TypeAlias = Annotated[ - Union[CodeInterpreterToolCall, FileSearchToolCall, FunctionToolCall], PropertyInfo(discriminator="type") + Union[BaseToolCall, CodeInterpreterToolCall, FileSearchToolCall, FunctionToolCall], + PropertyInfo(discriminator="type"), ] diff --git a/src/openai/types/beta/threads/runs/tool_call_delta.py b/src/openai/types/beta/threads/runs/tool_call_delta.py index f0b8070c97..6ed45bcb69 100644 --- a/src/openai/types/beta/threads/runs/tool_call_delta.py +++ b/src/openai/types/beta/threads/runs/tool_call_delta.py @@ -1,16 +1,38 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Union -from typing_extensions import Annotated, TypeAlias +from typing_extensions import Literal, Annotated, TypeAlias from ....._utils import PropertyInfo +from ....._compat import PYDANTIC_V2 +from ....._models import BaseModel from .function_tool_call_delta import FunctionToolCallDelta from .file_search_tool_call_delta import FileSearchToolCallDelta from .code_interpreter_tool_call_delta import CodeInterpreterToolCallDelta -__all__ = ["ToolCallDelta"] +if PYDANTIC_V2: + from pydantic import field_serializer + + +__all__ = ["ToolCallDelta", "BaseToolCallDelta"] + + +class BaseToolCallDelta(BaseModel): + index: int + """The index of the tool call in the tool calls array.""" + + type: Literal["unknown"] + """The type of tool call. + """ + + if PYDANTIC_V2: + + @field_serializer("type", when_used="always") # type: ignore + def serialize_unknown_type(self, type_: str) -> str: + return type_ + ToolCallDelta: TypeAlias = Annotated[ - Union[CodeInterpreterToolCallDelta, FileSearchToolCallDelta, FunctionToolCallDelta], + Union[BaseToolCallDelta, CodeInterpreterToolCallDelta, FileSearchToolCallDelta, FunctionToolCallDelta], PropertyInfo(discriminator="type"), ] diff --git a/tests/lib/test_assistants.py b/tests/lib/test_assistants.py index 67d021ec35..a1c1d995aa 100644 --- a/tests/lib/test_assistants.py +++ b/tests/lib/test_assistants.py @@ -1,9 +1,25 @@ from __future__ import annotations +from typing import Any + import pytest from openai import OpenAI, AsyncOpenAI from openai._utils import assert_signatures_in_sync +from openai._models import construct_type +from openai.types.beta import BaseTool +from openai.types.beta.threads import ( + Text, + Message, + TextDelta, + MessageDelta, + BaseAnnotation, + BaseDeltaBlock, + BaseContentBlock, + BaseDeltaAnnotation, +) +from openai.types.beta.assistant import Assistant +from openai.types.beta.threads.runs import RunStep, BaseToolCall, RunStepDelta, BaseToolCallDelta @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) @@ -48,3 +64,170 @@ def test_create_and_poll_method_definition_in_sync(sync: bool, client: OpenAI, a checking_client.beta.threads.runs.create_and_poll, exclude_params={"stream"}, ) + + +def test_assistants_unknown_create_tool_response() -> None: + response: dict[str, Any] = { + "id": "asst_xxx", + "created_at": 1722882650, + "description": None, + "instructions": "", + "metadata": {}, + "model": "gpt-4", + "name": "xxx", + "object": "assistant", + "tools": [{"type": "tool_unknown", "unknown": {}}], + "response_format": "auto", + "temperature": 1.0, + "tool_resources": {}, + "top_p": 1.0, + } + assistant = construct_type(type_=Assistant, value=response) + assert isinstance(assistant, Assistant) + assert isinstance(assistant.tools[0], BaseTool) + assert assistant.tools[0].type == "tool_unknown" # type: ignore[comparison-overlap] + a = assistant.model_dump() # type: ignore[unreachable] + assert a["tools"][0]["type"] == "tool_unknown" + assert a["tools"][0]["unknown"] == {} + + +def test_assistants_unknown_annotation_response() -> None: + response: dict[str, Any] = { + "annotations": [ + { + "text": "text", + "type": "unknown_citation", + "start_index": 150, + "end_index": 162, + "unknown_citation": {}, + }, + { + "text": "text", + "type": "unknown_citation", + "start_index": 150, + "end_index": 162, + "unknown_citation": {}, + }, + ], + "value": "", + } + text = construct_type(type_=Text, value=response) + assert isinstance(text, Text) + assert isinstance(text.annotations[0], BaseAnnotation) + assert text.annotations[0].type == "unknown_citation" # type: ignore[comparison-overlap] + t = text.model_dump() # type: ignore[unreachable] + assert t["annotations"][0]["type"] == "unknown_citation" + + +def test_assistants_unknown_annotation_delta_response() -> None: + response: dict[str, Any] = { + "annotations": [ + { + "index": 0, + "type": "unknown_citation", + "start_index": 150, + "end_index": 162, + "unknown_citation": {}, + }, + { + "index": 1, + "type": "unknown_citation", + "start_index": 150, + "end_index": 162, + "unknown_citation": {}, + }, + ], + "value": "", + } + text_delta = construct_type(type_=TextDelta, value=response) + assert isinstance(text_delta, TextDelta) + assert text_delta.annotations + assert isinstance(text_delta.annotations[0], BaseDeltaAnnotation) + assert text_delta.annotations[0].type == "unknown_citation" # type: ignore[comparison-overlap] + td = text_delta.model_dump() # type: ignore[unreachable] + assert td["annotations"][0]["type"] == "unknown_citation" + + +def test_assistants_unknown_message_content_response() -> None: + response: dict[str, Any] = { + "id": "msg_xxx", + "assistant_id": None, + "attachments": [], + "content": [{"unknown_content": {}, "type": "unknown_content"}], + "created_at": 1722885796, + "metadata": {}, + "object": "thread.message", + "role": "user", + "run_id": None, + "thread_id": "thread_xxx", + } + message = construct_type(type_=Message, value=response) + assert isinstance(message, Message) + assert isinstance(message.content[0], BaseContentBlock) + assert message.content[0].type == "unknown_content" # type: ignore[comparison-overlap] + msg = message.model_dump() # type: ignore[unreachable] + assert msg["content"][0]["type"] == "unknown_content" + + +def test_assistants_unknown_message_content_delta_response() -> None: + response: dict[str, Any] = { + "content": [{"index": 1, "unknown_content": {}, "type": "unknown_content"}], + "role": "user", + } + message_delta = construct_type(type_=MessageDelta, value=response) + assert isinstance(message_delta, MessageDelta) + assert message_delta.content + assert isinstance(message_delta.content[0], BaseDeltaBlock) + assert message_delta.content[0].type == "unknown_content" # type: ignore[comparison-overlap] + md = message_delta.model_dump() # type: ignore[unreachable] + assert md["content"][0]["type"] == "unknown_content" + + +def test_assistants_unknown_tool_call_response() -> None: + response: dict[str, Any] = { + "id": "step_xxx", + "assistant_id": "asst_xxx", + "cancelled_at": None, + "completed_at": None, + "created_at": 1722644003, + "failed_at": None, + "last_error": None, + "object": "thread.run.step", + "run_id": "run_xxx", + "status": "in_progress", + "step_details": { + "tool_calls": [{"type": "tool_unknown", "id": "call_xxx", "unknown": {}}], + "type": "tool_calls", + }, + "thread_id": "thread_xxx", + "type": "tool_calls", + "usage": None, + "expires_at": 1722644600, + } + run_step = construct_type(type_=RunStep, value=response) + assert isinstance(run_step, RunStep) + assert run_step.step_details + assert run_step.step_details.type == "tool_calls" + assert run_step.step_details.tool_calls + assert isinstance(run_step.step_details.tool_calls[0], BaseToolCall) + assert run_step.step_details.tool_calls[0].type == "tool_unknown" # type: ignore[comparison-overlap] + rs = run_step.model_dump() # type: ignore[unreachable] + assert rs["step_details"]["tool_calls"][0]["type"] == "tool_unknown" + + +def test_assistants_unknown_tool_call_delta_response() -> None: + response: dict[str, Any] = { + "step_details": { + "tool_calls": [{"index": 0, "type": "tool_unknown", "id": "call_xxx", "unknown": {}}], + "type": "tool_calls", + }, + } + run_step_delta = construct_type(type_=RunStepDelta, value=response) + assert isinstance(run_step_delta, RunStepDelta) + assert run_step_delta.step_details + assert run_step_delta.step_details.type == "tool_calls" + assert run_step_delta.step_details.tool_calls + assert isinstance(run_step_delta.step_details.tool_calls[0], BaseToolCallDelta) + assert run_step_delta.step_details.tool_calls[0].type == "tool_unknown" # type: ignore[comparison-overlap] + rsd = run_step_delta.model_dump() # type: ignore[unreachable] + assert rsd["step_details"]["tool_calls"][0]["type"] == "tool_unknown" diff --git a/tests/test_models.py b/tests/test_models.py index b703444248..ab23d8a32d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -827,3 +827,42 @@ class B(BaseModel): # if the discriminator details object stays the same between invocations then # we hit the cache assert UnionType.__discriminator__ is discriminator + + +def test_discriminated_unions_nested_unknown_variant() -> None: + + class Code(BaseModel): + type: Literal["code"] + + name: str + + class Function(BaseModel): + type: Literal["function"] + + name: str + + class Message(BaseModel): + type: Literal["message"] + + message: str + + class Tool(BaseModel): + type: Literal["tool"] + + tools: list[Annotated[Union[Code, Function], PropertyInfo(discriminator="type")]] + + class Model(BaseModel): + + data: str + + result: Annotated[Union[Message, Tool], PropertyInfo(discriminator="type")] + + # should construct a Tool object regardless of unknown data in tools + m = construct_type( + value={"data": "foo", "result": {"type": "tool", "tools": [{"type": "unknown", "name": "bar"}]}}, + type_=Model, + ) + m = cast(Model, m) + assert isinstance(m.result, Tool) + assert m.result.type == "tool" + assert m.result.tools[0].type == "unknown" # type: ignore[comparison-overlap]