Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add discriminator property support #1154

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions .changeset/discriminators.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
---
default: minor
---

# Add discriminator property support

The optional `discriminator` field, when used in a schema with `anyOf` or `oneOf` as described in [OpenAPI 3.1.0](https://spec.openapis.org/oas/v3.1.0.html#discriminator-object), now correctly produces deserialization logic for using the specified property value to determine the appropriate type.

In this example, `PolymorphicModel.thing` will be deserialized as a `ThingA` if the value of the `modelType` property is `"ThingA"`, or as a `ThingB` if the value is `"ThingB"`:

```yaml
ThingA:
type: object
properties:
thingType:
type: string
name:
type: string

ThingB:
type: object
properties:
thingType:
type: string
name:
type: string

PolymorphicModel:
type: object
properties:
thing:
anyOf:
- "#/components/schemas/ThingA"
- "#/components/schemas/ThingB"
discriminator:
propertyName: modelType
```

If you want to use property values that are not the same as the schema names, you can add a `mapping`. In this example, the value is expected to be `"A"` or `"B"`, instead of `"ThingA"` or `"ThingB"`:

```yaml
discriminator:
propertyName: modelType
mapping:
A: "#/components/schemas/ThingA"
B: "#/components/schemas/ThingB"
```

This could also be written more concisely as:

```yaml
discriminator:
propertyName: modelType
mapping:
A: "ThingA"
B: "ThingB"
```

If you specify a property name that does not exist in all of the variant schemas, the behavior is undefined.
9 changes: 6 additions & 3 deletions end_to_end_tests/baseline_openapi_3.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -2864,7 +2864,8 @@
"propertyName": "modelType",
"mapping": {
"type1": "#/components/schemas/ADiscriminatedUnionType1",
"type2": "#/components/schemas/ADiscriminatedUnionType2"
"type2": "#/components/schemas/ADiscriminatedUnionType2",
"type2-another-value": "#/components/schemas/ADiscriminatedUnionType2"
}
},
"oneOf": [
Expand All @@ -2882,15 +2883,17 @@
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary in order for the discriminator in the spec to really be valid: a discriminator property must be a required property in all of the variant schemas. My current implementation wouldn't actually catch a mistake like this, but I figured it was best to have the test spec be valid.

Copy link
Collaborator Author

@eli-bl eli-bl Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least, that's my reading of what OpenAPI says. I'm basing it on the statement "The expectation now is that a property with name petType MUST be present in the response payload, and the value will correspond to the name of a schema defined in the OAS document."

However, this is a case (one of many) of the OpenAPI specification using a mix of normative and non-normative language. The MUST seems unambiguous, but it's in a paragraph that's describing this specific example with types of pets, where the schemas do say the property is required. So do they mean a discriminator property must always be required, or just that that's the behavior this example is illustrating?

Logically I feel like it makes sense for it to be required, and if it isn't, then I'm not sure what the client behavior should be: (a) refuse to parse the object if the property is missing, or (b) just fall back to "try parsing it as each of these types until one of them works"? What I've implemented so far in the generated code is (a).

},
"ADiscriminatedUnionType2": {
"type": "object",
"properties": {
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
}
},
"parameters": {
Expand Down
9 changes: 6 additions & 3 deletions end_to_end_tests/baseline_openapi_3.1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2858,7 +2858,8 @@ info:
"propertyName": "modelType",
"mapping": {
"type1": "#/components/schemas/ADiscriminatedUnionType1",
"type2": "#/components/schemas/ADiscriminatedUnionType2"
"type2": "#/components/schemas/ADiscriminatedUnionType2",
"type2-another-value": "#/components/schemas/ADiscriminatedUnionType2"
}
},
"oneOf": [
Expand All @@ -2876,15 +2877,17 @@ info:
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
},
"ADiscriminatedUnionType2": {
"type": "object",
"properties": {
"modelType": {
"type": "string"
}
}
},
"required": ["modelType"]
}
}
"parameters": {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
from typing import Any, TypeVar, Union
from typing import Any, TypeVar

from attrs import define as _attrs_define
from attrs import field as _attrs_field

from ..types import UNSET, Unset

T = TypeVar("T", bound="ADiscriminatedUnionType1")


@_attrs_define
class ADiscriminatedUnionType1:
"""
Attributes:
model_type (Union[Unset, str]):
model_type (str):
"""

model_type: Union[Unset, str] = UNSET
model_type: str
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> dict[str, Any]:
model_type = self.model_type

field_dict: dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if model_type is not UNSET:
field_dict["modelType"] = model_type
field_dict.update(
{
"modelType": model_type,
}
)

return field_dict

@classmethod
def from_dict(cls: type[T], src_dict: dict[str, Any]) -> T:
d = src_dict.copy()
model_type = d.pop("modelType", UNSET)
model_type = d.pop("modelType")

a_discriminated_union_type_1 = cls(
model_type=model_type,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
from typing import Any, TypeVar, Union
from typing import Any, TypeVar

from attrs import define as _attrs_define
from attrs import field as _attrs_field

from ..types import UNSET, Unset

T = TypeVar("T", bound="ADiscriminatedUnionType2")


@_attrs_define
class ADiscriminatedUnionType2:
"""
Attributes:
model_type (Union[Unset, str]):
model_type (str):
"""

model_type: Union[Unset, str] = UNSET
model_type: str
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)

def to_dict(self) -> dict[str, Any]:
model_type = self.model_type

field_dict: dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if model_type is not UNSET:
field_dict["modelType"] = model_type
field_dict.update(
{
"modelType": model_type,
}
)

return field_dict

@classmethod
def from_dict(cls: type[T], src_dict: dict[str, Any]) -> T:
d = src_dict.copy()
model_type = d.pop("modelType", UNSET)
model_type = d.pop("modelType")

a_discriminated_union_type_2 = cls(
model_type=model_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,42 @@ def _parse_discriminated_union(
return data
if isinstance(data, Unset):
return data
try:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)

return componentsschemas_a_discriminated_union_type_0
except: # noqa: E722
pass
try:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)

return componentsschemas_a_discriminated_union_type_1
except: # noqa: E722
pass
return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data)
if not isinstance(data, dict):
raise TypeError()
if "modelType" in data:
_discriminator_value = data["modelType"]

