Added ConnectionField

This commit is contained in:
Syrus Akbary 2016-06-14 23:48:25 -07:00
parent c74a75133e
commit 5ccd815fbd
9 changed files with 161 additions and 103 deletions

View File

@ -16,13 +16,12 @@ class Ship(relay.Node, graphene.ObjectType):
class Faction(relay.Node, graphene.ObjectType): class Faction(relay.Node, graphene.ObjectType):
'''A faction in the Star Wars saga''' '''A faction in the Star Wars saga'''
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
# ships = relay.ConnectionField( ships = relay.ConnectionField(Ship, description='The ships used by the faction.')
# Ship, description='The ships used by the faction.')
ships = graphene.List(graphene.String) @resolve_only_args
# @resolve_only_args def resolve_ships(self, **args):
# def resolve_ships(self, **args): # Transform the instance ship_ids into real instances
# # Transform the instance ship_ids into real instances return [get_ship(ship_id) for ship_id in self.ships]
# return [get_ship(ship_id) for ship_id in self.ships]
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, id, context, info):

View File

@ -1,38 +1,38 @@
# from ..data import setup from ..data import setup
# from ..schema import schema from ..schema import schema
# setup() setup()
# def test_correct_fetch_first_ship_rebels(): def test_correct_fetch_first_ship_rebels():
# query = ''' query = '''
# query RebelsShipsQuery { query RebelsShipsQuery {
# rebels { rebels {
# name, name,
# ships(first: 1) { ships(first: 1) {
# edges { edges {
# node { node {
# name name
# } }
# } }
# } }
# } }
# } }
# ''' '''
# expected = { expected = {
# 'rebels': { 'rebels': {
# 'name': 'Alliance to Restore the Republic', 'name': 'Alliance to Restore the Republic',
# 'ships': { 'ships': {
# 'edges': [ 'edges': [
# { {
# 'node': { 'node': {
# 'name': 'X-Wing' 'name': 'X-Wing'
# } }
# } }
# ] ]
# } }
# } }
# } }
# result = schema.execute(query) result = schema.execute(query)
# assert not result.errors assert not result.errors
# assert result.data == expected assert result.data == expected

View File

@ -14,14 +14,14 @@ def test_mutations():
} }
faction { faction {
name name
# ships { ships {
# edges { edges {
# node { node {
# id id
# name name
# } }
# } }
# } }
} }
} }
} }
@ -34,39 +34,39 @@ def test_mutations():
}, },
'faction': { 'faction': {
'name': 'Alliance to Restore the Republic', 'name': 'Alliance to Restore the Republic',
# 'ships': { 'ships': {
# 'edges': [{ 'edges': [{
# 'node': { 'node': {
# 'id': 'U2hpcDox', 'id': 'U2hpcDox',
# 'name': 'X-Wing' 'name': 'X-Wing'
# } }
# }, { }, {
# 'node': { 'node': {
# 'id': 'U2hpcDoy', 'id': 'U2hpcDoy',
# 'name': 'Y-Wing' 'name': 'Y-Wing'
# } }
# }, { }, {
# 'node': { 'node': {
# 'id': 'U2hpcDoz', 'id': 'U2hpcDoz',
# 'name': 'A-Wing' 'name': 'A-Wing'
# } }
# }, { }, {
# 'node': { 'node': {
# 'id': 'U2hpcDo0', 'id': 'U2hpcDo0',
# 'name': 'Millenium Falcon' 'name': 'Millenium Falcon'
# } }
# }, { }, {
# 'node': { 'node': {
# 'id': 'U2hpcDo1', 'id': 'U2hpcDo1',
# 'name': 'Home One' 'name': 'Home One'
# } }
# }, { }, {
# 'node': { 'node': {
# 'id': 'U2hpcDo5', 'id': 'U2hpcDo5',
# 'name': 'Peter' 'name': 'Peter'
# } }
# }] }]
# }, },
} }
} }
} }

View File

@ -1,9 +1,10 @@
from .node import Node from .node import Node
from .mutation import ClientIDMutation from .mutation import ClientIDMutation
from .connection import Connection from .connection import Connection, ConnectionField
__all__ = [ __all__ = [
'Node', 'Node',
'ClientIDMutation', 'ClientIDMutation',
'Connection', 'Connection',
'ConnectionField',
] ]

View File

