Merge pull request #673 from jkimbo/relay-connection-required

Fix bug when setting a Relay ConnectionField to be required
This commit is contained in:
Syrus Akbary 2018-02-17 14:06:39 -08:00 committed by GitHub
commit 8c7ca74c6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 3 deletions

View File

@ -99,7 +99,10 @@ class IterableConnectionField(Field):
def type(self): def type(self):
type = super(IterableConnectionField, self).type type = super(IterableConnectionField, self).type
connection_type = type connection_type = type
if is_node(type): if isinstance(type, NonNull):
connection_type = type.of_type
if is_node(connection_type):
raise Exception( raise Exception(
"ConnectionField's now need a explicit ConnectionType for Nodes.\n" "ConnectionField's now need a explicit ConnectionType for Nodes.\n"
"Read more: https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#node-connections" "Read more: https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#node-connections"
@ -108,7 +111,7 @@ class IterableConnectionField(Field):
assert issubclass(connection_type, Connection), ( assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection. Received "{}".' '{} type have to be a subclass of Connection. Received "{}".'
).format(self.__class__.__name__, connection_type) ).format(self.__class__.__name__, connection_type)
return connection_type return type
@classmethod @classmethod
def resolve_connection(cls, connection_type, args, resolved): def resolve_connection(cls, connection_type, args, resolved):
@ -133,6 +136,9 @@ class IterableConnectionField(Field):
def connection_resolver(cls, resolver, connection_type, root, info, **args): def connection_resolver(cls, resolver, connection_type, root, info, **args):
resolved = resolver(root, info, **args) resolved = resolver(root, info, **args)
if isinstance(connection_type, NonNull):
connection_type = connection_type.of_type
on_resolve = partial(cls.resolve_connection, connection_type, args) on_resolve = partial(cls.resolve_connection, connection_type, args)
if is_thenable(resolved): if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve) return Promise.resolve(resolved).then(on_resolve)

View File

@ -1,6 +1,6 @@
import pytest import pytest
from ...types import Argument, Field, Int, List, NonNull, ObjectType, String from ...types import Argument, Field, Int, List, NonNull, ObjectType, String, Schema
from ..connection import Connection, ConnectionField, PageInfo from ..connection import Connection, ConnectionField, PageInfo
from ..node import Node from ..node import Node
@ -155,3 +155,23 @@ def test_connectionfield_custom_args():
'last': Argument(Int), 'last': Argument(Int),
'extra': Argument(String), 'extra': Argument(String),
} }
def test_connectionfield_required():
class MyObjectConnection(Connection):
class Meta:
node = MyObject
class Query(ObjectType):
test_connection = ConnectionField(MyObjectConnection, required=True)
def resolve_test_connection(root, info, **args):
return []
schema = Schema(query=Query)
executed = schema.execute(
'{ testConnection { edges { cursor } } }'
)
assert not executed.errors
assert executed.data == {'testConnection': {'edges': []}}