def _parse_1(data: object) -> ADiscriminatedUnionType1:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)

return componentsschemas_a_discriminated_union_type_0

def _parse_2(data: object) -> ADiscriminatedUnionType2:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)

return componentsschemas_a_discriminated_union_type_1

def _parse_3(data: object) -> ADiscriminatedUnionType2:
if not isinstance(data, dict):
raise TypeError()
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)

return componentsschemas_a_discriminated_union_type_1

_discriminator_mapping = {
"type1": _parse_1,
"type2": _parse_2,
"type2-another-value": _parse_3,
}
if _parse_fn := _discriminator_mapping.get(_discriminator_value):
return cast(
Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], _parse_fn(data)
)
raise TypeError("unrecognized value for property modelType")

discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET))

Expand Down
1 change: 1 addition & 0 deletions openapi_python_client/parser/properties/model_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ModelProperty(PropertyProtocol):
relative_imports: set[str] | None
lazy_imports: set[str] | None
additional_properties: Property | None
ref_path: ReferencePath | None = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_json_type_string: ClassVar[str] = "dict[str, Any]"

template: ClassVar[str] = "model_property.py.jinja"
Expand Down
5 changes: 5 additions & 0 deletions openapi_python_client/parser/properties/protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from openapi_python_client.parser.properties.schemas import ReferencePath

__all__ = ["PropertyProtocol", "Value"]

from abc import abstractmethod
Expand Down Expand Up @@ -185,3 +187,6 @@ def is_base_type(self) -> bool:
ListProperty.__name__,
UnionProperty.__name__,
}

def get_ref_path(self) -> ReferencePath | None:
return self.ref_path if hasattr(self, "ref_path") else None
16 changes: 16 additions & 0 deletions openapi_python_client/parser/properties/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def parse_reference_path(ref_path_raw: str) -> Union[ReferencePath, ParseError]:
return cast(ReferencePath, parsed.fragment)


def get_reference_simple_name(ref_path: str) -> str:
"""
Takes a path like `/components/schemas/NameOfThing` and returns a string like `NameOfThing`.
"""
return ref_path.split("/", 3)[-1]


@define
class Class:
"""Represents Python class which will be generated from an OpenAPI schema"""
Expand Down Expand Up @@ -135,6 +142,15 @@ def update_schemas_with_data(
)
return prop

# Save the original path (/components/schemas/X) in the property. This is important because:
# 1. There are some contexts (such as a union with a discriminator) where we have a Property
# instance and we want to know what its path is, instead of the other way round.
# 2. Even though we did set prop.name to be the same as ref_path when we created it above,
# whenever there's a $ref to this property, we end up making a copy of it and changing
# the name. So we can't rely on prop.name always being the path.
if hasattr(prop, "ref_path"):
prop.ref_path = ref_path

schemas = evolve(schemas, classes_by_reference={ref_path: prop, **schemas.classes_by_reference})
return schemas

Expand Down
Loading
Loading