@ -4,6 +4,7 @@ from collections import Iterable
import six import six
from graphql_relay import connection_definitions, connection_from_list from graphql_relay import connection_definitions, connection_from_list
from graphql_relay.connection.connection import connection_args
from ..types.field import Field from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta from ..types.objecttype import ObjectType, ObjectTypeMeta
@ -60,24 +61,57 @@ class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
resolve_node = None resolve_node = None
resolve_cursor = None resolve_cursor = None
def __init__(self, *args, **kwargs):
kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info'))
super(Connection, self).__init__(*args, **kwargs)
class IterableConnectionField(Field): class IterableConnectionField(Field):
# def __init__(self, type, *args, **kwargs):
# if
def resolver(self, root, args, context, info): def __init__(self, type, args={}, *other_args, **kwargs):
iterable = super(ConnectionField, self).resolver(root, args, context, info) super(IterableConnectionField, self).__init__(type, args=connection_args, *other_args, **kwargs)
# if isinstance(resolved, self.type.graphene)
assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable'
connection = connection_from_list(
iterable,
args,
connection_type=None,
edge_type=None,
pageinfo_type=None
)
return connection
@property
def type(self):
from ..utils.get_graphql_type import get_graphql_type
return get_graphql_type(self.connection)
@type.setter
def type(self, value):
self._type = value
@property
def connection(self):
from .node import Node
graphql_type = super(IterableConnectionField, self).type
if issubclass(graphql_type.graphene_type, Node):
connection_type = graphql_type.graphene_type.get_default_connection()
else:
connection_type = graphql_type.graphene_type
assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self))
return connection_type
@property
def resolver(self):
super_resolver = super(ConnectionField, self).resolver
def resolver(root, args, context, info):
iterable = super_resolver(root, args, context, info)
# if isinstance(resolved, self.type.graphene)
assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable'
connection = connection_from_list(
iterable,
args,
connection_type=self.connection,
edge_type=self.connection.Edge,
pageinfo_type=None
)
return connection
return resolver
@resolver.setter
def resolver(self, resolver):
self._resolver = resolver
ConnectionField = IterableConnectionField ConnectionField = IterableConnectionField

View File

@ -4,9 +4,10 @@ import six
from graphql_relay import from_global_id, node_definitions, to_global_id from graphql_relay import from_global_id, node_definitions, to_global_id
from .connection import Connection
from ..types.field import Field from ..types.field import Field
from ..types.interface import Interface from ..types.interface import Interface
from ..types.objecttype import ObjectTypeMeta from ..types.objecttype import ObjectTypeMeta, ObjectType
from ..types.options import Options from ..types.options import Options
@ -39,6 +40,7 @@ class NodeMeta(ObjectTypeMeta):
class Node(six.with_metaclass(NodeMeta, Interface)): class Node(six.with_metaclass(NodeMeta, Interface)):
_connection = None
@classmethod @classmethod
def require_get_node(cls): def require_get_node(cls):
@ -71,6 +73,15 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
return return
return graphql_type.graphene_type.get_node(_id, context, info) return graphql_type.graphene_type.get_node(_id, context, info)
@classmethod
def get_default_connection(cls):
assert issubclass(cls, ObjectType), 'Can only get connection type on implemented Nodes.'
if not cls._connection:
class Meta:
node = cls
cls._connection = type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
return cls._connection
@classmethod @classmethod
def implements(cls, object_type): def implements(cls, object_type):
''' '''

View File

@ -5,6 +5,7 @@ from graphql_relay import to_global_id
from ...types import ObjectType, Schema from ...types import ObjectType, Schema
from ...types.scalars import String from ...types.scalars import String
from ..node import Node from ..node import Node
from ..connection import Connection
class MyNode(Node, ObjectType): class MyNode(Node, ObjectType):
@ -44,6 +45,15 @@ def test_node_good():
assert 'id' in graphql_type.get_fields() assert 'id' in graphql_type.get_fields()
def test_node_get_connection():
connection = MyNode.get_default_connection()
assert issubclass(connection, Connection)
def test_node_get_connection_dont_duplicate():
assert MyNode.get_default_connection() == MyNode.get_default_connection()
def test_node_query(): def test_node_query():
executed = schema.execute( executed = schema.execute(
'{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1) '{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1)

View File

@ -76,8 +76,6 @@ class Field(AbstractField, GraphQLField, OrderedType):
@property @property
def resolver(self): def resolver(self):
pass
resolver = getattr(self.parent, 'resolve_{}'.format(self.attname), None) resolver = getattr(self.parent, 'resolve_{}'.format(self.attname), None)
# We try to get the resolver from the interfaces # We try to get the resolver from the interfaces

View File

@ -1,10 +1,15 @@
from collections import OrderedDict from collections import OrderedDict
from ..types.field import Field, InputField
def copy_fields(like, fields, **extra): def copy_fields(like, fields, **extra):
_fields = [] _fields = []
for attname, field in fields.items(): for attname, field in fields.items():
field = like.copy_and_extend(field, attname=getattr(field, 'attname', None) or attname, **extra) if isinstance(field, (Field, InputField)):
copy_and_extend = field.copy_and_extend
else:
copy_and_extend = like.copy_and_extend
field = copy_and_extend(field, attname=getattr(field, 'attname', None) or attname, **extra)
_fields.append(field) _fields.append(field)
return OrderedDict((f.name, f) for f in _fields) return OrderedDict((f.name, f) for f in _fields)