diff --git a/examples/starwars_relay/schema.py b/examples/starwars_relay/schema.py index 294e4b88..26c74160 100644 --- a/examples/starwars_relay/schema.py +++ b/examples/starwars_relay/schema.py @@ -4,7 +4,9 @@ from graphene import relay, resolve_only_args from .data import create_ship, get_empire, get_faction, get_rebels, get_ship -class Ship(relay.Node, graphene.ObjectType): +class Ship(graphene.ObjectType): + class Meta: + interfaces = [relay.Node] '''A ship in the Star Wars saga''' name = graphene.String(description='The name of the ship.') diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index d29e9b2a..07f5e3f5 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -96,8 +96,8 @@ class IterableConnectionField(Field): @property def connection(self): from .node import Node - if issubclass(self._type, Node): - connection_type = self._type.get_default_connection() + if Node in self._type._meta.interfaces: + connection_type = self._type.Connection else: connection_type = self._type assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self)) diff --git a/graphene/relay/node.py b/graphene/relay/node.py index 793cd738..59645757 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -12,6 +12,15 @@ from ..utils.copy_fields import copy_fields from .connection import Connection +def get_default_connection(cls): + assert issubclass(cls, ObjectType), 'Can only get connection type on implemented Nodes.' + + class Meta: + node = cls + + return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta}) + + # We inherit from ObjectTypeMeta as we want to allow # inheriting from Node, and also ObjectType. # Like class MyNode(Node): pass @@ -24,17 +33,6 @@ class NodeMeta(ObjectTypeMeta): meta, ) - @staticmethod - def _create_objecttype(cls, name, bases, attrs): - cls = super(NodeMeta, cls)._create_objecttype(cls, name, bases, attrs) - require_get_node = Node._meta.graphql_type in cls._meta.graphql_type._provided_interfaces - if require_get_node: - assert hasattr( - cls, 'get_node'), '{}.get_node method is required by the Node interface.'.format( - cls.__name__) - - return cls - @staticmethod def _create_interface(cls, name, bases, attrs): options = cls._get_interface_options(attrs.pop('Meta', None)) @@ -95,10 +93,14 @@ class Node(six.with_metaclass(NodeMeta, Interface)): return graphql_type.graphene_type.get_node(_id, context, info) @classmethod - def get_default_connection(cls): - assert issubclass(cls, ObjectType), 'Can only get connection type on implemented Nodes.' - if not cls._connection: - class Meta: - node = cls - cls._connection = type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta}) - return cls._connection + def implements(cls, objecttype): + require_get_node = Node._meta.graphql_type in objecttype._meta.get_interfaces + get_connection = getattr(objecttype, 'get_connection', None) + if not get_connection: + get_connection = partial(get_default_connection, objecttype) + + objecttype.Connection = get_connection() + if require_get_node: + assert hasattr( + objecttype, 'get_node'), '{}.get_node method is required by the Node interface.'.format( + objecttype.__name__) diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py index fc14d9da..9b2ce03f 100644 --- a/graphene/relay/tests/test_connection.py +++ b/graphene/relay/tests/test_connection.py @@ -6,7 +6,9 @@ from ..connection import Connection from ..node import Node -class MyObject(Node, ObjectType): +class MyObject(ObjectType): + class Meta: + interfaces = [Node] field = String() @classmethod diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index 04474024..5fe7b132 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -8,8 +8,9 @@ from ..connection import Connection from ..node import Node -class MyNode(Node, ObjectType): - +class MyNode(ObjectType): + class Meta: + interfaces = [Node] name = String() @staticmethod @@ -46,12 +47,12 @@ def test_node_good(): def test_node_get_connection(): - connection = MyNode.get_default_connection() + connection = MyNode.Connection assert issubclass(connection, Connection) def test_node_get_connection_dont_duplicate(): - assert MyNode.get_default_connection() == MyNode.get_default_connection() + assert MyNode.Connection == MyNode.Connection def test_node_query(): diff --git a/graphene/relay/tests/test_node_custom.py b/graphene/relay/tests/test_node_custom.py index f6020381..cfe572d5 100644 --- a/graphene/relay/tests/test_node_custom.py +++ b/graphene/relay/tests/test_node_custom.py @@ -24,12 +24,15 @@ class BasePhoto(Interface): width = Int() -class User(CustomNode, ObjectType): +class User(ObjectType): + class Meta: + interfaces = [CustomNode] name = String() -class Photo(CustomNode, BasePhoto, ObjectType): - pass +class Photo(ObjectType): + class Meta: + interfaces = [CustomNode, BasePhoto] user_data = { @@ -50,7 +53,6 @@ schema = Schema(query=RootQuery, types=[User, Photo]) def test_str_schema_correct(): - print str(schema) assert str(schema) == '''schema { query: RootQuery } diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 49be24f0..acee8265 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -1,4 +1,4 @@ - +import inspect import six from ..utils.copy_fields import copy_fields @@ -97,12 +97,17 @@ class ObjectTypeMeta(type): if not options.graphql_type: fields = copy_fields(Field, fields, parent=cls) - base_interfaces = tuple(b for b in bases if issubclass(b, Interface)) + inherited_interfaces = tuple(b for b in bases if issubclass(b, Interface)) options.get_fields = lambda: fields - options.get_interfaces = tuple(get_interfaces(interfaces + base_interfaces)) + options.interfaces = interfaces + inherited_interfaces + options.get_interfaces = tuple(get_interfaces(options.interfaces)) options.graphql_type = generate_objecttype(cls) + for i in options.interfaces: + if inspect.isclass(i) and issubclass(i, Interface): + i.implements(cls) else: assert not fields, "Can't mount Fields in an ObjectType with a defined graphql_type" + assert not interfaces, "Can't have extra interfaces with a defined graphql_type" fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls) options.get_fields = lambda: fields @@ -175,3 +180,7 @@ class Interface(six.with_metaclass(ObjectTypeMeta)): if not isinstance(self, ObjectType): raise Exception("An interface cannot be intitialized") super(Interface, self).__init__(*args, **kwargs) + + @classmethod + def implements(cls, objecttype): + pass