diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index f70e35ff..3f4681a3 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -27,10 +27,10 @@ class DjangoConnectionField(ConnectionField): def get_queryset(self, resolved_qs, args, info): return resolved_qs - def from_list(self, connection_type, resolved, args, info): + def from_list(self, connection_type, resolved, args, context, info): resolved_qs = maybe_queryset(resolved) qs = self.get_queryset(resolved_qs, args, info) - return super(DjangoConnectionField, self).from_list(connection_type, qs, args, info) + return super(DjangoConnectionField, self).from_list(connection_type, qs, args, context, info) class ConnectionOrListField(Field): diff --git a/graphene/contrib/sqlalchemy/fields.py b/graphene/contrib/sqlalchemy/fields.py index dc3eb66b..598cd341 100644 --- a/graphene/contrib/sqlalchemy/fields.py +++ b/graphene/contrib/sqlalchemy/fields.py @@ -21,11 +21,11 @@ class SQLAlchemyConnectionField(ConnectionField): def model(self): return self.type._meta.model - def from_list(self, connection_type, resolved, args, info): + def from_list(self, connection_type, resolved, args, context, info): if resolved is DefaultQuery: resolved = get_query(self.model, info) query = maybe_query(resolved) - return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, info) + return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, context, info) class ConnectionOrListField(Field): diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index da9d4af7..e79b9592 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -5,7 +5,7 @@ from graphql_relay.node.node import from_global_id from ..core.fields import Field from ..core.types.definitions import NonNull from ..core.types.scalars import ID, Int, String -from ..utils import with_context +from ..utils.wrap_resolver_function import has_context, with_context class ConnectionField(Field): @@ -26,16 +26,23 @@ class ConnectionField(Field): self.connection_type = connection_type self.edge_type = edge_type - def resolver(self, instance, args, info): + @with_context + def resolver(self, instance, args, context, info): schema = info.schema.graphene_schema connection_type = self.get_type(schema) - resolved = super(ConnectionField, self).resolver(instance, args, info) + + resolver = super(ConnectionField, self).resolver + if has_context(resolver): + resolved = super(ConnectionField, self).resolver(instance, args, context, info) + else: + resolved = super(ConnectionField, self).resolver(instance, args, info) + if isinstance(resolved, connection_type): return resolved - return self.from_list(connection_type, resolved, args, info) + return self.from_list(connection_type, resolved, args, context, info) - def from_list(self, connection_type, resolved, args, info): - return connection_type.from_list(resolved, args, info) + def from_list(self, connection_type, resolved, args, context, info): + return connection_type.from_list(resolved, args, context, info) def get_connection_type(self, node): connection_type = self.connection_type or node.get_connection_type() diff --git a/graphene/relay/tests/test_query.py b/graphene/relay/tests/test_query.py index 4e94d9fd..08e57226 100644 --- a/graphene/relay/tests/test_query.py +++ b/graphene/relay/tests/test_query.py @@ -37,15 +37,58 @@ class Query(graphene.ObjectType): all_my_nodes = relay.ConnectionField( MyNode, connection_type=MyConnection, customArg=graphene.String()) + context_nodes = relay.ConnectionField( + MyNode, connection_type=MyConnection, customArg=graphene.String()) + def resolve_all_my_nodes(self, args, info): custom_arg = args.get('customArg') assert custom_arg == "1" return [MyNode(name='my')] + @with_context + def resolve_context_nodes(self, args, context, info): + custom_arg = args.get('customArg') + assert custom_arg == "1" + return [MyNode(name='my')] + schema.query = Query def test_nodefield_query(): + query = ''' + query RebelsShipsQuery { + contextNodes (customArg:"1") { + edges { + node { + name + } + }, + myCustomField + pageInfo { + hasNextPage + } + } + } + ''' + expected = { + 'contextNodes': { + 'edges': [{ + 'node': { + 'name': 'my' + } + }], + 'myCustomField': 'Custom', + 'pageInfo': { + 'hasNextPage': False, + } + } + } + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_connectionfield_context_query(): query = ''' query RebelsShipsQuery { myNode(id:"TXlOb2RlOjE=") { diff --git a/graphene/relay/types.py b/graphene/relay/types.py index 29ba779a..2e0391ed 100644 --- a/graphene/relay/types.py +++ b/graphene/relay/types.py @@ -87,7 +87,7 @@ class Connection(ObjectType): {'edge_type': edge_type, 'edges': edges}) @classmethod - def from_list(cls, iterable, args, info): + def from_list(cls, iterable, args, context, info): assert isinstance( iterable, Iterable), 'Resolved value from the connection field have to be iterable' connection = connection_from_list(