Added relay.Connection and relay.ConnectionField

This commit is contained in:
Syrus Akbary 2016-08-13 21:05:45 -07:00
parent b19bca7f3b
commit e2036da75f
7 changed files with 234 additions and 8 deletions

View File

@ -1,10 +1,10 @@
from .node import Node
# from .mutation import ClientIDMutation
# from .connection import Connection, ConnectionField
from .connection import Connection, ConnectionField
__all__ = [
'Node',
# 'ClientIDMutation',
# 'Connection',
# 'ConnectionField',
'Connection',
'ConnectionField',
]

View File

@ -0,0 +1,155 @@
import re
from collections import Iterable, OrderedDict
from functools import partial
import six
from graphql_relay import connection_from_list
from graphql_relay.connection.connection import connection_args
from ..types import Boolean, String, List
from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options
from ..utils.is_base_type import is_base_type
from ..utils.props import props
from .node import Node
from ..types.utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
def is_node(objecttype):
for i in objecttype._meta.interfaces:
if issubclass(i, Node):
return True
return False
class PageInfo(ObjectType):
has_next_page = Boolean(
required=True,
name='hasNextPage',
description='When paginating forwards, are there more items?',
)
has_previous_page = Boolean(
required=True,
name='hasPreviousPage',
description='When paginating backwards, are there more items?',
)
start_cursor = String(
name='startCursor',
description='When paginating backwards, the cursor to continue.',
)
end_cursor = String(
name='endCursor',
description='When paginating forwards, the cursor to continue.',
)
class ConnectionMeta(ObjectTypeMeta):
def __new__(cls, name, bases, attrs):
# Also ensure initialization is only performed for subclasses of Model
# (excluding Model class itself).
if not is_base_type(bases, ConnectionMeta):
return type.__new__(cls, name, bases, attrs)
options = Options(
attrs.pop('Meta', None),
name=None,
description=None,
node=None,
)
options.interfaces = ()
assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__)
assert issubclass(options.node, (Node, ObjectType)), (
'Received incompatible node "{}" for Connection {}.'
).format(options.node, name)
base_name = re.sub('Connection$', '', name)
if not options.name:
options.name = '{}Connection'.format(base_name)
edge_class = attrs.pop('Edge', None)
edge_fields = OrderedDict([
('node', Field(options.node, description='The item at the end of the edge')),
('cursor', Field(String, required=True, description='A cursor for use in pagination'))
])
edge_attrs = props(edge_class) if edge_class else OrderedDict()
edge_fields.update(get_fields_in_type(ObjectType, edge_attrs))
EdgeMeta = type('Meta', (object, ), {'fields': edge_fields})
Edge = type('{}Edge'.format(base_name), (ObjectType,), {'Meta': EdgeMeta})
options.local_fields = OrderedDict([
('page_info', Field(PageInfo, name='pageInfo', required=True)),
('edges', Field(List(Edge)))
])
typed_fields = get_fields_in_type(ObjectType, attrs)
options.local_fields.update(typed_fields)
options.fields = options.local_fields
yank_fields_from_attrs(attrs, typed_fields)
return type.__new__(cls, name, bases, dict(attrs, _meta=options, Edge=Edge))
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
resolve_node = None
resolve_cursor = None
def __init__(self, *args, **kwargs):
kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info'))
super(Connection, self).__init__(*args, **kwargs)
class IterableConnectionField(Field):
def __init__(self, type, *args, **kwargs):
# arguments = kwargs.pop('args', {})
# if not arguments:
# arguments = connection_args
# else:
# arguments = copy.copy(arguments)
# arguments.update(connection_args)
super(IterableConnectionField, self).__init__(
type,
args=connection_args,
*args,
**kwargs
)
@property
def connection(self):
if is_node(self.type):
connection_type = self.type.Connection
else:
connection_type = self.type
assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection'
).format(str(self))
return connection_type
@staticmethod
def connection_resolver(resolver, connection, root, args, context, info):
iterable = resolver(root, args, context, info)
assert isinstance(iterable, Iterable), (
'Resolved value from the connection field have to be iterable. '
'Received "{}"'
).format(iterable)
connection = connection_from_list(
iterable,
args,
connection_type=connection,
edge_type=connection.Edge,
)
return connection
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.connection)
ConnectionField = IterableConnectionField

