Skip to content

Commit b75adda

Browse files
committed
simply the discriminator logic a bit
1 parent 0496ff8 commit b75adda

File tree

7 files changed

+104
-117
lines changed

7 files changed

+104
-117
lines changed

end_to_end_tests/golden-record/my_test_api_client/models/model_with_discriminated_union.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,23 @@ def _parse_discriminated_union(
6464
if "modelType" in data:
6565
_discriminator_value = data["modelType"]
6666

67-
def _parse_componentsschemas_a_discriminated_union_type_1(data: object) -> ADiscriminatedUnionType1:
67+
def _parse_1(data: object) -> ADiscriminatedUnionType1:
6868
if not isinstance(data, dict):
6969
raise TypeError()
70-
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType1.from_dict(data)
70+
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)
7171

72-
return componentsschemas_a_discriminated_union_type_1
72+
return componentsschemas_a_discriminated_union_type_0
7373

74-
def _parse_componentsschemas_a_discriminated_union_type_2(data: object) -> ADiscriminatedUnionType2:
74+
def _parse_2(data: object) -> ADiscriminatedUnionType2:
7575
if not isinstance(data, dict):
7676
raise TypeError()
77-
componentsschemas_a_discriminated_union_type_2 = ADiscriminatedUnionType2.from_dict(data)
77+
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)
7878

79-
return componentsschemas_a_discriminated_union_type_2
79+
return componentsschemas_a_discriminated_union_type_1
8080

8181
_discriminator_mapping = {
82-
"type1": _parse_componentsschemas_a_discriminated_union_type_1,
83-
"type2": _parse_componentsschemas_a_discriminated_union_type_2,
82+
"type1": _parse_1,
83+
"type2": _parse_2,
8484
}
8585
if _parse_fn := _discriminator_mapping.get(_discriminator_value):
8686
return cast(

openapi_python_client/parser/properties/model_property.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ModelProperty(PropertyProtocol):
3232
relative_imports: set[str] | None
3333
lazy_imports: set[str] | None
3434
additional_properties: Property | None
35+
ref_path: ReferencePath | None = None
3536
_json_type_string: ClassVar[str] = "Dict[str, Any]"
3637

3738
template: ClassVar[str] = "model_property.py.jinja"

openapi_python_client/parser/properties/protocol.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from openapi_python_client.parser.properties.schemas import ReferencePath
4+
35
__all__ = ["PropertyProtocol", "Value"]
46

57
from abc import abstractmethod
@@ -185,3 +187,6 @@ def is_base_type(self) -> bool:
185187
ListProperty.__name__,
186188
UnionProperty.__name__,
187189
}
190+
191+
def get_ref_path(self) -> ReferencePath | None:
192+
return self.ref_path if hasattr(self, "ref_path") else None

openapi_python_client/parser/properties/schemas.py

+9
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ def update_schemas_with_data(
142142
)
143143
return prop
144144

145+
# Save the original path (/components/schemas/X) in the property. This is important because:
146+
# 1. There are some contexts (such as a union with a discriminator) where we have a Property
147+
# instance and we want to know what its path is, instead of the other way round.
148+
# 2. Even though we did set prop.name to be the same as ref_path when we created it above,
149+
# whenever there's a $ref to this property, we end up making a copy of it and changing
150+
# the name. So we can't rely on prop.name always being the path.
151+
if hasattr(prop, "ref_path"):
152+
prop.ref_path = ref_path
153+
145154
schemas = evolve(schemas, classes_by_reference={ref_path: prop, **schemas.classes_by_reference})
146155
return schemas
147156

openapi_python_client/parser/properties/union.py

+75-64
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from itertools import chain
4-
from typing import Any, ClassVar, cast
4+
from typing import Any, ClassVar, OrderedDict, cast
55

66
from attr import define, evolve
77

@@ -15,6 +15,27 @@
1515

