diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index da9d4af7..6796d1a7 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -5,7 +5,7 @@ from graphql_relay.node.node import from_global_id from ..core.fields import Field from ..core.types.definitions import NonNull from ..core.types.scalars import ID, Int, String -from ..utils import with_context +from ..utils.wrap_resolver_function import has_context, with_context class ConnectionField(Field): @@ -25,17 +25,24 @@ class ConnectionField(Field): **kwargs) self.connection_type = connection_type self.edge_type = edge_type - - def resolver(self, instance, args, info): + + @with_context + def resolver(self, instance, args, context, info): schema = info.schema.graphene_schema connection_type = self.get_type(schema) - resolved = super(ConnectionField, self).resolver(instance, args, info) + + resolver = super(ConnectionField, self).resolver + if has_context(resolver): + resolved = super(ConnectionField, self).resolver(instance, args, context, info) + else: + resolved = super(ConnectionField, self).resolver(instance, args, info) + if isinstance(resolved, connection_type): return resolved - return self.from_list(connection_type, resolved, args, info) + return self.from_list(connection_type, resolved, args, context, info) def from_list(self, connection_type, resolved, args, info): - return connection_type.from_list(resolved, args, info) + return connection_type.from_list(resolved, args, context, info) def get_connection_type(self, node): connection_type = self.connection_type or node.get_connection_type()