Skip to content

Commit 5bfecc8

Browse files
committed
closes #34
1 parent d9bec6e commit 5bfecc8

23 files changed

+475
-128
lines changed

README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ was to:
4343

4444
* Always automatically coerce fields to instances of the Enum type.
4545
* Allow strict adherence to Enum values to be disabled.
46+
* Be compatible with Enum classes that do not derive from Django's Choices.
4647
* Handle migrations appropriately. (See `migrations <https://django-enum.readthedocs.io/en/latest/usage.html#migrations>`_)
4748
* Integrate as fully as possible with Django_'s existing level of enum support.
4849
* Integrate with `enum-properties <https://pypi.org/project/enum-properties/>`_ to enable richer enumeration types.

django_enum/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
'EnumFilter'
4848
]
4949

50-
VERSION = (1, 1, 2)
50+
VERSION = (1, 2, 0)
5151

5252
__title__ = 'Django Enum'
5353
__version__ = '.'.join(str(i) for i in VERSION)

django_enum/choices.py

+75-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,86 @@
33
types. These choices types are drop in replacements for the Django
44
IntegerChoices and TextChoices.
55
"""
6+
from enum import Enum
7+
from typing import Any, List, Optional, Tuple, Type
8+
69
from django.db.models import Choices
710
from django.db.models import IntegerChoices as DjangoIntegerChoices
811
from django.db.models import TextChoices as DjangoTextChoices
912
from django.db.models.enums import ChoicesMeta
1013

14+
15+
def choices(enum: Optional[Type[Enum]]) -> List[Tuple[Any, str]]:
16+
"""
17+
Get the Django choices for an enumeration type. If the enum type has a
18+
choices attribute, it will be used. Otherwise, the choices will be derived
19+
from value, label pairs if the enumeration type has a label attribute, or
20+
the name attribute if it does not.
21+
22+
This is used for compat with enums that do not inherit from Django's
23+
Choices type.
24+
25+
:param enum: The enumeration type
26+
:return: A list of (value, label) pairs
27+
"""
28+
return getattr(
29+
enum,
30+
'choices', [
31+
*([(None, enum.__empty__)] if hasattr(enum, '__empty__') else []),
32+
*[
33+
(
34+
member.value,
35+
getattr(member, 'label', getattr(member, 'name'))
36+
)
37+
for member in enum
38+
]
39+
]
40+
) if enum else []
41+
42+
43+
def names(enum: Optional[Type[Enum]]) -> List[Any]:
44+
"""
45+
Return a list of names to use for the enumeration type. This is used
46+
for compat with enums that do not inherit from Django's Choices type.
47+
48+
:param enum: The enumeration type
49+
:return: A list of labels
50+
"""
51+
return getattr(
52+
enum,
53+
'names', [
54+
*(['__empty__'] if hasattr(enum, '__empty__') else []),
55+
*[member.name for member in enum]
56+
]
57+
) if enum else []
58+
59+
60+
def labels(enum: Optional[Type[Enum]]) -> List[Any]:
61+
"""
62+
Return a list of labels to use for the enumeration type. See choices.
63+
64+
This is used for compat with enums that do not inherit from Django's
65+
Choices type.
66+
67+
:param enum: The enumeration type
68+
:return: A list of labels
69+
"""
70+
return getattr(enum, 'labels', [label for _, label in choices(enum)])
71+
72+
73+
def values(enum: Optional[Type[Enum]]) -> List[Any]:
74+
"""
75+
Return a list of the values of an enumeration type.
76+
77+
This is used for compat with enums that do not inherit from Django's
78+
Choices type.
79+
80+
:param enum: The enumeration type
81+
:return: A list of values
82+
"""
83+
return getattr(enum, 'values', [value for value, _ in choices(enum)])
84+
85+
1186
try:
1287
from enum_properties import EnumPropertiesMeta, SymmetricMixin
1388

@@ -63,7 +138,6 @@ class FloatChoices(
63138
"""
64139

65140
except (ImportError, ModuleNotFoundError):
66-
from enum import Enum
67141

68142
# 3.11 - extend from Enum so base type check does not throw type error
69143
class MissingEnumProperties(Enum):

django_enum/drf.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
__all__ = ['EnumField']
44

55
try:
6+
from enum import Enum
67
from typing import Any, Type, Union
78

8-
from django.db.models import Choices
9+
from django_enum.choices import choices, values
910
from rest_framework.fields import ChoiceField
1011

