From 51e97510c00b7cd83450d325e53eb6b34d17b89c Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Fri, 10 Jun 2016 00:23:31 -0700 Subject: [PATCH] Improved relay nodes and field copies --- examples/starwars_relay/schema.py | 4 +- graphene/relay/mutation.py | 7 +-- graphene/relay/node.py | 25 +++++---- graphene/relay/tests/test_node.py | 11 ++-- graphene/relay/tests/test_node_custom.py | 6 +-- graphene/types/field.py | 64 ++++++++++++++++++------ graphene/types/interface.py | 5 +- graphene/types/objecttype.py | 1 + graphene/types/tests/test_objecttype.py | 1 - 9 files changed, 80 insertions(+), 44 deletions(-) diff --git a/examples/starwars_relay/schema.py b/examples/starwars_relay/schema.py index 0e17b5f2..be410dac 100644 --- a/examples/starwars_relay/schema.py +++ b/examples/starwars_relay/schema.py @@ -4,7 +4,7 @@ from graphene import implements, relay, resolve_only_args from .data import create_ship, get_empire, get_faction, get_rebels, get_ship -class Ship(graphene.ObjectType, relay.Node): +class Ship(relay.Node, graphene.ObjectType): '''A ship in the Star Wars saga''' name = graphene.String(description='The name of the ship.') @@ -13,7 +13,7 @@ class Ship(graphene.ObjectType, relay.Node): return get_ship(id) -class Faction(graphene.ObjectType, relay.Node): +class Faction(relay.Node, graphene.ObjectType): '''A faction in the Star Wars saga''' name = graphene.String(description='The name of the faction.') # ships = relay.ConnectionField( diff --git a/graphene/relay/mutation.py b/graphene/relay/mutation.py index 74a3cd52..b5ab4c3c 100644 --- a/graphene/relay/mutation.py +++ b/graphene/relay/mutation.py @@ -39,14 +39,9 @@ class ClientIDMutationMeta(MutationMeta): input_fields=input_local_fields, output_fields=cls._fields(bases, attrs, local_fields), mutate_and_get_payload=cls.mutate_and_get_payload, - - input_type_class=partial(GrapheneInputObjectType, graphene_type=cls), - input_field_class=InputField, - output_type_class=partial(GrapheneObjectType, graphene_type=cls), - field_class=Field, ) cls._meta.graphql_type = field.type - cls.Field = partial(Field.copy_and_extend, field, type=None, _creation_counter=None) + cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None) constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs) return constructed diff --git a/graphene/relay/node.py b/graphene/relay/node.py index 2f2d78d3..db959d7a 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -11,25 +11,30 @@ from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMet class NodeMeta(ObjectTypeMeta): def construct(cls, bases, attrs): - if not cls.is_object_type(): - cls.get_node = attrs.pop('get_node') + is_object_type = cls.is_object_type() + cls = super(NodeMeta, cls).construct(bases, attrs) + if not is_object_type: + get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None) + assert get_node_from_global_id, '{}.get_node method is required by the Node interface.'.format(cls.__name__) cls.id_resolver = attrs.pop('id_resolver', None) node_interface, node_field = node_definitions( - cls.get_node, + get_node_from_global_id, id_resolver=cls.id_resolver, - interface_class=partial(GrapheneInterfaceType, graphene_type=cls), - field_class=Field, ) cls._meta.graphql_type = node_interface - cls.Field = partial(Field.copy_and_extend, node_field, type=None, _creation_counter=None) - return super(NodeMeta, cls).construct(bases, attrs) - + cls.Field = partial(Field.copy_and_extend, node_field, type=node_field.type, _creation_counter=None) + else: + # The interface provided by node_definitions is not an instance + # of GrapheneInterfaceType, so it will have no graphql_type, + # so will not trigger Node.implements + cls.implements(cls) + return cls class Node(six.with_metaclass(NodeMeta, Interface)): @classmethod def require_get_node(cls): - return cls == Node + return Node._meta.graphql_type in cls._meta.graphql_type._provided_interfaces @classmethod def from_global_id(cls, global_id): @@ -44,7 +49,7 @@ class Node(six.with_metaclass(NodeMeta, Interface)): return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None)) @classmethod - def get_node(cls, global_id, context, info): + def get_node_from_global_id(cls, global_id, context, info): try: _type, _id = cls.from_global_id(global_id) except: diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index 1420fe8c..58c0eb55 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -7,7 +7,7 @@ from ...types import ObjectType, Schema, implements from ...types.scalars import String -class MyNode(ObjectType, Node): +class MyNode(Node, ObjectType): name = String() @@ -25,9 +25,7 @@ schema = Schema(query=RootQuery, types=[MyNode]) def test_node_no_get_node(): with pytest.raises(AssertionError) as excinfo: - class MyNode(ObjectType): - class Meta: - interfaces = [Node] + class MyNode(Node, ObjectType): pass assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value) @@ -35,9 +33,8 @@ def test_node_no_get_node(): def test_node_no_get_node_with_meta(): with pytest.raises(AssertionError) as excinfo: - class MyNode(ObjectType): - class Meta: - interfaces = [Node] + class MyNode(Node, ObjectType): + pass assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value) diff --git a/graphene/relay/tests/test_node_custom.py b/graphene/relay/tests/test_node_custom.py index 133c6db0..591c1200 100644 --- a/graphene/relay/tests/test_node_custom.py +++ b/graphene/relay/tests/test_node_custom.py @@ -7,7 +7,7 @@ from ...types.scalars import String, Int class CustomNode(Node): @staticmethod - def get_node(id, context, info): + def get_node_from_global_id(id, context, info): assert info.schema == schema if id in user_data: return user_data.get(id) @@ -15,11 +15,11 @@ class CustomNode(Node): return photo_data.get(id) -class User(ObjectType, CustomNode): +class User(CustomNode, ObjectType): name = String() -class Photo(ObjectType, CustomNode): +class Photo(CustomNode, ObjectType): width = Int() diff --git a/graphene/types/field.py b/graphene/types/field.py index 9e7db8b1..b7abaead 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -116,20 +116,40 @@ class Field(AbstractField, GraphQLField, OrderedType): @classmethod def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, **extra_args): + if isinstance(field, Field): + type = type or field._type + resolver = resolver or field._resolver + source = source or field.source + name = name or field._name + required = required or field.required + _creation_counter = field.creation_counter if _creation_counter is False else None + attname = field.attname + parent = field.parent + else: + # If is a GraphQLField + type = type or field.type + resolver = resolver or field.resolver + source = None + name = field.name + required = None + _creation_counter = None + attname = name + parent = None + new_field = cls( - type=type or field._type, + type=type, args=to_arguments(args, field.args), - resolver=field._resolver, - source=source or getattr(field, 'source', None), + resolver=resolver, + source=source, deprecation_reason=field.deprecation_reason, - name=field._name, - required=required or getattr(field, 'required', False), + name=name, + required=required, description=field.description, - _creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None, + _creation_counter=_creation_counter, **extra_args ) - new_field.attname = field.attname - new_field.parent = field.parent + new_field.attname = attname + new_field.parent = parent return new_field def __str__(self): @@ -168,14 +188,30 @@ class InputField(AbstractField, GraphQLInputObjectField, OrderedType): @classmethod def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, _creation_counter=False): + if isinstance(field, Field): + type = type or field._type + name = name or field._name + required = required or field.required + _creation_counter = field.creation_counter if _creation_counter is False else None + attname = field.attname + parent = field.parent + else: + # If is a GraphQLField + type = type or field.type + name = field.name + required = None + _creation_counter = None + attname = None + parent = None + new_field = cls( - type=type or field._type, - name=name or field._name, - required=required or field.required, + type=type, + name=name, + required=required, default_value=default_value or field.default_value, description=description or field.description, - _creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None, + _creation_counter=_creation_counter, ) - new_field.attname = field.attname - new_field.parent = field.parent + new_field.attname = attname + new_field.parent = parent return new_field diff --git a/graphene/types/interface.py b/graphene/types/interface.py index 9843f9dc..6d51f238 100644 --- a/graphene/types/interface.py +++ b/graphene/types/interface.py @@ -40,7 +40,10 @@ class Interface(six.with_metaclass(InterfaceTypeMeta)): abstract = True def __init__(self, *args, **kwargs): - raise Exception("An interface cannot be intitialized") + from .objecttype import ObjectType + if not isinstance(self, ObjectType): + raise Exception("An interface cannot be intitialized") + super(Interface, self).__init__(*args, **kwargs) @classmethod def implements(cls, object_type): diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index a5e83f1d..607c8524 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -71,6 +71,7 @@ class ObjectTypeMeta(InterfaceTypeMeta): cls._meta.interfaces, cls.get_interfaces(bases), ))) + local_fields = cls._extract_local_fields(attrs) if not cls._meta.graphql_type: cls = super(ObjectTypeMeta, cls).construct(bases, attrs) diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index 65a801d0..9379cc6a 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -93,7 +93,6 @@ def test_objecttype_as_container_get_fields(): def test_parent_container_get_fields(): fields = Container._meta.graphql_type.get_fields() - print [(f.creation_counter, f.name) for f in fields.values()] assert fields.keys() == ['field1', 'field2']