diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index af0eaf4d..02b3dbe5 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -91,23 +91,22 @@ class IterableConnectionField(Field): assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self)) return connection_type + def connection_resolver(self, root, args, context, info): + iterable = super(ConnectionField, self).resolver(root, args, context, info) + # if isinstance(resolved, self.type.graphene) + assert isinstance( + iterable, Iterable), 'Resolved value from the connection field have to be iterable. Received "{}"'.format(iterable) + connection = connection_from_list( + iterable, + args, + connection_type=self.connection, + edge_type=self.connection.Edge, + ) + return connection + @property def resolver(self): - super_resolver = super(ConnectionField, self).resolver - - def resolver(root, args, context, info): - iterable = super_resolver(root, args, context, info) - # if isinstance(resolved, self.type.graphene) - assert isinstance( - iterable, Iterable), 'Resolved value from the connection field have to be iterable' - connection = connection_from_list( - iterable, - args, - connection_type=self.connection, - edge_type=self.connection.Edge, - ) - return connection - return resolver + return self.connection_resolver @resolver.setter def resolver(self, resolver): diff --git a/graphene/relay/node.py b/graphene/relay/node.py index 8394bfb7..4467746d 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -6,41 +6,59 @@ from graphql_relay import from_global_id, node_definitions, to_global_id from ..types.field import Field from ..types.interface import Interface -from ..types.objecttype import ObjectType, ObjectTypeMeta +from ..types.objecttype import ObjectType, ObjectTypeMeta, is_objecttype from ..types.options import Options from .connection import Connection +# We inherit from ObjectTypeMeta as we want to allow +# inheriting from Node, and also ObjectType. +# Like class MyNode(Node): pass +# And class MyNodeImplementation(Node, ObjectType): pass class NodeMeta(ObjectTypeMeta): + @staticmethod + def _get_interface_options(meta): + return Options( + meta, + ) + def __new__(cls, name, bases, attrs): - cls = super(NodeMeta, cls).__new__(cls, name, bases, attrs) - is_object_type = cls.is_object_type() - if not is_object_type: - get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None) - id_resolver = getattr(cls, 'id_resolver', None) - assert get_node_from_global_id, '{}.get_node_from_global_id method is required by the Node interface.'.format(cls.__name__) - node_interface, node_field = node_definitions( - get_node_from_global_id, - id_resolver=id_resolver, - ) - cls._meta = Options(None, graphql_type=node_interface) - cls.Field = partial( - Field.copy_and_extend, - node_field, - type=node_field.type, - parent=cls, - _creation_counter=None) - else: + + if is_objecttype(bases): + cls = super(NodeMeta, cls).__new__(cls, name, bases, attrs) # The interface provided by node_definitions is not an instance # of GrapheneInterfaceType, so it will have no graphql_type, # so will not trigger Node.implements cls.implements(cls) + return cls + + options = cls._get_interface_options(attrs.pop('Meta', None)) + cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) + + get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None) + id_resolver = getattr(cls, 'id_resolver', None) + assert get_node_from_global_id, '{}.get_node_from_global_id method is required by the Node interface.'.format(cls.__name__) + node_interface, node_field = node_definitions( + get_node_from_global_id, + id_resolver=id_resolver, + type_resolver=cls.resolve_type, + ) + options.graphql_type = node_interface + cls.Field = partial( + Field.copy_and_extend, + node_field, + type=node_field.type, + parent=cls, + _creation_counter=None) + return cls class Node(six.with_metaclass(NodeMeta, Interface)): _connection = None + resolve_type = None + use_global_id = True @classmethod def require_get_node(cls): @@ -48,15 +66,17 @@ class Node(six.with_metaclass(NodeMeta, Interface)): @classmethod def from_global_id(cls, global_id): - if cls is Node: - return from_global_id(global_id) - raise NotImplementedError("You need to implement {}.from_global_id".format(cls.__name__)) + return from_global_id(global_id) + # if cls is Node: + # return from_global_id(global_id) + # raise NotImplementedError("You need to implement {}.from_global_id".format(cls.__name__)) @classmethod def to_global_id(cls, type, id): - if cls is Node: - return to_global_id(type, id) - raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__)) + return to_global_id(type, id) + # if cls is Node: + # return to_global_id(type, id) + # raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__)) @classmethod def id_resolver(cls, root, args, context, info):