-
-
Notifications
You must be signed in to change notification settings - Fork 227
/
Copy pathenum_property.py
214 lines (185 loc) · 8.47 KB
/
enum_property.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from __future__ import annotations
__all__ = ["EnumProperty", "ValueType"]
from typing import Any, ClassVar, Union, cast
from attr import evolve
from attrs import define
from ... import Config, utils
from ... import schema as oai
from ...schema import DataType
from ..errors import PropertyError
from .none import NoneProperty
from .protocol import PropertyProtocol, Value
from .schemas import Class, Schemas
from .union import UnionProperty
ValueType = Union[str, int]
@define
class EnumProperty(PropertyProtocol):
"""A property that should use an enum"""
name: str
required: bool
default: Value | None
python_name: utils.PythonIdentifier
description: str | None
example: str | None
values: dict[str, ValueType]
class_info: Class
value_type: type[ValueType]
template: ClassVar[str] = "enum_property.py.jinja"
_allowed_locations: ClassVar[set[oai.ParameterLocation]] = {
oai.ParameterLocation.QUERY,
oai.ParameterLocation.PATH,
oai.ParameterLocation.COOKIE,
oai.ParameterLocation.HEADER,
}
@classmethod
def build( # noqa: PLR0911
cls,
*,
data: oai.Schema,
name: str,
required: bool,
schemas: Schemas,
parent_name: str,
config: Config,
) -> tuple[EnumProperty | NoneProperty | UnionProperty | PropertyError, Schemas]:
"""
Create an EnumProperty from schema data.
Args:
data: The OpenAPI Schema which defines this enum.
name: The name to use for variables which receive this Enum's value (e.g. model property name)
required: Whether or not this Property is required in the calling context
schemas: The Schemas which have been defined so far (used to prevent naming collisions)
enum: The enum from the provided data. Required separately here to prevent extra type checking.
parent_name: The context in which this EnumProperty is defined, used to create more specific class names.
config: The global config for this run of the generator
Returns:
A tuple containing either the created property or a PropertyError AND update schemas.
"""
enum = data.enum or [] # The outer function checks for this, but mypy doesn't know that
# OpenAPI allows for null as an enum value, but it doesn't make sense with how enums are constructed in Python.
# So instead, if null is a possible value, make the property nullable.
# Mypy is not smart enough to know that the type is right though
unchecked_value_list = [value for value in enum if value is not None] # type: ignore
# It's legal to have an enum that only contains null as a value, we don't bother constructing an enum for that
if len(unchecked_value_list) == 0:
return (
NoneProperty.build(
name=name,
required=required,
default="None",
python_name=utils.PythonIdentifier(value=name, prefix=config.field_prefix),
description=None,
example=None,
),
schemas,
)
value_types = {type(value) for value in unchecked_value_list}
if len(value_types) > 1:
return PropertyError(
header="Enum values must all be the same type", detail=f"Got {value_types}", data=data
), schemas
value_type = next(iter(value_types))
if value_type not in (str, int):
return PropertyError(header=f"Unsupported enum type {value_type}", data=data), schemas
value_list = cast(
Union[list[int], list[str]], unchecked_value_list
) # We checked this with all the value_types stuff
if len(value_list) < len(enum): # Only one of the values was None, that becomes a union
data.oneOf = [
oai.Schema(type=DataType.NULL),
data.model_copy(update={"enum": value_list, "default": data.default}),
]
data.enum = None
return UnionProperty.build(
data=data,
name=name,
required=required,
schemas=schemas,
parent_name=parent_name,
config=config,
)
class_name = data.title or name
if parent_name:
class_name = f"{utils.pascal_case(parent_name)}{utils.pascal_case(class_name)}"
original_class_name = class_name
name_index = 0
while class_name in schemas.classes_by_name:
name_index += 1
class_name = f"{original_class_name}PropertyEnum_{name_index}"
class_info = Class.from_string(string=class_name, config=config)
values = EnumProperty.values_from_list(value_list, class_info)
if class_info.name in schemas.classes_by_name:
existing = schemas.classes_by_name[class_info.name]
if not isinstance(existing, EnumProperty) or values != existing.values:
return (
PropertyError(
detail=f"Found conflicting enums named {class_info.name} with incompatible values.", data=data
),
schemas,
)
prop = EnumProperty(
name=name,
required=required,
class_info=class_info,
values=values,
value_type=value_type,
default=None,
python_name=utils.PythonIdentifier(value=name, prefix=config.field_prefix),
description=data.description,
example=data.example,
)
checked_default = prop.convert_value(data.default)
if isinstance(checked_default, PropertyError):
checked_default.data = data
return checked_default, schemas
prop = evolve(prop, default=checked_default)
schemas = evolve(schemas, classes_by_name={**schemas.classes_by_name, class_info.name: prop})
return prop, schemas
def convert_value(self, value: Any) -> Value | PropertyError | None:
if value is None or isinstance(value, Value):
return value
if isinstance(value, self.value_type):
inverse_values = {v: k for k, v in self.values.items()}
try:
return Value(python_code=f"{self.class_info.name}.{inverse_values[value]}", raw_value=value)
except KeyError:
return PropertyError(detail=f"Value {value} is not valid for enum {self.name}")
return PropertyError(detail=f"Cannot convert {value} to enum {self.name} of type {self.value_type}")
def get_base_type_string(self, *, quoted: bool = False) -> str:
return self.class_info.name
def get_base_json_type_string(self, *, quoted: bool = False) -> str:
return self.value_type.__name__
def get_imports(self, *, prefix: str) -> set[str]:
"""
Get a set of import strings that should be included when this property is used somewhere
Args:
prefix: A prefix to put before any relative (local) module names. This should be the number of . to get
back to the root of the generated client.
"""
imports = super().get_imports(prefix=prefix)
imports.add(f"from {prefix}models.{self.class_info.module_name} import {self.class_info.name}")
return imports
@staticmethod
def values_from_list(values: list[str] | list[int], class_info: Class) -> dict[str, ValueType]:
"""Convert a list of values into dict of {name: value}, where value can sometimes be None"""
output: dict[str, ValueType] = {}
for i, value in enumerate(values):
value = cast(Union[str, int], value)
if isinstance(value, int):
if value < 0:
output[f"VALUE_NEGATIVE_{-value}"] = value
else:
output[f"VALUE_{value}"] = value
continue
if value and value[0].isalpha():
key = value.upper()
else:
key = f"VALUE_{i}"
if key in output:
raise ValueError(
f"Duplicate key {key} in enum {class_info.module_name}.{class_info.name}; "
f"consider setting literal_enums in your config"
)
sanitized_key = utils.snake_case(key).upper()
output[sanitized_key] = utils.remove_string_escapes(value)
return output