Making node argument optional for GlobalID.

This commit is contained in:
Markus Padourek 2016-09-12 10:38:09 +01:00
parent 94d46f7960
commit 513b3e46c3
8 changed files with 39 additions and 14 deletions

View File

@ -46,7 +46,7 @@ class DjangoConnectionField(ConnectionField):
connection.length = _len connection.length = _len
return connection return connection
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver, _):
return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager()) return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager())

View File

@ -34,6 +34,6 @@ class DjangoFilterConnectionField(DjangoConnectionField):
return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info) return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info)
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver, _):
return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(), return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(),
self.filterset_class, self.filtering_args) self.filterset_class, self.filtering_args)

View File

@ -33,5 +33,5 @@ class SQLAlchemyConnectionField(ConnectionField):
edge_type=connection.Edge, edge_type=connection.Edge,
) )
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver, _):
return partial(self.connection_resolver, parent_resolver, self.type, self.model) return partial(self.connection_resolver, parent_resolver, self.type, self.model)

View File

@ -133,8 +133,8 @@ class IterableConnectionField(Field):
connection.iterable = iterable connection.iterable = iterable
return connection return connection
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver, _):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) resolver = super(IterableConnectionField, self).get_resolver(parent_resolver, None)
return partial(self.connection_resolver, resolver, self.type) return partial(self.connection_resolver, resolver, self.type)
ConnectionField = IterableConnectionField ConnectionField = IterableConnectionField

View File

@ -35,8 +35,8 @@ def get_default_connection(cls):
class GlobalID(Field): class GlobalID(Field):
def __init__(self, node, *args, **kwargs): def __init__(self, node=None, required=True, *args, **kwargs):
super(GlobalID, self).__init__(ID, *args, **kwargs) super(GlobalID, self).__init__(ID, required=required, *args, **kwargs)
self.node = node self.node = node
@staticmethod @staticmethod
@ -44,15 +44,15 @@ class GlobalID(Field):
id = parent_resolver(root, args, context, info) id = parent_resolver(root, args, context, info)
return node.to_global_id(info.parent_type.name, id) # root._meta.name return node.to_global_id(info.parent_type.name, id) # root._meta.name
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver, parent_type):
return partial(self.id_resolver, parent_resolver, self.node) return partial(self.id_resolver, parent_resolver, self.node or parent_type)
class NodeMeta(InterfaceMeta): class NodeMeta(InterfaceMeta):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
cls = InterfaceMeta.__new__(cls, name, bases, attrs) cls = InterfaceMeta.__new__(cls, name, bases, attrs)
cls._meta.fields['id'] = GlobalID(cls, required=True, description='The ID of the object.') cls._meta.fields['id'] = GlobalID(cls, description='The ID of the object.')
return cls return cls

View File

@ -5,7 +5,7 @@ from graphql_relay import to_global_id
from ...types import ObjectType, Schema, String, AbstractType from ...types import ObjectType, Schema, String, AbstractType
from ..connection import Connection from ..connection import Connection
from ..node import Node from ..node import Node, GlobalID
class SharedNodeFields(AbstractType): class SharedNodeFields(AbstractType):
@ -28,6 +28,18 @@ class MyNode(ObjectType):
return MyNode(name=str(id)) return MyNode(name=str(id))
class MyNodeImplementedId(ObjectType):
class Meta:
interfaces = (Node, )
id = GlobalID()
name = String()
@staticmethod
def get_node(id, *_):
return MyNodeImplementedId(name=str(id) + '!')
class MyOtherNode(SharedNodeFields, ObjectType): class MyOtherNode(SharedNodeFields, ObjectType):
extra_field = String() extra_field = String()
@ -46,7 +58,7 @@ class RootQuery(ObjectType):
first = String() first = String()
node = Node.Field() node = Node.Field()
schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode]) schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode, MyNodeImplementedId])
def test_node_good(): def test_node_good():
@ -78,6 +90,14 @@ def test_subclassed_node_query():
assert executed.data == OrderedDict({'node': OrderedDict([('shared', '1'), ('extraField', 'extra field info.'), ('somethingElse', '----')])}) assert executed.data == OrderedDict({'node': OrderedDict([('shared', '1'), ('extraField', 'extra field info.'), ('somethingElse', '----')])})
def test_node_query_implemented_id():
executed = schema.execute(
'{ node(id:"%s") { ... on MyNodeImplementedId { name } } }' % to_global_id("MyNodeImplementedId", 1)
)
assert not executed.errors
assert executed.data == {'node': {'name': '1!'}}
def test_node_query_incorrect_id(): def test_node_query_incorrect_id():
executed = schema.execute( executed = schema.execute(
'{ node(id:"%s") { ... on MyNode { name } } }' % "something:2" '{ node(id:"%s") { ... on MyNode { name } } }' % "something:2"
@ -97,6 +117,11 @@ type MyNode implements Node {
name: String name: String
} }
type MyNodeImplementedId implements Node {
id: ID!
name: String
}
type MyOtherNode implements Node { type MyOtherNode implements Node {
id: ID! id: ID!
shared: String shared: String

View File

@ -45,5 +45,5 @@ class Field(OrderedType):
return self._type() return self._type()
return self._type return self._type
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver, _):
return self.resolver or parent_resolver return self.resolver or parent_resolver

View File

@ -219,7 +219,7 @@ class TypeMap(GraphQLTypeMap):
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolver=field.get_resolver(self.get_resolver_for_type(type, name)), resolver=field.get_resolver(self.get_resolver_for_type(type, name), type),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description description=field.description
) )