1112
class EnumField(ChoiceField):
@@ -23,21 +24,21 @@ class EnumField(ChoiceField):
2324
will be passed up to the base classes.
2425
"""
2526

26-
enum: Type[Choices]
27+
enum: Type[Enum]
2728
strict: bool = True
2829

2930
def __init__(
3031
self,
31-
enum: Type[Choices],
32+
enum: Type[Enum],
3233
strict: bool = strict,
3334
**kwargs
3435
):
3536
self.enum = enum
3637
self.strict = strict
37-
self.choices = kwargs.pop('choices', enum.choices)
38+
self.choices = kwargs.pop('choices', choices(enum))
3839
super().__init__(choices=self.choices, **kwargs)
3940

40-
def to_internal_value(self, data: Any) -> Union[Choices, Any]:
41+
def to_internal_value(self, data: Any) -> Union[Enum, Any]:
4142
"""
4243
Transform the *incoming* primitive data into an enum instance.
4344
"""
@@ -49,12 +50,12 @@ def to_internal_value(self, data: Any) -> Union[Choices, Any]:
4950
data = self.enum(data)
5051
except (TypeError, ValueError):
5152
try:
52-
data = type(self.enum.values[0])(data)
53+
data = type(values(self.enum)[0])(data)
5354
data = self.enum(data)
5455
except (TypeError, ValueError):
5556
if self.strict or not isinstance(
5657
data,
57-
type(self.enum.values[0])
58+
type(values(self.enum)[0])
5859
):
5960
self.fail('invalid_choice', input=data)
6061
return data

django_enum/fields.py

+37-33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Support for Django model fields built from enumeration types.
33
"""
4+
from enum import Enum
45
from typing import (
56
TYPE_CHECKING,
67
Any,
@@ -16,7 +17,6 @@
1617
from django.db.models import (
1718
BigIntegerField,
1819
CharField,
19-
Choices,
2020
Field,
2121
FloatField,
2222
IntegerField,
@@ -27,6 +27,7 @@
2727
SmallIntegerField,
2828
)
2929
from django.db.models.query_utils import DeferredAttribute
30+
from django_enum.choices import choices, values
3031
from django_enum.forms import EnumChoiceField, NonStrictSelect
3132

3233
T = TypeVar('T') # pylint: disable=C0103
@@ -79,25 +80,25 @@ class EnumMixin(
7980
field type.
8081
"""
8182

82-
enum: Optional[Type[Choices]] = None
83+
enum: Optional[Type[Enum]] = None
8384
strict: bool = True
8485
coerce: bool = True
8586

8687
descriptor_class = ToPythonDeferredAttribute
8788

88-
def _coerce_to_value_type(self, value: Any) -> Choices:
89+
def _coerce_to_value_type(self, value: Any) -> Enum:
8990
"""Coerce the value to the enumerations value type"""
9091
# note if enum type is int and a floating point is passed we could get
9192
# situations like X.xxx == X - this is acceptable
9293
if self.enum:
93-
return type(self.enum.values[0])(value)
94+
return type(values(self.enum)[0])(value)
9495
# can't ever reach this - just here to make type checker happy
9596
return value # pragma: no cover
9697

9798
def __init__(
9899
self,
99100
*args,
100-
enum: Optional[Type[Choices]] = None,
101+
enum: Optional[Type[Enum]] = None,
101102
strict: bool = strict,
102103
coerce: bool = coerce,
103104
**kwargs
@@ -106,14 +107,14 @@ def __init__(
106107
self.strict = strict if enum else False
107108
self.coerce = coerce if enum else False
108109
if self.enum is not None:
109-
kwargs.setdefault('choices', enum.choices if enum else [])
110+
kwargs.setdefault('choices', choices(enum))
110111
super().__init__(*args, **kwargs)
111112

112113
def _try_coerce(
113114
self,
114115
value: Any,
115116
force: bool = False
116-
) -> Union[Choices, Any]:
117+
) -> Union[Enum, Any]:
117118
"""
118119
Attempt coercion of value to enumeration type instance, if unsuccessful
119120
and non-strict, coercion to enum's primitive type will be done,
@@ -130,15 +131,18 @@ def _try_coerce(
130131
try:
131132
value = self._coerce_to_value_type(value)
132133
value = self.enum(value)
133-
except (TypeError, ValueError) as err:
134-
if self.strict or not isinstance(
135-
value,
136-
type(self.enum.values[0])
137-
):
138-
raise ValueError(
139-
f"'{value}' is not a valid {self.enum.__name__} "
140-
f"required by field {self.name}."
141-
) from err
134+
except (TypeError, ValueError):
135+
try:
136+
value = self.enum[value]
137+
except KeyError as err:
138+
if self.strict or not isinstance(
139+
value,
140+
type(values(self.enum)[0])
141+
):
142+
raise ValueError(
143+
f"'{value}' is not a valid {self.enum.__name__} "
144+
f"required by field {self.name}."
145+
) from err
142146
return value
143147

144148
def deconstruct(self) -> Tuple[str, str, List, dict]:
@@ -159,7 +163,7 @@ def deconstruct(self) -> Tuple[str, str, List, dict]:
159163
"""
160164
name, path, args, kwargs = super().deconstruct()
161165
if self.enum is not None:
162-
kwargs['choices'] = self.enum.choices
166+
kwargs['choices'] = choices(self.enum)
163167

164168
if 'default' in kwargs:
165169
# ensure default in deconstructed fields is always the primitive
@@ -216,7 +220,7 @@ def from_db_value(
216220
return value
217221
return self._try_coerce(value)
218222

219-
def to_python(self, value: Any) -> Union[Choices, Any]:
223+
def to_python(self, value: Any) -> Union[Enum, Any]:
220224
"""
221225
Converts the value in the enumeration type.
222226
@@ -301,10 +305,12 @@ class EnumCharField(EnumMixin, CharField):
301305
"""
302306

303307
def __init__(self, *args, enum=None, **kwargs):
304-
choices = kwargs.get('choices', enum.choices if enum else [])
305308
kwargs.setdefault(
306309
'max_length',
307-
max((len(choice[0]) for choice in choices))
310+
max((
311+
len(choice[0])
312+
for choice in kwargs.get('choices', choices(enum))
313+
))
308314
)
309315
super().__init__(*args, enum=enum, **kwargs)
310316

@@ -361,7 +367,7 @@ class _EnumFieldMetaClass(type):
361367

362368
def __new__( # pylint: disable=R0911
363369
mcs,
364-
enum: Type[Choices]
370+
enum: Type[Enum]
365371
) -> Type[EnumMixin]:
366372
"""
367373
Construct a new Django Field class given the Enumeration class. The
@@ -370,23 +376,21 @@ def __new__( # pylint: disable=R0911
370376
371377
:param enum: The class of the Enumeration to build a field class for
372378
"""
373-
assert issubclass(enum, Choices), \
374-
f'{enum} must inherit from {Choices}!'
375379
primitives = mcs.SUPPORTED_PRIMITIVES.intersection(set(enum.__mro__))
376-
assert len(primitives) == 1, f'{enum} must inherit from exactly one ' \
377-
f'supported primitive type ' \
378-
f'{mcs.SUPPORTED_PRIMITIVES}, ' \
379-
f'encountered: {primitives}.'
380-
381-
primitive = list(primitives)[0]
380+
primitive = (
381+
list(primitives)[0] if primitives else type(values(enum)[0])
382+
)
383+
assert primitive in mcs.SUPPORTED_PRIMITIVES, \
384+
f'Enum {enum} has values of an unnsupported primitive type: ' \
385+
f'{primitive}'
382386

383387
if primitive is float:
384388
return EnumFloatField
385389

386390
if primitive is int:
387-
values = [define.value for define in enum]
388-
min_value = min(values)
389-
max_value = max(values)
391+
vals = [define.value for define in enum]
392+
min_value = min(vals)
393+
max_value = max(vals)
390394
if min_value < 0:
391395
if min_value < -2147483648 or max_value > 2147483647:
392396
return EnumBigIntegerField
@@ -404,7 +408,7 @@ def __new__( # pylint: disable=R0911
404408

405409

406410
def EnumField( # pylint: disable=C0103
407-
enum: Type[Choices],
411+
enum: Type[Enum],
408412
*field_args,
409413
**field_kwargs
410414
) -> EnumMixin:

django_enum/filters.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from django.db.models import Field as ModelField
66
from django.forms.fields import Field as FormField
7+
from django_enum.choices import choices
78
from django_enum.forms import EnumChoiceField
89

910
try:
@@ -43,7 +44,7 @@ def __init__(self, *, enum, strict=False, **kwargs):
4344
self.enum = enum
4445
super().__init__(
4546
enum=enum,
46-
choices=kwargs.pop('choices', self.enum.choices),
47+
choices=kwargs.pop('choices', choices(self.enum)),
4748
strict=strict,
4849
**kwargs
4950
)
@@ -61,7 +62,7 @@ def filter_for_lookup(
6162
lookup_type: str
6263
) -> Tuple[Type[Filter], dict]:
6364
"""For EnumFields use the EnumFilter class by default"""
64-
if hasattr(field, 'enum') and hasattr(field.enum, 'choices'):
65+
if hasattr(field, 'enum'):
6566
return EnumFilter, {
6667
'enum': field.enum,
6768
'strict': getattr(field, 'strict', False)

0 commit comments

Comments
 (0)