From 11a5ee189ef6f99ff361755dedad6b567f426f35 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sat, 21 May 2016 00:40:26 -0700 Subject: [PATCH] Fixed issues --- graphene/relay/fields.py | 6 +++--- graphene/relay/tests/test_query.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index 8ca12fef..e79b9592 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -25,18 +25,18 @@ class ConnectionField(Field): **kwargs) self.connection_type = connection_type self.edge_type = edge_type - + @with_context def resolver(self, instance, args, context, info): schema = info.schema.graphene_schema connection_type = self.get_type(schema) - + resolver = super(ConnectionField, self).resolver if has_context(resolver): resolved = super(ConnectionField, self).resolver(instance, args, context, info) else: resolved = super(ConnectionField, self).resolver(instance, args, info) - + if isinstance(resolved, connection_type): return resolved return self.from_list(connection_type, resolved, args, context, info) diff --git a/graphene/relay/tests/test_query.py b/graphene/relay/tests/test_query.py index 70b31936..08e57226 100644 --- a/graphene/relay/tests/test_query.py +++ b/graphene/relay/tests/test_query.py @@ -36,7 +36,7 @@ class Query(graphene.ObjectType): special_node = relay.NodeField(SpecialNode) all_my_nodes = relay.ConnectionField( MyNode, connection_type=MyConnection, customArg=graphene.String()) - + context_nodes = relay.ConnectionField( MyNode, connection_type=MyConnection, customArg=graphene.String()) @@ -44,7 +44,7 @@ class Query(graphene.ObjectType): custom_arg = args.get('customArg') assert custom_arg == "1" return [MyNode(name='my')] - + @with_context def resolve_context_nodes(self, args, context, info): custom_arg = args.get('customArg') @@ -71,7 +71,7 @@ def test_nodefield_query(): } ''' expected = { - 'allMyNodes': { + 'contextNodes': { 'edges': [{ 'node': { 'name': 'my' @@ -87,6 +87,7 @@ def test_nodefield_query(): assert not result.errors assert result.data == expected + def test_connectionfield_context_query(): query = ''' query RebelsShipsQuery {