From c80d1317d606b3b298ba4da1be8f96f0cea580bf Mon Sep 17 00:00:00 2001 From: Muhammed Aldulaimi Date: Sat, 9 Nov 2024 17:29:03 +0300 Subject: [PATCH] add DjangoUnionType test --- graphene_django/tests/test_types.py | 52 ++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index 72514d2..4ad92c1 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -11,9 +11,15 @@ from graphene.relay import Node from .. import registry from ..filter import DjangoFilterConnectionField -from ..types import DjangoObjectType, DjangoObjectTypeOptions +from ..types import ( + DjangoObjectType, + DjangoObjectTypeOptions, + DjangoUnionType, +) from .models import ( + APNewsReporter as APNewsReporterModel, Article as ArticleModel, + CNNReporter as CNNReporterModel, Reporter as ReporterModel, ) @@ -799,3 +805,47 @@ def test_django_objecttype_name_connection_propagation(): assert "type Reporter implements Node {" not in schema assert "type ReporterConnection {" not in schema assert "type ReporterEdge {" not in schema + + +@with_local_registry +def test_django_uniontype_name_connection_propagation(): + class CNNReporter(DjangoObjectType): + class Meta: + model = CNNReporterModel + name = "CNNReporter" + fields = "__all__" + filter_fields = ["email"] + interfaces = (Node,) + + class APNewsReporter(DjangoObjectType): + class Meta: + model = APNewsReporterModel + name = "APNewsReporter" + fields = "__all__" + filter_fields = ["email"] + interfaces = (Node,) + + class ReporterUnion(DjangoUnionType): + class Meta: + model = ReporterModel + types = (CNNReporter, APNewsReporter) + interfaces = (Node,) + filter_fields = ("id", "first_name", "last_name") + + @classmethod + def resolve_type(cls, instance, info): + if isinstance(instance, CNNReporterModel): + return CNNReporter + elif isinstance(instance, APNewsReporterModel): + return APNewsReporter + return None + + class Query(ObjectType): + reporter = Node.Field(ReporterUnion) + reporters = DjangoFilterConnectionField(ReporterUnion) + + schema = str(Schema(query=Query)) + + assert "union ReporterUnion = CNNReporter | APNewsReporter" in schema + assert "CNNReporter implements Node" in schema + assert "ReporterUnionConnection" in schema