Skip to content

Commit 9c1e4f2

Browse files
committed
Feature: DjangoUnionType
1 parent fbe5603 commit 9c1e4f2

File tree

3 files changed

+138
-6
lines changed

3 files changed

+138
-6
lines changed

Diff for: graphene_django/fields.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,16 @@ def __init__(self, *args, **kwargs):
9292

9393
@property
9494
def type(self):
95-
from .types import DjangoObjectType
95+
from .types import DjangoObjectType, DjangoUnionType
9696

9797
_type = super(ConnectionField, self).type
9898
non_null = False
9999
if isinstance(_type, NonNull):
100100
_type = _type.of_type
101101
non_null = True
102102
assert issubclass(
103-
_type, DjangoObjectType
104-
), "DjangoConnectionField only accepts DjangoObjectType types"
103+
_type, (DjangoObjectType, DjangoUnionType)
104+
), "DjangoConnectionField only accepts DjangoObjectType or DjangoUnionType types"
105105
assert _type._meta.connection, "The type {} doesn't have a connection".format(
106106
_type.__name__
107107
)

Diff for: graphene_django/registry.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ def __init__(self):
44
self._field_registry = {}
55

66
def register(self, cls):
7-
from .types import DjangoObjectType
7+
from .types import DjangoObjectType, DjangoUnionType
88

99
assert issubclass(
10-
cls, DjangoObjectType
11-
), f'Only DjangoObjectTypes can be registered, received "{cls.__name__}"'
10+
cls, (DjangoObjectType, DjangoUnionType)
11+
), f'Only DjangoObjectTypes or DjangoUnionType can be registered, received "{cls.__name__}"'
1212
assert cls._meta.registry == self, "Registry for a Model have to match."
1313
# assert self.get_type_for_model(cls._meta.model) == cls, (
1414
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)

Diff for: graphene_django/types.py

+132
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import graphene
88
from graphene.relay import Connection, Node
99
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
10+
from graphene.types.union import Union
1011
from graphene.types.utils import yank_fields_from_attrs
1112

1213
from .converter import convert_django_field_with_choices
@@ -293,6 +294,137 @@ def get_node(cls, info, id):
293294
return None
294295

295296

297+
class DjangoUnionTypeOptions(ObjectTypeOptions):
298+
model = None # type: Type[Model]
299+
registry = None # type: Registry
300+
connection = None # type: Type[Connection]
301+
302+
filter_fields = ()
303+
filterset_class = None
304+
305+
306+
class DjangoUnionType(Union):
307+
"""
308+
A Django specific Union type that allows to map multiple Django object types
309+
One use case is to handle polymorphic relationships for a Django model, using a library like django-polymorphic.
310+
311+
Can be used in combination with DjangoConnectionField and DjangoFilterConnectionField
312+
313+
Args:
314+
Meta (class): The meta class of the union type
315+
model (Model): The Django model that represents the union type
316+
types (tuple): A tuple of DjangoObjectType classes that represent the possible types of the union
317+
318+
Example:
319+
```python
320+
from graphene_django.types import DjangoObjectType, DjangoUnionType
321+
322+
class AssessmentUnion(DjangoUnionType):
323+
class Meta:
324+
model = Assessment
325+
types = (HomeworkAssessmentNode, QuizAssessmentNode)
326+
interfaces = (graphene.relay.Node,)
327+
filter_fields = ("id", "title", "description")
328+
329+
@classmethod
330+
def resolve_type(cls, instance, info):
331+
if isinstance(instance, HomeworkAssessment):
332+
return HomeworkAssessmentNode
333+
elif isinstance(instance, QuizAssessment):
334+
return QuizAssessmentNode
335+
336+
class Query(graphene.ObjectType):
337+
all_assessments = DjangoFilterConnectionField(AssessmentUnion)
338+
```
339+
"""
340+
341+
class Meta:
342+
abstract = True
343+
344+
@classmethod
345+
def __init_subclass_with_meta__(
346+
cls,
347+
model=None,
348+
types=None,
349+
registry=None,
350+
skip_registry=False,
351+
_meta=None,
352+
fields=None,
353+
exclude=None,
354+
convert_choices_to_enum=None,
355+
filter_fields=None,
356+
filterset_class=None,
357+
connection=None,
358+
connection_class=None,
359+
use_connection=None,
360+
interfaces=(),
361+
**options,
362+
):
363+
django_fields = yank_fields_from_attrs(
364+
construct_fields(model, registry, fields, exclude, convert_choices_to_enum),
365+
_as=graphene.Field,
366+
)
367+
368+
if use_connection is None and interfaces:
369+
use_connection = any(
370+
issubclass(interface, Node) for interface in interfaces
371+
)
372+
373+
if not registry:
374+
registry = get_global_registry()
375+
376+
assert isinstance(registry, Registry), (
377+
f"The attribute registry in {cls.__name__} needs to be an instance of "
378+
f'Registry, received "{registry}".'
379+
)
380+
381+
if filter_fields and filterset_class:
382+
raise Exception("Can't set both filter_fields and filterset_class")
383+
384+
if not DJANGO_FILTER_INSTALLED and (filter_fields or filterset_class):
385+
raise Exception(
386+
"Can only set filter_fields or filterset_class if "
387+
"Django-Filter is installed"
388+
)
389+
390+
if not _meta:
391+
_meta = DjangoUnionTypeOptions(cls)
392+
393+
_meta.model = model
394+
_meta.types = types
395+
_meta.fields = django_fields
396+
_meta.filter_fields = filter_fields
397+
_meta.filterset_class = filterset_class
398+
_meta.registry = registry
399+
400+
if use_connection and not connection:
401+
# We create the connection automatically
402+
if not connection_class:
403+
connection_class = Connection
404+
405+
connection = connection_class.create_type(
406+
"{}Connection".format(options.get("name") or cls.__name__), node=cls
407+
)
408+
409+
if connection is not None:
410+
assert issubclass(
411+
connection, Connection
412+
), f"The connection must be a Connection. Received {connection.__name__}"
413+
414+
_meta.connection = connection
415+
416+
super().__init_subclass_with_meta__(
417+
types=types, _meta=_meta, interfaces=interfaces, **options
418+
)
419+
420+
if not skip_registry:
421+
registry.register(cls)
422+
423+
@classmethod
424+
def get_queryset(cls, queryset, info):
425+
return queryset
426+
427+
296428
class ErrorType(ObjectType):
297429
field = graphene.String(required=True)
298430
messages = graphene.List(graphene.NonNull(graphene.String), required=True)

0 commit comments

Comments
 (0)