1
1
"""
2
2
Support for Django model fields built from enumeration types.
3
3
"""
4
+ from enum import Enum
4
5
from typing import (
5
6
TYPE_CHECKING ,
6
7
Any ,
16
17
from django .db .models import (
17
18
BigIntegerField ,
18
19
CharField ,
19
- Choices ,
20
20
Field ,
21
21
FloatField ,
22
22
IntegerField ,
27
27
SmallIntegerField ,
28
28
)
29
29
from django .db .models .query_utils import DeferredAttribute
30
+ from django_enum .choices import choices , values
30
31
from django_enum .forms import EnumChoiceField , NonStrictSelect
31
32
32
33
T = TypeVar ('T' ) # pylint: disable=C0103
@@ -79,25 +80,25 @@ class EnumMixin(
79
80
field type.
80
81
"""
81
82
82
- enum : Optional [Type [Choices ]] = None
83
+ enum : Optional [Type [Enum ]] = None
83
84
strict : bool = True
84
85
coerce : bool = True
85
86
86
87
descriptor_class = ToPythonDeferredAttribute
87
88
88
- def _coerce_to_value_type (self , value : Any ) -> Choices :
89
+ def _coerce_to_value_type (self , value : Any ) -> Enum :
89
90
"""Coerce the value to the enumerations value type"""
90
91
# note if enum type is int and a floating point is passed we could get
91
92
# situations like X.xxx == X - this is acceptable
92
93
if self .enum :
93
- return type (self .enum . values [0 ])(value )
94
+ return type (values ( self .enum ) [0 ])(value )
94
95
# can't ever reach this - just here to make type checker happy
95
96
return value # pragma: no cover
96
97
97
98
def __init__ (
98
99
self ,
99
100
* args ,
100
- enum : Optional [Type [Choices ]] = None ,
101
+ enum : Optional [Type [Enum ]] = None ,
101
102
strict : bool = strict ,
102
103
coerce : bool = coerce ,
103
104
** kwargs
@@ -106,14 +107,14 @@ def __init__(
106
107
self .strict = strict if enum else False
107
108
self .coerce = coerce if enum else False
108
109
if self .enum is not None :
109
- kwargs .setdefault ('choices' , enum . choices if enum else [] )
110
+ kwargs .setdefault ('choices' , choices ( enum ) )
110
111
super ().__init__ (* args , ** kwargs )
111
112
112
113
def _try_coerce (
113
114
self ,
114
115
value : Any ,
115
116
force : bool = False
116
- ) -> Union [Choices , Any ]:
117
+ ) -> Union [Enum , Any ]:
117
118
"""
118
119
Attempt coercion of value to enumeration type instance, if unsuccessful
119
120
and non-strict, coercion to enum's primitive type will be done,
@@ -130,15 +131,18 @@ def _try_coerce(
130
131
try :
131
132
value = self ._coerce_to_value_type (value )
132
133
value = self .enum (value )
133
- except (TypeError , ValueError ) as err :
134
- if self .strict or not isinstance (
135
- value ,
136
- type (self .enum .values [0 ])
137
- ):
138
- raise ValueError (
139
- f"'{ value } ' is not a valid { self .enum .__name__ } "
140
- f"required by field { self .name } ."
141
- ) from err
134
+ except (TypeError , ValueError ):
135
+ try :
136
+ value = self .enum [value ]
137
+ except KeyError as err :
138
+ if self .strict or not isinstance (
139
+ value ,
140
+ type (values (self .enum )[0 ])
141
+ ):
142
+ raise ValueError (
143
+ f"'{ value } ' is not a valid { self .enum .__name__ } "
144
+ f"required by field { self .name } ."
145
+ ) from err
142
146
return value
143
147
144
148
def deconstruct (self ) -> Tuple [str , str , List , dict ]:
@@ -159,7 +163,7 @@ def deconstruct(self) -> Tuple[str, str, List, dict]:
159
163
"""
160
164
name , path , args , kwargs = super ().deconstruct ()
161
165
if self .enum is not None :
162
- kwargs ['choices' ] = self .enum . choices
166
+ kwargs ['choices' ] = choices ( self .enum )
163
167
164
168
if 'default' in kwargs :
165
169
# ensure default in deconstructed fields is always the primitive
@@ -216,7 +220,7 @@ def from_db_value(
216
220
return value
217
221
return self ._try_coerce (value )
218
222
219
- def to_python (self , value : Any ) -> Union [Choices , Any ]:
223
+ def to_python (self , value : Any ) -> Union [Enum , Any ]:
220
224
"""
221
225
Converts the value in the enumeration type.
222
226
@@ -301,10 +305,12 @@ class EnumCharField(EnumMixin, CharField):
301
305
"""
302
306
303
307
def __init__ (self , * args , enum = None , ** kwargs ):
304
- choices = kwargs .get ('choices' , enum .choices if enum else [])
305
308
kwargs .setdefault (
306
309
'max_length' ,
307
- max ((len (choice [0 ]) for choice in choices ))
310
+ max ((
311
+ len (choice [0 ])
312
+ for choice in kwargs .get ('choices' , choices (enum ))
313
+ ))
308
314
)
309
315
super ().__init__ (* args , enum = enum , ** kwargs )
310
316
@@ -361,7 +367,7 @@ class _EnumFieldMetaClass(type):
361
367
362
368
def __new__ ( # pylint: disable=R0911
363
369
mcs ,
364
- enum : Type [Choices ]
370
+ enum : Type [Enum ]
365
371
) -> Type [EnumMixin ]:
366
372
"""
367
373
Construct a new Django Field class given the Enumeration class. The
@@ -370,23 +376,21 @@ def __new__( # pylint: disable=R0911
370
376
371
377
:param enum: The class of the Enumeration to build a field class for
372
378
"""
373
- assert issubclass (enum , Choices ), \
374
- f'{ enum } must inherit from { Choices } !'
375
379
primitives = mcs .SUPPORTED_PRIMITIVES .intersection (set (enum .__mro__ ))
376
- assert len ( primitives ) == 1 , f' { enum } must inherit from exactly one ' \
377
- f'supported primitive type ' \
378
- f' { mcs . SUPPORTED_PRIMITIVES } , ' \
379
- f'encountered: { primitives } .'
380
-
381
- primitive = list ( primitives )[ 0 ]
380
+ primitive = (
381
+ list ( primitives )[ 0 ] if primitives else type ( values ( enum )[ 0 ])
382
+ )
383
+ assert primitive in mcs . SUPPORTED_PRIMITIVES , \
384
+ f'Enum { enum } has values of an unnsupported primitive type: ' \
385
+ f' { primitive } '
382
386
383
387
if primitive is float :
384
388
return EnumFloatField
385
389
386
390
if primitive is int :
387
- values = [define .value for define in enum ]
388
- min_value = min (values )
389
- max_value = max (values )
391
+ vals = [define .value for define in enum ]
392
+ min_value = min (vals )
393
+ max_value = max (vals )
390
394
if min_value < 0 :
391
395
if min_value < - 2147483648 or max_value > 2147483647 :
392
396
return EnumBigIntegerField
@@ -404,7 +408,7 @@ def __new__( # pylint: disable=R0911
404
408
405
409
406
410
def EnumField ( # pylint: disable=C0103
407
- enum : Type [Choices ],
411
+ enum : Type [Enum ],
408
412
* field_args ,
409
413
** field_kwargs
410
414
) -> EnumMixin :
0 commit comments