From 513b3e46c392cad15700072098802ccd57adaf3d Mon Sep 17 00:00:00 2001 From: Markus Padourek Date: Mon, 12 Sep 2016 10:38:09 +0100 Subject: [PATCH] Making node argument optional for GlobalID. --- graphene-django/graphene_django/fields.py | 2 +- .../graphene_django/filter/fields.py | 2 +- .../graphene_sqlalchemy/fields.py | 2 +- graphene/relay/connection.py | 4 +-- graphene/relay/node.py | 10 +++---- graphene/relay/tests/test_node.py | 29 +++++++++++++++++-- graphene/types/field.py | 2 +- graphene/types/typemap.py | 2 +- 8 files changed, 39 insertions(+), 14 deletions(-) diff --git a/graphene-django/graphene_django/fields.py b/graphene-django/graphene_django/fields.py index b23a4b65..18637115 100644 --- a/graphene-django/graphene_django/fields.py +++ b/graphene-django/graphene_django/fields.py @@ -46,7 +46,7 @@ class DjangoConnectionField(ConnectionField): connection.length = _len return connection - def get_resolver(self, parent_resolver): + def get_resolver(self, parent_resolver, _): return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager()) diff --git a/graphene-django/graphene_django/filter/fields.py b/graphene-django/graphene_django/filter/fields.py index f4f84e29..cbd4e81e 100644 --- a/graphene-django/graphene_django/filter/fields.py +++ b/graphene-django/graphene_django/filter/fields.py @@ -34,6 +34,6 @@ class DjangoFilterConnectionField(DjangoConnectionField): return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info) - def get_resolver(self, parent_resolver): + def get_resolver(self, parent_resolver, _): return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(), self.filterset_class, self.filtering_args) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/fields.py b/graphene-sqlalchemy/graphene_sqlalchemy/fields.py index d97d2295..43215e26 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/fields.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/fields.py @@ -33,5 +33,5 @@ class SQLAlchemyConnectionField(ConnectionField): edge_type=connection.Edge, ) - def get_resolver(self, parent_resolver): + def get_resolver(self, parent_resolver, _): return partial(self.connection_resolver, parent_resolver, self.type, self.model) diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 17d9854e..77e7a2ff 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -133,8 +133,8 @@ class IterableConnectionField(Field): connection.iterable = iterable return connection - def get_resolver(self, parent_resolver): - resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) + def get_resolver(self, parent_resolver, _): + resolver = super(IterableConnectionField, self).get_resolver(parent_resolver, None) return partial(self.connection_resolver, resolver, self.type) ConnectionField = IterableConnectionField diff --git a/graphene/relay/node.py b/graphene/relay/node.py index f73f6a53..7e8934c4 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -35,8 +35,8 @@ def get_default_connection(cls): class GlobalID(Field): - def __init__(self, node, *args, **kwargs): - super(GlobalID, self).__init__(ID, *args, **kwargs) + def __init__(self, node=None, required=True, *args, **kwargs): + super(GlobalID, self).__init__(ID, required=required, *args, **kwargs) self.node = node @staticmethod @@ -44,15 +44,15 @@ class GlobalID(Field): id = parent_resolver(root, args, context, info) return node.to_global_id(info.parent_type.name, id) # root._meta.name - def get_resolver(self, parent_resolver): - return partial(self.id_resolver, parent_resolver, self.node) + def get_resolver(self, parent_resolver, parent_type): + return partial(self.id_resolver, parent_resolver, self.node or parent_type) class NodeMeta(InterfaceMeta): def __new__(cls, name, bases, attrs): cls = InterfaceMeta.__new__(cls, name, bases, attrs) - cls._meta.fields['id'] = GlobalID(cls, required=True, description='The ID of the object.') + cls._meta.fields['id'] = GlobalID(cls, description='The ID of the object.') return cls diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index ef35f409..dc2d3e43 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -5,7 +5,7 @@ from graphql_relay import to_global_id from ...types import ObjectType, Schema, String, AbstractType from ..connection import Connection -from ..node import Node +from ..node import Node, GlobalID class SharedNodeFields(AbstractType): @@ -28,6 +28,18 @@ class MyNode(ObjectType): return MyNode(name=str(id)) +class MyNodeImplementedId(ObjectType): + + class Meta: + interfaces = (Node, ) + id = GlobalID() + name = String() + + @staticmethod + def get_node(id, *_): + return MyNodeImplementedId(name=str(id) + '!') + + class MyOtherNode(SharedNodeFields, ObjectType): extra_field = String() @@ -46,7 +58,7 @@ class RootQuery(ObjectType): first = String() node = Node.Field() -schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode]) +schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode, MyNodeImplementedId]) def test_node_good(): @@ -78,6 +90,14 @@ def test_subclassed_node_query(): assert executed.data == OrderedDict({'node': OrderedDict([('shared', '1'), ('extraField', 'extra field info.'), ('somethingElse', '----')])}) +def test_node_query_implemented_id(): + executed = schema.execute( + '{ node(id:"%s") { ... on MyNodeImplementedId { name } } }' % to_global_id("MyNodeImplementedId", 1) + ) + assert not executed.errors + assert executed.data == {'node': {'name': '1!'}} + + def test_node_query_incorrect_id(): executed = schema.execute( '{ node(id:"%s") { ... on MyNode { name } } }' % "something:2" @@ -97,6 +117,11 @@ type MyNode implements Node { name: String } +type MyNodeImplementedId implements Node { + id: ID! + name: String +} + type MyOtherNode implements Node { id: ID! shared: String diff --git a/graphene/types/field.py b/graphene/types/field.py index b1122696..d6526531 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -45,5 +45,5 @@ class Field(OrderedType): return self._type() return self._type - def get_resolver(self, parent_resolver): + def get_resolver(self, parent_resolver, _): return self.resolver or parent_resolver diff --git a/graphene/types/typemap.py b/graphene/types/typemap.py index 360939f2..0bb2c3c5 100644 --- a/graphene/types/typemap.py +++ b/graphene/types/typemap.py @@ -219,7 +219,7 @@ class TypeMap(GraphQLTypeMap): _field = GraphQLField( field_type, args=args, - resolver=field.get_resolver(self.get_resolver_for_type(type, name)), + resolver=field.get_resolver(self.get_resolver_for_type(type, name), type), deprecation_reason=field.deprecation_reason, description=field.description )