1
1
from __future__ import annotations
2
2
3
3
from itertools import chain
4
- from typing import Any , ClassVar , cast
4
+ from typing import Any , ClassVar , OrderedDict , cast
5
5
6
6
from attr import define , evolve
7
7
15
15
16
16
@define
17
17
class DiscriminatorDefinition :
18
+ """Represents a discriminator that can optionally be specified for a union type.
19
+
20
+ Normally, a UnionProperty has either zero or one of these. However, a nested union
21
+ could have more than one, as we accumulate all the discriminators when we flatten
22
+ out the nested schemas. For example:
23
+
24
+ anyOf:
25
+ - anyOf:
26
+ - $ref: "#/components/schemas/Cat"
27
+ - $ref: "#/components/schemas/Dog"
28
+ discriminator:
29
+ propertyName: mammalType
30
+ - anyOf:
31
+ - $ref: "#/components/schemas/Condor"
32
+ - $ref: "#/components/schemas/Chicken"
33
+ discriminator:
34
+ propertyName: birdType
35
+
36
+ In this example there are four schemas and two discriminators. The deserializer
37
+ logic will check for the mammalType property first, then birdType.
38
+ """
18
39
property_name : str
19
40
value_to_model_map : dict [str , PropertyProtocol ]
20
41
# Every value in the map is really a ModelProperty, but this avoids circular imports
@@ -75,7 +96,7 @@ def build(
75
96
return PropertyError (detail = f"Invalid property in union { name } " , data = sub_prop_data ), schemas
76
97
sub_properties .append (sub_prop )
77
98
78
- sub_properties , discriminators_list = _flatten_union_properties (sub_properties )
99
+ sub_properties , discriminators_from_nested_unions = _flatten_union_properties (sub_properties )
79
100
80
101
prop = UnionProperty (
81
102
name = name ,
@@ -92,15 +113,14 @@ def build(
92
113
return default_or_error , schemas
93
114
prop = evolve (prop , default = default_or_error )
94
115
116
+ all_discriminators = discriminators_from_nested_unions
95
117
if data .discriminator :
96
118
discriminator_or_error = _parse_discriminator (data .discriminator , sub_properties , schemas )
97
119
if isinstance (discriminator_or_error , PropertyError ):
98
120
return discriminator_or_error , schemas
99
- discriminators_list = [discriminator_or_error , * discriminators_list ]
100
- if discriminators_list :
101
- if error := _validate_discriminators (discriminators_list ):
102
- return error , schemas
103
- prop = evolve (prop , discriminators = discriminators_list )
121
+ all_discriminators = [discriminator_or_error , * all_discriminators ]
122
+ if all_discriminators :
123
+ prop = evolve (prop , discriminators = all_discriminators )
104
124
105
125
return prop , schemas
106
126
@@ -227,15 +247,33 @@ def _parse_discriminator(
227
247
228
248
# See: https://spec.openapis.org/oas/v3.1.0.html#discriminator-object
229
249
230
- def _find_top_level_model (matching_model : ModelProperty ) -> ModelProperty | None :
231
- # This is needed because, when we built the union list, $refs were changed into a copy of
232
- # the type they referred to, without preserving the original name. We need to know that
233
- # every type in the discriminator is a $ref to a top-level type and we need its name.
234
- for prop in schemas .classes_by_reference .values ():
235
- if isinstance (prop , ModelProperty ):
236
- if prop .class_info == matching_model .class_info :
237
- return prop
238
- return None
250
+ # Conditions that must be true when there is a discriminator:
251
+ # 1. Every type in the anyOf/oneOf list must be a $ref to a named schema, such as
252
+ # #/components/schemas/X, rather than an inline schema. This is important because
253
+ # we may need to use the schema's simple name (X).
254
+ # 2. There must be a propertyName, representing a property that exists in every
255
+ # schema in that list (although we can't currently enforce the latter condition,
256
+ # because those properties haven't been parsed yet at this point.)
257
+ #
258
+ # There *may* also be a mapping of lookup values (the possible values of the property)
259
+ # to schemas. Schemas can be referenced either by a full path or a name:
260
+ # mapping:
261
+ # value_for_a: "#/components/schemas/ModelA"
262
+ # value_for_b: ModelB # equivalent to "#/components/schemas/ModelB"
263
+ #
264
+ # For any type that isn't specified in the mapping (or if the whole mapping is omitted)
265
+ # the default lookup value for each schema is the same as the schema name. So this--
266
+ # mapping:
267
+ # value_for_a: "#/components/schemas/ModelA"
268
+ # --is exactly equivalent to this:
269
+ # discriminator:
270
+ # propertyName: modelType
271
+ # mapping:
272
+ # value_for_a: "#/components/schemas/ModelA"
273
+ # ModelB: "#/components/schemas/ModelB"
274
+
275
+ def _get_model_name (model : ModelProperty ) -> str | None :
276
+ return get_reference_simple_name (model .ref_path ) if model .ref_path else None
239
277
240
278
model_types_by_name : dict [str , PropertyProtocol ] = {}
241
279
for model in subtypes :
@@ -245,59 +283,32 @@ def _find_top_level_model(matching_model: ModelProperty) -> ModelProperty | None
245
283
return PropertyError (
246
284
detail = "All schema variants must be objects when using a discriminator" ,
247
285
)
248
- top_level_model = _find_top_level_model (model )
249
- if not top_level_model :
286
+ name = _get_model_name (model )
287
+ if not name :
250
288
return PropertyError (
251
289
detail = "Inline schema declarations are not allowed when using a discriminator" ,
252
290
)
253
- name = top_level_model .name
254
- if name .startswith ("/components/schemas/" ):
255
- name = get_reference_simple_name (name )
256
- model_types_by_name [name ] = top_level_model
257
-
258
- # The discriminator can specify an explicit mapping of values to types, but it doesn't
259
- # have to; the default behavior is that the value for each type is simply its name.
260
- mapping : dict [str , PropertyProtocol ] = model_types_by_name .copy ()
291
+ model_types_by_name [name ] = model
292
+
293
+ mapping : dict [str , PropertyProtocol ] = OrderedDict () # use ordered dict for test determinacy
294
+ unspecified_models = list (model_types_by_name .values ())
261
295
if data .mapping :
262
296
for discriminator_value , model_ref in data .mapping .items ():
263
- ref_path = parse_reference_path (
264
- model_ref if model_ref .startswith ("#/components/schemas/" ) else f"#/components/schemas/{ model_ref } "
265
- )
266
- if isinstance (ref_path , ParseError ) or ref_path not in schemas .classes_by_reference :
267
- return PropertyError (detail = f'Invalid reference "{ model_ref } " in discriminator mapping' )
268
- name = get_reference_simple_name (ref_path )
269
- if not (lookup_model := model_types_by_name .get (name )):
297
+ if "/" in model_ref :
298
+ ref_path = parse_reference_path (model_ref )
299
+ if isinstance (ref_path , ParseError ) or ref_path not in schemas .classes_by_reference :
300
+ return PropertyError (detail = f'Invalid reference "{ model_ref } " in discriminator mapping' )
301
+ name = get_reference_simple_name (ref_path )
302
+ else :
303
+ name = model_ref
304
+ model = model_types_by_name .get (name )
305
+ if not model :
270
306
return PropertyError (
271
- detail = f'Discriminator mapping referred to "{ model_ref } " which is not one of the schema variants' ,
307
+ detail = f'Discriminator mapping referred to "{ name } " which is not one of the schema variants' ,
272
308
)
273
- for original_value in (name for name , m in model_types_by_name .items () if m == lookup_model ):
274
- mapping .pop (original_value )
275
- mapping [discriminator_value ] = lookup_model
276
- else :
277
- mapping = model_types_by_name
278
-
309
+ mapping [discriminator_value ] = model
310
+ unspecified_models .remove (model )
311
+ for model in unspecified_models :
312
+ if name := _get_model_name (model ):
313
+ mapping [name ] = model
279
314
return DiscriminatorDefinition (property_name = data .propertyName , value_to_model_map = mapping )
280
-
281
-
282
- def _validate_discriminators (
283
- discriminators : list [DiscriminatorDefinition ],
284
- ) -> PropertyError | None :
285
- from .model_property import ModelProperty
286
-
287
- prop_names_values_classes = [
288
- (discriminator .property_name , key , cast (ModelProperty , model ).class_info .name )
289
- for discriminator in discriminators
290
- for key , model in discriminator .value_to_model_map .items ()
291
- ]
292
- for p , v in {(p , v ) for p , v , _ in prop_names_values_classes }:
293
- if len ({c for p1 , v1 , c in prop_names_values_classes if (p1 , v1 ) == (p , v )}) > 1 :
294
- return PropertyError (f'Discriminator property "{ p } " had more than one schema for value "{ v } "' )
295
- return None
296
-
297
- # TODO: We should also validate that property_name refers to a property that 1. exists,
298
- # 2. is required, 3. is a string (in all of these models). However, currently we can't
299
- # do that because, at the time this function is called, the ModelProperties within the
300
- # union haven't yet been post-processed and so we don't have full information about
301
- # their properties. To fix this, we may need to generalize the post-processing phase so
302
- # that any Property type, not just ModelProperty, can say it needs post-processing; then
303
- # we can defer _validate_discriminators till that phase.
0 commit comments