Skip to content

Commit 996bc5a

Browse files
committed
add DjangoUnionType test
1 parent 9c1e4f2 commit 996bc5a

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

Diff for: graphene_django/tests/test_types.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@
1111

1212
from .. import registry
1313
from ..filter import DjangoFilterConnectionField
14-
from ..types import DjangoObjectType, DjangoObjectTypeOptions
14+
from ..types import (
15+
DjangoObjectType,
16+
DjangoObjectTypeOptions,
17+
DjangoUnionType,
18+
)
1519
from .models import (
20+
APNewsReporter as APNewsReporterModel,
1621
Article as ArticleModel,
22+
CNNReporter as CNNReporterModel,
1723
Reporter as ReporterModel,
1824
)
1925

@@ -799,3 +805,47 @@ class Query(ObjectType):
799805
assert "type Reporter implements Node {" not in schema
800806
assert "type ReporterConnection {" not in schema
801807
assert "type ReporterEdge {" not in schema
808+
809+
810+
@with_local_registry
811+
def test_django_uniontype_name_connection_propagation():
812+
class CNNReporter(DjangoObjectType):
813+
class Meta:
814+
model = CNNReporterModel
815+
name = "CNNReporter"
816+
fields = "__all__"
817+
filter_fields = ["email"]
818+
interfaces = (Node,)
819+
820+
class APNewsReporter(DjangoObjectType):
821+
class Meta:
822+
model = APNewsReporterModel
823+
name = "APNewsReporter"
824+
fields = "__all__"
825+
filter_fields = ["email"]
826+
interfaces = (Node,)
827+
828+
class ReporterUnion(DjangoUnionType):
829+
class Meta:
830+
model = ReporterModel
831+
types = (CNNReporter, APNewsReporter)
832+
interfaces = (Node,)
833+
filter_fields = ("id", "first_name", "last_name")
834+
835+
@classmethod
836+
def resolve_type(cls, instance, info):
837+
if isinstance(instance, CNNReporterModel):
838+
return CNNReporter
839+
elif isinstance(instance, APNewsReporterModel):
840+
return APNewsReporter
841+
return None
842+
843+
class Query(ObjectType):
844+
reporter = Node.Field(ReporterUnion)
845+
reporters = DjangoFilterConnectionField(ReporterUnion)
846+
847+
schema = str(Schema(query=Query))
848+
849+
assert "union ReporterUnion = CNNReporter | APNewsReporter" in schema
850+
assert "CNNReporter implements Node" in schema
851+
assert "ReporterUnionConnection" in schema

0 commit comments

Comments
 (0)