Correctly propagate NonNull to inner connection type

This commit is contained in:
Alexandre Kirszenberg 2019-03-30 19:38:20 +01:00
parent 3d493c3bd9
commit 8beadc759f

View File

@ -1,6 +1,7 @@
from functools import partial from functools import partial
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from graphene import NonNull
from promise import Promise from promise import Promise
@ -45,17 +46,31 @@ class DjangoConnectionField(ConnectionField):
from .types import DjangoObjectType from .types import DjangoObjectType
_type = super(ConnectionField, self).type _type = super(ConnectionField, self).type
non_null = False
if isinstance(_type, NonNull):
_type = _type.of_type
non_null = True
assert issubclass( assert issubclass(
_type, DjangoObjectType _type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types" ), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format( assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__ _type.__name__
) )
return _type._meta.connection connection_type = _type._meta.connection
if non_null:
return NonNull(connection_type)
return connection_type
@property
def connection_type(self):
type = self.type
if isinstance(type, NonNull):
return type.of_type
return type
@property @property
def node_type(self): def node_type(self):
return self.type._meta.node return self.connection_type._meta.node
@property @property
def model(self): def model(self):
@ -103,15 +118,15 @@ class DjangoConnectionField(ConnectionField):
@classmethod @classmethod
def connection_resolver( def connection_resolver(
cls, cls,
resolver, resolver,
connection, connection,
default_manager, default_manager,
max_limit, max_limit,
enforce_first_or_last, enforce_first_or_last,
root, root,
info, info,
**args **args
): ):
first = args.get("first") first = args.get("first")
last = args.get("last") last = args.get("last")
@ -146,7 +161,7 @@ class DjangoConnectionField(ConnectionField):
return partial( return partial(
self.connection_resolver, self.connection_resolver,
parent_resolver, parent_resolver,
self.type, self.connection_type,
self.get_manager(), self.get_manager(),
self.max_limit, self.max_limit,
self.enforce_first_or_last, self.enforce_first_or_last,