1616
@define
1717
class DiscriminatorDefinition:
18+
"""Represents a discriminator that can optionally be specified for a union type.
19+
20+
Normally, a UnionProperty has either zero or one of these. However, a nested union
21+
could have more than one, as we accumulate all the discriminators when we flatten
22+
out the nested schemas. For example:
23+
24+
anyOf:
25+
- anyOf:
26+
- $ref: "#/components/schemas/Cat"
27+
- $ref: "#/components/schemas/Dog"
28+
discriminator:
29+
propertyName: mammalType
30+
- anyOf:
31+
- $ref: "#/components/schemas/Condor"
32+
- $ref: "#/components/schemas/Chicken"
33+
discriminator:
34+
propertyName: birdType
35+
36+
In this example there are four schemas and two discriminators. The deserializer
37+
logic will check for the mammalType property first, then birdType.
38+
"""
1839
property_name: str
1940
value_to_model_map: dict[str, PropertyProtocol]
2041
# Every value in the map is really a ModelProperty, but this avoids circular imports
@@ -75,7 +96,7 @@ def build(
7596
return PropertyError(detail=f"Invalid property in union {name}", data=sub_prop_data), schemas
7697
sub_properties.append(sub_prop)
7798

78-
sub_properties, discriminators_list = _flatten_union_properties(sub_properties)
99+
sub_properties, discriminators_from_nested_unions = _flatten_union_properties(sub_properties)
79100

80101
prop = UnionProperty(
81102
name=name,
@@ -92,15 +113,14 @@ def build(
92113
return default_or_error, schemas
93114
prop = evolve(prop, default=default_or_error)
94115

116+
all_discriminators = discriminators_from_nested_unions
95117
if data.discriminator:
96118
discriminator_or_error = _parse_discriminator(data.discriminator, sub_properties, schemas)
97119
if isinstance(discriminator_or_error, PropertyError):
98120
return discriminator_or_error, schemas
99-
discriminators_list = [discriminator_or_error, *discriminators_list]
100-
if discriminators_list:
101-
if error := _validate_discriminators(discriminators_list):
102-
return error, schemas
103-
prop = evolve(prop, discriminators=discriminators_list)
121+
all_discriminators = [discriminator_or_error, *all_discriminators]
122+
if all_discriminators:
123+
prop = evolve(prop, discriminators=all_discriminators)
104124

105125
return prop, schemas
106126

@@ -227,15 +247,33 @@ def _parse_discriminator(
227247

228248
# See: https://spec.openapis.org/oas/v3.1.0.html#discriminator-object
229249

230-
def _find_top_level_model(matching_model: ModelProperty) -> ModelProperty | None:
231-
# This is needed because, when we built the union list, $refs were changed into a copy of
232-
# the type they referred to, without preserving the original name. We need to know that
233-
# every type in the discriminator is a $ref to a top-level type and we need its name.
234-
for prop in schemas.classes_by_reference.values():
235-
if isinstance(prop, ModelProperty):
236-
if prop.class_info == matching_model.class_info:
237-
return prop
238-
return None
250+
# Conditions that must be true when there is a discriminator:
251+
# 1. Every type in the anyOf/oneOf list must be a $ref to a named schema, such as
252+
# #/components/schemas/X, rather than an inline schema. This is important because
253+
# we may need to use the schema's simple name (X).
254+
# 2. There must be a propertyName, representing a property that exists in every
255+
# schema in that list (although we can't currently enforce the latter condition,
256+
# because those properties haven't been parsed yet at this point.)
257+
#
258+
# There *may* also be a mapping of lookup values (the possible values of the property)
259+
# to schemas. Schemas can be referenced either by a full path or a name:
260+
# mapping:
261+
# value_for_a: "#/components/schemas/ModelA"
262+
# value_for_b: ModelB # equivalent to "#/components/schemas/ModelB"
263+
#
264+
# For any type that isn't specified in the mapping (or if the whole mapping is omitted)
265+
# the default lookup value for each schema is the same as the schema name. So this--
266+
# mapping:
267+
# value_for_a: "#/components/schemas/ModelA"
268+
# --is exactly equivalent to this:
269+
# discriminator:
270+
# propertyName: modelType
271+
# mapping:
272+
# value_for_a: "#/components/schemas/ModelA"
273+
# ModelB: "#/components/schemas/ModelB"
274+
275+
def _get_model_name(model: ModelProperty) -> str | None:
276+
return get_reference_simple_name(model.ref_path) if model.ref_path else None
239277

240278
model_types_by_name: dict[str, PropertyProtocol] = {}
241279
for model in subtypes:
@@ -245,59 +283,32 @@ def _find_top_level_model(matching_model: ModelProperty) -> ModelProperty | None
245283
return PropertyError(
246284
detail="All schema variants must be objects when using a discriminator",
247285
)
248-
top_level_model = _find_top_level_model(model)
249-
if not top_level_model:
286+
name = _get_model_name(model)
287+
if not name:
250288
return PropertyError(
251289
detail="Inline schema declarations are not allowed when using a discriminator",
252290
)
253-
name = top_level_model.name
254-
if name.startswith("/components/schemas/"):
255-
name = get_reference_simple_name(name)
256-
model_types_by_name[name] = top_level_model
257-
258-
# The discriminator can specify an explicit mapping of values to types, but it doesn't
259-
# have to; the default behavior is that the value for each type is simply its name.
260-
mapping: dict[str, PropertyProtocol] = model_types_by_name.copy()
291+
model_types_by_name[name] = model
292+
293+
mapping: dict[str, PropertyProtocol] = OrderedDict() # use ordered dict for test determinacy
294+
unspecified_models = list(model_types_by_name.values())
261295
if data.mapping:
262296
for discriminator_value, model_ref in data.mapping.items():
263-
ref_path = parse_reference_path(
264-
model_ref if model_ref.startswith("#/components/schemas/") else f"#/components/schemas/{model_ref}"
265-
)
266-
if isinstance(ref_path, ParseError) or ref_path not in schemas.classes_by_reference:
267-
return PropertyError(detail=f'Invalid reference "{model_ref}" in discriminator mapping')
268-
name = get_reference_simple_name(ref_path)
269-
if not (lookup_model := model_types_by_name.get(name)):
297+
if "/" in model_ref:
298+
ref_path = parse_reference_path(model_ref)
299+
if isinstance(ref_path, ParseError) or ref_path not in schemas.classes_by_reference:
300+
return PropertyError(detail=f'Invalid reference "{model_ref}" in discriminator mapping')
301+
name = get_reference_simple_name(ref_path)
302+
else:
303+
name = model_ref
304+
model = model_types_by_name.get(name)
305+
if not model:
270306
return PropertyError(
271-
detail=f'Discriminator mapping referred to "{model_ref}" which is not one of the schema variants',
307+
detail=f'Discriminator mapping referred to "{name}" which is not one of the schema variants',
272308
)
273-
for original_value in (name for name, m in model_types_by_name.items() if m == lookup_model):
274-
mapping.pop(original_value)
275-
mapping[discriminator_value] = lookup_model
276-
else:
277-
mapping = model_types_by_name
278-
309+
mapping[discriminator_value] = model
310+
unspecified_models.remove(model)
311+
for model in unspecified_models:
312+
if name := _get_model_name(model):
313+
mapping[name] = model
279314
return DiscriminatorDefinition(property_name=data.propertyName, value_to_model_map=mapping)
280-
281-
282-
def _validate_discriminators(
283-
discriminators: list[DiscriminatorDefinition],
284-
) -> PropertyError | None:
285-
from .model_property import ModelProperty
286-
287-
prop_names_values_classes = [
288-
(discriminator.property_name, key, cast(ModelProperty, model).class_info.name)
289-
for discriminator in discriminators
290-
for key, model in discriminator.value_to_model_map.items()
291-
]
292-
for p, v in {(p, v) for p, v, _ in prop_names_values_classes}:
293-
if len({c for p1, v1, c in prop_names_values_classes if (p1, v1) == (p, v)}) > 1:
294-
return PropertyError(f'Discriminator property "{p}" had more than one schema for value "{v}"')
295-
return None
296-
297-
# TODO: We should also validate that property_name refers to a property that 1. exists,
298-
# 2. is required, 3. is a string (in all of these models). However, currently we can't
299-
# do that because, at the time this function is called, the ModelProperties within the
300-
# union haven't yet been post-processed and so we don't have full information about
301-
# their properties. To fix this, we may need to generalize the post-processing phase so
302-
# that any Property type, not just ModelProperty, can say it needs post-processing; then
303-
# we can defer _validate_discriminators till that phase.

openapi_python_client/templates/property_templates/union_property.py.jinja

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ if not isinstance(data, dict):
1616
raise TypeError()
1717
if "{{ discriminator.property_name }}" in data:
1818
_discriminator_value = data["{{ discriminator.property_name }}"]
19-
{% for model in discriminator.value_to_model_map.values() %}
20-
def _parse_{{ model.python_name }}(data: object) -> {{ model.get_type_string() }}:
21-
{{ construct_inner_property(model) | indent(12, True) }}
19+
{% for value, model in discriminator.value_to_model_map.items() %}
20+
def _parse_{{ loop.index }}(data: object) -> {{ model.get_type_string() }}:
21+
{{ construct_inner_property(model) | indent(8, True) }}
2222
{% endfor %}
2323
_discriminator_mapping = {
2424
{% for value, model in discriminator.value_to_model_map.items() %}
25-
"{{ value }}": _parse_{{ model.python_name }},
25+
"{{ value }}": _parse_{{ loop.index }},
2626
{% endfor %}
2727
}
2828
if _parse_fn := _discriminator_mapping.get(_discriminator_value):

tests/test_parser/test_properties/test_union.py

+2-41
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from openapi_python_client.parser.properties.model_property import ModelProperty
1111
from openapi_python_client.parser.properties.property import Property
1212
from openapi_python_client.parser.properties.protocol import Value
13-
from openapi_python_client.parser.properties.schemas import Class
13+
from openapi_python_client.parser.properties.schemas import Class, ReferencePath
1414
from openapi_python_client.schema import DataType, ParameterLocation
1515
from tests.test_parser.test_properties.properties_test_helpers import assert_prop_error
1616

@@ -66,6 +66,7 @@ def _make_basic_model(
6666
)
6767
assert isinstance(model, ModelProperty)
6868
if name:
69+
model.ref_path = ReferencePath(f"/components/schemas/{name}")
6970
schemas = evolve(
7071
schemas, classes_by_reference={**schemas.classes_by_reference, f"/components/schemas/{name}": model}
7172
)
@@ -404,43 +405,3 @@ def test_discriminator_invalid_inline_schema_variant(config, string_property_fac
404405
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
405406
)
406407
assert_prop_error(p, "Inline schema")
407-
408-
409-
def test_conflicting_discriminator_mappings(config):
410-
from openapi_python_client.parser.properties import Schemas, property_from_data
411-
412-
schemas = Schemas()
413-
props = {"type": oai.Schema.model_construct(type="string")}
414-
model1, schemas = _make_basic_model("Model1", props, "type", schemas, config)
415-
model2, schemas = _make_basic_model("Model2", props, "type", schemas, config)
416-
model3, schemas = _make_basic_model("Model3", props, "type", schemas, config)
417-
model4, schemas = _make_basic_model("Model4", props, "type", schemas, config)
418-
data = oai.Schema.model_construct(
419-
oneOf=[
420-
oai.Schema.model_construct(
421-
oneOf=[
422-
oai.Reference(ref="#/components/schemas/Model1"),
423-
oai.Reference(ref="#/components/schemas/Model2"),
424-
],
425-
discriminator=oai.Discriminator.model_construct(
426-
propertyName="type",
427-
mapping={"a": "Model1", "b": "Model2"},
428-
),
429-
),
430-
oai.Schema.model_construct(
431-
oneOf=[
432-
oai.Reference(ref="#/components/schemas/Model3"),
433-
oai.Reference(ref="#/components/schemas/Model4"),
434-
],
435-
discriminator=oai.Discriminator.model_construct(
436-
propertyName="type",
437-
mapping={"a": "Model3", "x": "Model4"},
438-
),
439-
),
440-
],
441-
)
442-
443-
p, schemas = property_from_data(
444-
name="MyUnion", required=False, data=data, schemas=schemas, parent_name="parent", config=config
445-
)
446-
assert_prop_error(p, '"type" had more than one schema for value "a"')

0 commit comments

Comments
 (0)