Merge pull request #177 from graphql-python/bugfixes/context-in-connectionfield

Added context to connectionfield resolver
This commit is contained in:
Syrus Akbary 2016-05-21 00:47:56 -07:00
commit 89074492e0
5 changed files with 61 additions and 11 deletions

View File

@ -27,10 +27,10 @@ class DjangoConnectionField(ConnectionField):
def get_queryset(self, resolved_qs, args, info): def get_queryset(self, resolved_qs, args, info):
return resolved_qs return resolved_qs
def from_list(self, connection_type, resolved, args, info): def from_list(self, connection_type, resolved, args, context, info):
resolved_qs = maybe_queryset(resolved) resolved_qs = maybe_queryset(resolved)
qs = self.get_queryset(resolved_qs, args, info) qs = self.get_queryset(resolved_qs, args, info)
return super(DjangoConnectionField, self).from_list(connection_type, qs, args, info) return super(DjangoConnectionField, self).from_list(connection_type, qs, args, context, info)
class ConnectionOrListField(Field): class ConnectionOrListField(Field):

View File

@ -21,11 +21,11 @@ class SQLAlchemyConnectionField(ConnectionField):
def model(self): def model(self):
return self.type._meta.model return self.type._meta.model
def from_list(self, connection_type, resolved, args, info): def from_list(self, connection_type, resolved, args, context, info):
if resolved is DefaultQuery: if resolved is DefaultQuery:
resolved = get_query(self.model, info) resolved = get_query(self.model, info)
query = maybe_query(resolved) query = maybe_query(resolved)
return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, info) return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, context, info)
class ConnectionOrListField(Field): class ConnectionOrListField(Field):

View File

@ -5,7 +5,7 @@ from graphql_relay.node.node import from_global_id
from ..core.fields import Field from ..core.fields import Field
from ..core.types.definitions import NonNull from ..core.types.definitions import NonNull
from ..core.types.scalars import ID, Int, String from ..core.types.scalars import ID, Int, String
from ..utils import with_context from ..utils.wrap_resolver_function import has_context, with_context
class ConnectionField(Field): class ConnectionField(Field):
@ -26,16 +26,23 @@ class ConnectionField(Field):
self.connection_type = connection_type self.connection_type = connection_type
self.edge_type = edge_type self.edge_type = edge_type
def resolver(self, instance, args, info): @with_context
def resolver(self, instance, args, context, info):
schema = info.schema.graphene_schema schema = info.schema.graphene_schema
connection_type = self.get_type(schema) connection_type = self.get_type(schema)
resolved = super(ConnectionField, self).resolver(instance, args, info)
resolver = super(ConnectionField, self).resolver
if has_context(resolver):
resolved = super(ConnectionField, self).resolver(instance, args, context, info)
else:
resolved = super(ConnectionField, self).resolver(instance, args, info)
if isinstance(resolved, connection_type): if isinstance(resolved, connection_type):
return resolved return resolved
return self.from_list(connection_type, resolved, args, info) return self.from_list(connection_type, resolved, args, context, info)
def from_list(self, connection_type, resolved, args, info): def from_list(self, connection_type, resolved, args, context, info):
return connection_type.from_list(resolved, args, info) return connection_type.from_list(resolved, args, context, info)
def get_connection_type(self, node): def get_connection_type(self, node):
connection_type = self.connection_type or node.get_connection_type() connection_type = self.connection_type or node.get_connection_type()

View File

@ -37,15 +37,58 @@ class Query(graphene.ObjectType):
all_my_nodes = relay.ConnectionField( all_my_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection, customArg=graphene.String()) MyNode, connection_type=MyConnection, customArg=graphene.String())
context_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection, customArg=graphene.String())
def resolve_all_my_nodes(self, args, info): def resolve_all_my_nodes(self, args, info):
custom_arg = args.get('customArg') custom_arg = args.get('customArg')
assert custom_arg == "1" assert custom_arg == "1"
return [MyNode(name='my')] return [MyNode(name='my')]
@with_context
def resolve_context_nodes(self, args, context, info):
custom_arg = args.get('customArg')
assert custom_arg == "1"
return [MyNode(name='my')]
schema.query = Query schema.query = Query
def test_nodefield_query(): def test_nodefield_query():
query = '''
query RebelsShipsQuery {
contextNodes (customArg:"1") {
edges {
node {
name
}
},
myCustomField
pageInfo {
hasNextPage
}
}
}
'''
expected = {
'contextNodes': {
'edges': [{
'node': {
'name': 'my'
}
}],
'myCustomField': 'Custom',
'pageInfo': {
'hasNextPage': False,
}
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_connectionfield_context_query():
query = ''' query = '''
query RebelsShipsQuery { query RebelsShipsQuery {
myNode(id:"TXlOb2RlOjE=") { myNode(id:"TXlOb2RlOjE=") {

View File

@ -87,7 +87,7 @@ class Connection(ObjectType):
{'edge_type': edge_type, 'edges': edges}) {'edge_type': edge_type, 'edges': edges})
@classmethod @classmethod
def from_list(cls, iterable, args, info): def from_list(cls, iterable, args, context, info):
assert isinstance( assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable' iterable, Iterable), 'Resolved value from the connection field have to be iterable'
connection = connection_from_list( connection = connection_from_list(