diff --git a/examples/starwars_relay/schema.py b/examples/starwars_relay/schema.py index e4fe0bc0..294e4b88 100644 --- a/examples/starwars_relay/schema.py +++ b/examples/starwars_relay/schema.py @@ -16,13 +16,12 @@ class Ship(relay.Node, graphene.ObjectType): class Faction(relay.Node, graphene.ObjectType): '''A faction in the Star Wars saga''' name = graphene.String(description='The name of the faction.') - # ships = relay.ConnectionField( - # Ship, description='The ships used by the faction.') - ships = graphene.List(graphene.String) - # @resolve_only_args - # def resolve_ships(self, **args): - # # Transform the instance ship_ids into real instances - # return [get_ship(ship_id) for ship_id in self.ships] + ships = relay.ConnectionField(Ship, description='The ships used by the faction.') + + @resolve_only_args + def resolve_ships(self, **args): + # Transform the instance ship_ids into real instances + return [get_ship(ship_id) for ship_id in self.ships] @classmethod def get_node(cls, id, context, info): diff --git a/examples/starwars_relay/tests/test_connections.py b/examples/starwars_relay/tests/test_connections.py index 0ef14311..17b6865f 100644 --- a/examples/starwars_relay/tests/test_connections.py +++ b/examples/starwars_relay/tests/test_connections.py @@ -1,38 +1,38 @@ -# from ..data import setup -# from ..schema import schema +from ..data import setup +from ..schema import schema -# setup() +setup() -# def test_correct_fetch_first_ship_rebels(): -# query = ''' -# query RebelsShipsQuery { -# rebels { -# name, -# ships(first: 1) { -# edges { -# node { -# name -# } -# } -# } -# } -# } -# ''' -# expected = { -# 'rebels': { -# 'name': 'Alliance to Restore the Republic', -# 'ships': { -# 'edges': [ -# { -# 'node': { -# 'name': 'X-Wing' -# } -# } -# ] -# } -# } -# } -# result = schema.execute(query) -# assert not result.errors -# assert result.data == expected +def test_correct_fetch_first_ship_rebels(): + query = ''' + query RebelsShipsQuery { + rebels { + name, + ships(first: 1) { + edges { + node { + name + } + } + } + } + } + ''' + expected = { + 'rebels': { + 'name': 'Alliance to Restore the Republic', + 'ships': { + 'edges': [ + { + 'node': { + 'name': 'X-Wing' + } + } + ] + } + } + } + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/examples/starwars_relay/tests/test_mutation.py b/examples/starwars_relay/tests/test_mutation.py index c5d63139..e5ae2716 100644 --- a/examples/starwars_relay/tests/test_mutation.py +++ b/examples/starwars_relay/tests/test_mutation.py @@ -14,14 +14,14 @@ def test_mutations(): } faction { name - # ships { - # edges { - # node { - # id - # name - # } - # } - # } + ships { + edges { + node { + id + name + } + } + } } } } @@ -34,39 +34,39 @@ def test_mutations(): }, 'faction': { 'name': 'Alliance to Restore the Republic', - # 'ships': { - # 'edges': [{ - # 'node': { - # 'id': 'U2hpcDox', - # 'name': 'X-Wing' - # } - # }, { - # 'node': { - # 'id': 'U2hpcDoy', - # 'name': 'Y-Wing' - # } - # }, { - # 'node': { - # 'id': 'U2hpcDoz', - # 'name': 'A-Wing' - # } - # }, { - # 'node': { - # 'id': 'U2hpcDo0', - # 'name': 'Millenium Falcon' - # } - # }, { - # 'node': { - # 'id': 'U2hpcDo1', - # 'name': 'Home One' - # } - # }, { - # 'node': { - # 'id': 'U2hpcDo5', - # 'name': 'Peter' - # } - # }] - # }, + 'ships': { + 'edges': [{ + 'node': { + 'id': 'U2hpcDox', + 'name': 'X-Wing' + } + }, { + 'node': { + 'id': 'U2hpcDoy', + 'name': 'Y-Wing' + } + }, { + 'node': { + 'id': 'U2hpcDoz', + 'name': 'A-Wing' + } + }, { + 'node': { + 'id': 'U2hpcDo0', + 'name': 'Millenium Falcon' + } + }, { + 'node': { + 'id': 'U2hpcDo1', + 'name': 'Home One' + } + }, { + 'node': { + 'id': 'U2hpcDo5', + 'name': 'Peter' + } + }] + }, } } } diff --git a/graphene/relay/__init__.py b/graphene/relay/__init__.py index 576b696c..1a20ec41 100644 --- a/graphene/relay/__init__.py +++ b/graphene/relay/__init__.py @@ -1,9 +1,10 @@ from .node import Node from .mutation import ClientIDMutation -from .connection import Connection +from .connection import Connection, ConnectionField __all__ = [ 'Node', 'ClientIDMutation', 'Connection', + 'ConnectionField', ] diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 16a4d441..55f7ae49 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -4,6 +4,7 @@ from collections import Iterable import six from graphql_relay import connection_definitions, connection_from_list +from graphql_relay.connection.connection import connection_args from ..types.field import Field from ..types.objecttype import ObjectType, ObjectTypeMeta @@ -60,24 +61,57 @@ class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): resolve_node = None resolve_cursor = None + def __init__(self, *args, **kwargs): + kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info')) + super(Connection, self).__init__(*args, **kwargs) + class IterableConnectionField(Field): - # def __init__(self, type, *args, **kwargs): - # if - def resolver(self, root, args, context, info): - iterable = super(ConnectionField, self).resolver(root, args, context, info) - # if isinstance(resolved, self.type.graphene) - assert isinstance( - iterable, Iterable), 'Resolved value from the connection field have to be iterable' - connection = connection_from_list( - iterable, - args, - connection_type=None, - edge_type=None, - pageinfo_type=None - ) - return connection + def __init__(self, type, args={}, *other_args, **kwargs): + super(IterableConnectionField, self).__init__(type, args=connection_args, *other_args, **kwargs) + @property + def type(self): + from ..utils.get_graphql_type import get_graphql_type + return get_graphql_type(self.connection) + + @type.setter + def type(self, value): + self._type = value + + @property + def connection(self): + from .node import Node + graphql_type = super(IterableConnectionField, self).type + if issubclass(graphql_type.graphene_type, Node): + connection_type = graphql_type.graphene_type.get_default_connection() + else: + connection_type = graphql_type.graphene_type + assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self)) + return connection_type + + @property + def resolver(self): + super_resolver = super(ConnectionField, self).resolver + + def resolver(root, args, context, info): + iterable = super_resolver(root, args, context, info) + # if isinstance(resolved, self.type.graphene) + assert isinstance( + iterable, Iterable), 'Resolved value from the connection field have to be iterable' + connection = connection_from_list( + iterable, + args, + connection_type=self.connection, + edge_type=self.connection.Edge, + pageinfo_type=None + ) + return connection + return resolver + + @resolver.setter + def resolver(self, resolver): + self._resolver = resolver ConnectionField = IterableConnectionField diff --git a/graphene/relay/node.py b/graphene/relay/node.py index 778f8295..52869fea 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -4,9 +4,10 @@ import six from graphql_relay import from_global_id, node_definitions, to_global_id +from .connection import Connection from ..types.field import Field from ..types.interface import Interface -from ..types.objecttype import ObjectTypeMeta +from ..types.objecttype import ObjectTypeMeta, ObjectType from ..types.options import Options @@ -39,6 +40,7 @@ class NodeMeta(ObjectTypeMeta): class Node(six.with_metaclass(NodeMeta, Interface)): + _connection = None @classmethod def require_get_node(cls): @@ -71,6 +73,15 @@ class Node(six.with_metaclass(NodeMeta, Interface)): return return graphql_type.graphene_type.get_node(_id, context, info) + @classmethod + def get_default_connection(cls): + assert issubclass(cls, ObjectType), 'Can only get connection type on implemented Nodes.' + if not cls._connection: + class Meta: + node = cls + cls._connection = type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta}) + return cls._connection + @classmethod def implements(cls, object_type): ''' diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index 2a57539a..54266af6 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -5,6 +5,7 @@ from graphql_relay import to_global_id from ...types import ObjectType, Schema from ...types.scalars import String from ..node import Node +from ..connection import Connection class MyNode(Node, ObjectType): @@ -44,6 +45,15 @@ def test_node_good(): assert 'id' in graphql_type.get_fields() +def test_node_get_connection(): + connection = MyNode.get_default_connection() + assert issubclass(connection, Connection) + + +def test_node_get_connection_dont_duplicate(): + assert MyNode.get_default_connection() == MyNode.get_default_connection() + + def test_node_query(): executed = schema.execute( '{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1) diff --git a/graphene/types/field.py b/graphene/types/field.py index 2b4331ea..fc6fcad6 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -76,8 +76,6 @@ class Field(AbstractField, GraphQLField, OrderedType): @property def resolver(self): - pass - resolver = getattr(self.parent, 'resolve_{}'.format(self.attname), None) # We try to get the resolver from the interfaces diff --git a/graphene/utils/copy_fields.py b/graphene/utils/copy_fields.py index 53570250..9d63a0ef 100644 --- a/graphene/utils/copy_fields.py +++ b/graphene/utils/copy_fields.py @@ -1,10 +1,15 @@ from collections import OrderedDict +from ..types.field import Field, InputField def copy_fields(like, fields, **extra): _fields = [] for attname, field in fields.items(): - field = like.copy_and_extend(field, attname=getattr(field, 'attname', None) or attname, **extra) + if isinstance(field, (Field, InputField)): + copy_and_extend = field.copy_and_extend + else: + copy_and_extend = like.copy_and_extend + field = copy_and_extend(field, attname=getattr(field, 'attname', None) or attname, **extra) _fields.append(field) return OrderedDict((f.name, f) for f in _fields)