Added context to connectionfield resolver

This commit is contained in:
Markus Padourek 2016-05-20 13:59:39 +01:00
parent b431bfe477
commit c4f29f050b

View File

@ -5,7 +5,7 @@ from graphql_relay.node.node import from_global_id
from ..core.fields import Field from ..core.fields import Field
from ..core.types.definitions import NonNull from ..core.types.definitions import NonNull
from ..core.types.scalars import ID, Int, String from ..core.types.scalars import ID, Int, String
from ..utils import with_context from ..utils.wrap_resolver_function import has_context, with_context
class ConnectionField(Field): class ConnectionField(Field):
@ -25,17 +25,24 @@ class ConnectionField(Field):
**kwargs) **kwargs)
self.connection_type = connection_type self.connection_type = connection_type
self.edge_type = edge_type self.edge_type = edge_type
def resolver(self, instance, args, info): @with_context
def resolver(self, instance, args, context, info):
schema = info.schema.graphene_schema schema = info.schema.graphene_schema
connection_type = self.get_type(schema) connection_type = self.get_type(schema)
resolved = super(ConnectionField, self).resolver(instance, args, info)
resolver = super(ConnectionField, self).resolver
if has_context(resolver):
resolved = super(ConnectionField, self).resolver(instance, args, context, info)
else:
resolved = super(ConnectionField, self).resolver(instance, args, info)
if isinstance(resolved, connection_type): if isinstance(resolved, connection_type):
return resolved return resolved
return self.from_list(connection_type, resolved, args, info) return self.from_list(connection_type, resolved, args, context, info)
def from_list(self, connection_type, resolved, args, info): def from_list(self, connection_type, resolved, args, info):
return connection_type.from_list(resolved, args, info) return connection_type.from_list(resolved, args, context, info)
def get_connection_type(self, node): def get_connection_type(self, node):
connection_type = self.connection_type or node.get_connection_type() connection_type = self.connection_type or node.get_connection_type()