Add the info parameter (ResolveInfo) to get_node() calls.

This allows request_context (or any other ResolveInfo data) to be used while getting nodes.
For example, some data might need to be hidden based on the user's authorization; you would use info.request_context for this.

Fixes #34.
This commit is contained in:
Jon Rosebaugh 2015-11-17 23:43:20 -05:00
parent a970d99b69
commit a79a76d3b9
5 changed files with 41 additions and 8 deletions

View File

@ -17,7 +17,7 @@ class Ship(DjangoNode):
model = ShipModel model = ShipModel
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id, info):
return Ship(get_ship(id)) return Ship(get_ship(id))
@ -33,7 +33,7 @@ class Faction(DjangoNode):
model = FactionModel model = FactionModel
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id, info):
return Faction(get_faction(id)) return Faction(get_faction(id))

View File

@ -11,7 +11,7 @@ class Ship(relay.Node):
name = graphene.String(description='The name of the ship.') name = graphene.String(description='The name of the ship.')
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id, info):
return get_ship(id) return get_ship(id)
@ -27,7 +27,7 @@ class Faction(relay.Node):
return [get_ship(ship_id) for ship_id in self.ships] return [get_ship(ship_id) for ship_id in self.ships]
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id, info):
return get_faction(id) return get_faction(id)

View File

@ -66,7 +66,7 @@ def test_should_node():
model = Reporter model = Reporter
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id, info):
return ReporterNode(Reporter(id=2, first_name='Cookie Monster')) return ReporterNode(Reporter(id=2, first_name='Cookie Monster'))
def resolve_articles(self, *args, **kwargs): def resolve_articles(self, *args, **kwargs):
@ -78,7 +78,7 @@ def test_should_node():
model = Article model = Article
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id, info):
return ArticleNode(Article(id=1, headline='Article node')) return ArticleNode(Article(id=1, headline='Article node'))
class Query(graphene.ObjectType): class Query(graphene.ObjectType):

View File

@ -92,7 +92,7 @@ class NodeField(Field):
object_type != self.field_object_type): object_type != self.field_object_type):
return return
return object_type.get_node(_id) return object_type.get_node(_id, info)
def resolver(self, instance, args, info): def resolver(self, instance, args, info):
global_id = args.get('id') global_id = args.get('id')

View File

@ -1,3 +1,4 @@
import pytest
from graphql.core.type import GraphQLID, GraphQLNonNull from graphql.core.type import GraphQLID, GraphQLNonNull
import graphene import graphene
@ -15,12 +16,22 @@ class MyNode(relay.Node):
name = graphene.String() name = graphene.String()
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id, info):
return MyNode(id=id, name='mo') return MyNode(id=id, name='mo')
class SpecialNode(relay.Node):
value = graphene.String()
@classmethod
def get_node(cls, id, info):
value = "!!!" if info.request_context.get('is_special') else "???"
return SpecialNode(id=id, value=value)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
my_node = relay.NodeField(MyNode) my_node = relay.NodeField(MyNode)
special_node = relay.NodeField(SpecialNode)
all_my_nodes = relay.ConnectionField( all_my_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection, customArg=graphene.String()) MyNode, connection_type=MyConnection, customArg=graphene.String())
@ -79,6 +90,28 @@ def test_nodefield_query():
assert result.data == expected assert result.data == expected
@pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')])
def test_get_node_info(specialness, value):
query = '''
query ValueQuery {
specialNode(id:"U3BlY2lhbE5vZGU6Mg==") {
id
value
}
}
'''
expected = {
'specialNode': {
'id': 'U3BlY2lhbE5vZGU6Mg==',
'value': value,
},
}
result = schema.execute(query, request_context={'is_special': specialness})
assert not result.errors
assert result.data == expected
def test_nodeidfield(): def test_nodeidfield():
id_field = MyNode._meta.fields_map['id'] id_field = MyNode._meta.fields_map['id']
id_field_type = schema.T(id_field) id_field_type = schema.T(id_field)