Merge pull request #177 from graphql-python/bugfixes/context-in-connectionfield

Added context to connectionfield resolver
This commit is contained in:
Syrus Akbary 2016-05-21 00:47:56 -07:00
commit 89074492e0
5 changed files with 61 additions and 11 deletions

View File

@ -27,10 +27,10 @@ class DjangoConnectionField(ConnectionField):
def get_queryset(self, resolved_qs, args, info):
return resolved_qs
def from_list(self, connection_type, resolved, args, info):
def from_list(self, connection_type, resolved, args, context, info):
resolved_qs = maybe_queryset(resolved)
qs = self.get_queryset(resolved_qs, args, info)
return super(DjangoConnectionField, self).from_list(connection_type, qs, args, info)
return super(DjangoConnectionField, self).from_list(connection_type, qs, args, context, info)
class ConnectionOrListField(Field):

View File

@ -21,11 +21,11 @@ class SQLAlchemyConnectionField(ConnectionField):
def model(self):
return self.type._meta.model
def from_list(self, connection_type, resolved, args, info):
def from_list(self, connection_type, resolved, args, context, info):
if resolved is DefaultQuery:
resolved = get_query(self.model, info)
query = maybe_query(resolved)
return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, info)
return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, context, info)
class ConnectionOrListField(Field):

View File

@ -5,7 +5,7 @@ from graphql_relay.node.node import from_global_id
from ..core.fields import Field
from ..core.types.definitions import NonNull
from ..core.types.scalars import ID, Int, String
from ..utils import with_context
from ..utils.wrap_resolver_function import has_context, with_context
class ConnectionField(Field):
@ -26,16 +26,23 @@ class ConnectionField(Field):
self.connection_type = connection_type
self.edge_type = edge_type
def resolver(self, instance, args, info):
@with_context
def resolver(self, instance, args, context, info):
schema = info.schema.graphene_schema
connection_type = self.get_type(schema)
resolved = super(ConnectionField, self).resolver(instance, args, info)
resolver = super(ConnectionField, self).resolver
if has_context(resolver):
resolved = super(ConnectionField, self).resolver(instance, args, context, info)
else:
resolved = super(ConnectionField, self).resolver(instance, args, info)
if isinstance(resolved, connection_type):
return resolved
return self.from_list(connection_type, resolved, args, info)
return self.from_list(connection_type, resolved, args, context, info)
def from_list(self, connection_type, resolved, args, info):
return connection_type.from_list(resolved, args, info)
def from_list(self, connection_type, resolved, args, context, info):
return connection_type.from_list(resolved, args, context, info)
def get_connection_type(self, node):
connection_type = self.connection_type or node.get_connection_type()

View File

@ -37,15 +37,58 @@ class Query(graphene.ObjectType):
all_my_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection, customArg=graphene.String())
context_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection, customArg=graphene.String())
def resolve_all_my_nodes(self, args, info):
custom_arg = args.get('customArg')
assert custom_arg == "1"
return [MyNode(name='my')]
@with_context
def resolve_context_nodes(self, args, context, info):
custom_arg = args.get('customArg')
assert custom_arg == "1"
return [MyNode(name='my')]
schema.query = Query
def test_nodefield_query():
query = '''
query RebelsShipsQuery {
contextNodes (customArg:"1") {
edges {
node {
name
}
},
myCustomField
pageInfo {
hasNextPage
}
}
}
'''
expected = {
'contextNodes': {
'edges': [{
'node': {
'name': 'my'
}
}],
'myCustomField': 'Custom',
'pageInfo': {
'hasNextPage': False,
}
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_connectionfield_context_query():
query = '''
query RebelsShipsQuery {
myNode(id:"TXlOb2RlOjE=") {

View File

@ -87,7 +87,7 @@ class Connection(ObjectType):
{'edge_type': edge_type, 'edges': edges})
@classmethod
def from_list(cls, iterable, args, info):
def from_list(cls, iterable, args, context, info):
assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable'
connection = connection_from_list(