View File

@ -0,0 +1,66 @@
from ...types import ObjectType, Schema, List, Field, String, NonNull
from ..connection import Connection, PageInfo
from ..node import Node
class MyObject(ObjectType):
class Meta:
interfaces = [Node]
field = String()
@classmethod
def get_node(cls, id):
pass
class MyObjectConnection(Connection):
extra = String()
class Meta:
node = MyObject
class Edge:
other = String()
class RootQuery(ObjectType):
my_connection = Field(MyObjectConnection)
schema = Schema(query=RootQuery)
def test_connection():
assert MyObjectConnection._meta.name == 'MyObjectConnection'
fields = MyObjectConnection._meta.fields
assert fields.keys() == ['page_info', 'edges', 'extra']
edge_field = fields['edges']
pageinfo_field = fields['page_info']
assert isinstance(edge_field, Field)
assert isinstance(edge_field.type, List)
assert edge_field.type.of_type == MyObjectConnection.Edge
assert isinstance(pageinfo_field, Field)
assert isinstance(pageinfo_field.type, NonNull)
assert pageinfo_field.type.of_type == PageInfo
def test_edge():
Edge = MyObjectConnection.Edge
assert Edge._meta.name == 'MyObjectEdge'
edge_fields = Edge._meta.fields
assert edge_fields.keys() == ['node', 'cursor', 'other']
assert isinstance(edge_fields['node'], Field)
assert edge_fields['node'].type == MyObject
assert isinstance(edge_fields['other'], Field)
assert edge_fields['other'].type == String
def test_pageinfo():
assert PageInfo._meta.name == 'PageInfo'
fields = PageInfo._meta.fields
assert fields.keys() == ['has_next_page', 'has_previous_page', 'start_cursor', 'end_cursor']

View File

@ -44,3 +44,6 @@ class Field(OrderedType):
if inspect.isfunction(self._type):
return self._type()
return self._type
def get_resolver(self, parent_resolver):
return self.resolver or parent_resolver

View File

@ -22,6 +22,7 @@ class ObjectTypeMeta(AbstractTypeMeta):
name=name,
description=attrs.get('__doc__'),
interfaces=(),
fields=OrderedDict(),
)
attrs = merge_fields_in_attrs(bases, attrs)
@ -33,7 +34,7 @@ class ObjectTypeMeta(AbstractTypeMeta):
'All interfaces of {} must be a subclass of Interface. Received "{}".'
).format(name, interface)
options.interface_fields.update(interface._meta.fields)
options.fields = OrderedDict(options.interface_fields)
options.fields.update(options.interface_fields)
options.fields.update(options.local_fields)
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))

View File

@ -42,6 +42,7 @@ def test_objecttype():
class MyObjectType(ObjectType):
'''Description'''
foo = String(bar=String(description='Argument description', default_value='x'), description='Field description')
bar = String(name='gizmo')
def resolve_foo(self, args, info):
return args.get('bar')
@ -54,7 +55,7 @@ def test_objecttype():
assert graphql_type.description == 'Description'
fields = graphql_type.get_fields()
assert 'foo' in fields
assert fields.keys() == ['foo', 'gizmo']
foo_field = fields['foo']
assert isinstance(foo_field, GraphQLField)
assert foo_field.description == 'Field description'

View File

@ -210,12 +210,12 @@ class TypeMap(GraphQLTypeMap):
_field = GraphQLField(
field_type,
args=args,
resolver=field.resolver or cls.get_resolver_for_type(type, name),
resolver=field.get_resolver(cls.get_resolver_for_type(type, name)),
deprecation_reason=field.deprecation_reason,
description=field.description
)
processed_name = cls.process_field_name(name)
fields[processed_name] = _field
field_name = field.name or cls.process_field_name(name)
fields[field_name] = _field
return fields
@classmethod