diff --git a/graphene_django/__init__.py b/graphene_django/__init__.py index 6cc8734..abc81c4 100644 --- a/graphene_django/__init__.py +++ b/graphene_django/__init__.py @@ -1,5 +1,5 @@ from .fields import DjangoConnectionField, DjangoListField -from .types import DjangoObjectType +from .types import DjangoObjectType, DjangoUnionType from .utils import bypass_get_queryset __version__ = "3.2.3" @@ -7,6 +7,7 @@ __version__ = "3.2.3" __all__ = [ "__version__", "DjangoObjectType", + "DjangoUnionType", "DjangoListField", "DjangoConnectionField", "bypass_get_queryset", diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 9e9f457..125b0a9 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -92,7 +92,7 @@ class DjangoConnectionField(ConnectionField): @property def type(self): - from .types import DjangoObjectType + from .types import DjangoObjectType, DjangoUnionType _type = super(ConnectionField, self).type non_null = False @@ -100,8 +100,8 @@ class DjangoConnectionField(ConnectionField): _type = _type.of_type non_null = True assert issubclass( - _type, DjangoObjectType - ), "DjangoConnectionField only accepts DjangoObjectType types" + _type, (DjangoObjectType, DjangoUnionType) + ), "DjangoConnectionField only accepts DjangoObjectType or DjangoUnionType types" assert _type._meta.connection, "The type {} doesn't have a connection".format( _type.__name__ ) diff --git a/graphene_django/registry.py b/graphene_django/registry.py index 900feeb..83230c7 100644 --- a/graphene_django/registry.py +++ b/graphene_django/registry.py @@ -4,11 +4,11 @@ class Registry: self._field_registry = {} def register(self, cls): - from .types import DjangoObjectType + from .types import DjangoObjectType, DjangoUnionType assert issubclass( - cls, DjangoObjectType - ), f'Only DjangoObjectTypes can be registered, received "{cls.__name__}"' + cls, (DjangoObjectType, DjangoUnionType) + ), f'Only DjangoObjectTypes or DjangoUnionType can be registered, received "{cls.__name__}"' assert cls._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) == cls, ( # 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model) diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index 0491bcd..ec17fc2 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -11,9 +11,15 @@ from graphene.relay import Node from .. import registry from ..filter import DjangoFilterConnectionField -from ..types import DjangoObjectType, DjangoObjectTypeOptions +from ..types import ( + DjangoObjectType, + DjangoObjectTypeOptions, + DjangoUnionType, +) from .models import ( + APNewsReporter as APNewsReporterModel, Article as ArticleModel, + CNNReporter as CNNReporterModel, Reporter as ReporterModel, ) @@ -832,3 +838,47 @@ def test_django_objecttype_name_connection_propagation(): assert "type Reporter implements Node {" not in schema assert "type ReporterConnection {" not in schema assert "type ReporterEdge {" not in schema + + +@with_local_registry +def test_django_uniontype_name_connection_propagation(): + class CNNReporter(DjangoObjectType): + class Meta: + model = CNNReporterModel + name = "CNNReporter" + fields = "__all__" + filter_fields = ["email"] + interfaces = (Node,) + + class APNewsReporter(DjangoObjectType): + class Meta: + model = APNewsReporterModel + name = "APNewsReporter" + fields = "__all__" + filter_fields = ["email"] + interfaces = (Node,) + + class ReporterUnion(DjangoUnionType): + class Meta: + model = ReporterModel + types = (CNNReporter, APNewsReporter) + interfaces = (Node,) + filter_fields = ("id", "first_name", "last_name") + + @classmethod + def resolve_type(cls, instance, info): + if isinstance(instance, CNNReporterModel): + return CNNReporter + elif isinstance(instance, APNewsReporterModel): + return APNewsReporter + return None + + class Query(ObjectType): + reporter = Node.Field(ReporterUnion) + reporters = DjangoFilterConnectionField(ReporterUnion) + + schema = str(Schema(query=Query)) + + assert "union ReporterUnion = CNNReporter | APNewsReporter" in schema + assert "CNNReporter implements Node" in schema + assert "ReporterUnionConnection" in schema diff --git a/graphene_django/types.py b/graphene_django/types.py index e310fe4..329ecf9 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -7,6 +7,7 @@ from django.db.models import Model # noqa: F401 import graphene from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions +from graphene.types.union import Union, UnionOptions from graphene.types.utils import yank_fields_from_attrs from .converter import convert_django_field_with_choices @@ -293,6 +294,137 @@ class DjangoObjectType(ObjectType): return None +class DjangoUnionTypeOptions(UnionOptions, ObjectTypeOptions): + model = None # type: Type[Model] + registry = None # type: Registry + connection = None # type: Type[Connection] + + filter_fields = () + filterset_class = None + + +class DjangoUnionType(Union): + """ + A Django specific Union type that allows to map multiple Django object types + One use case is to handle polymorphic relationships for a Django model, using a library like django-polymorphic. + + Can be used in combination with DjangoConnectionField and DjangoFilterConnectionField + + Args: + Meta (class): The meta class of the union type + model (Model): The Django model that represents the union type + types (tuple): A tuple of DjangoObjectType classes that represent the possible types of the union + + Example: + ```python + from graphene_django.types import DjangoObjectType, DjangoUnionType + + class AssessmentUnion(DjangoUnionType): + class Meta: + model = Assessment + types = (HomeworkAssessmentNode, QuizAssessmentNode) + interfaces = (graphene.relay.Node,) + filter_fields = ("id", "title", "description") + + @classmethod + def resolve_type(cls, instance, info): + if isinstance(instance, HomeworkAssessment): + return HomeworkAssessmentNode + elif isinstance(instance, QuizAssessment): + return QuizAssessmentNode + + class Query(graphene.ObjectType): + all_assessments = DjangoFilterConnectionField(AssessmentUnion) + ``` + """ + + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__( + cls, + model=None, + types=None, + registry=None, + skip_registry=False, + _meta=None, + fields=None, + exclude=None, + convert_choices_to_enum=None, + filter_fields=None, + filterset_class=None, + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + **options, + ): + django_fields = yank_fields_from_attrs( + construct_fields(model, registry, fields, exclude, convert_choices_to_enum), + _as=graphene.Field, + ) + + if use_connection is None and interfaces: + use_connection = any( + issubclass(interface, Node) for interface in interfaces + ) + + if not registry: + registry = get_global_registry() + + assert isinstance(registry, Registry), ( + f"The attribute registry in {cls.__name__} needs to be an instance of " + f'Registry, received "{registry}".' + ) + + if filter_fields and filterset_class: + raise Exception("Can't set both filter_fields and filterset_class") + + if not DJANGO_FILTER_INSTALLED and (filter_fields or filterset_class): + raise Exception( + "Can only set filter_fields or filterset_class if " + "Django-Filter is installed" + ) + + if not _meta: + _meta = DjangoUnionTypeOptions(cls) + + _meta.model = model + _meta.types = types + _meta.fields = django_fields + _meta.filter_fields = filter_fields + _meta.filterset_class = filterset_class + _meta.registry = registry + + if use_connection and not connection: + # We create the connection automatically + if not connection_class: + connection_class = Connection + + connection = connection_class.create_type( + "{}Connection".format(options.get("name") or cls.__name__), node=cls + ) + + if connection is not None: + assert issubclass( + connection, Connection + ), f"The connection must be a Connection. Received {connection.__name__}" + + _meta.connection = connection + + super().__init_subclass_with_meta__( + types=types, _meta=_meta, interfaces=interfaces, **options + ) + + if not skip_registry: + registry.register(cls) + + @classmethod + def get_queryset(cls, queryset, info): + return queryset + + class ErrorType(ObjectType): field = graphene.String(required=True) messages = graphene.List(graphene.NonNull(graphene.String), required=True)