mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-11 12:16:58 +03:00
Added relay.Connection and relay.ConnectionField
This commit is contained in:
parent
b19bca7f3b
commit
e2036da75f
|
@ -1,10 +1,10 @@
|
|||
from .node import Node
|
||||
# from .mutation import ClientIDMutation
|
||||
# from .connection import Connection, ConnectionField
|
||||
from .connection import Connection, ConnectionField
|
||||
|
||||
__all__ = [
|
||||
'Node',
|
||||
# 'ClientIDMutation',
|
||||
# 'Connection',
|
||||
# 'ConnectionField',
|
||||
'Connection',
|
||||
'ConnectionField',
|
||||
]
|
||||
|
|
155
graphene/relay/connection.py
Normal file
155
graphene/relay/connection.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
import re
|
||||
from collections import Iterable, OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import six
|
||||
|
||||
from graphql_relay import connection_from_list
|
||||
from graphql_relay.connection.connection import connection_args
|
||||
|
||||
from ..types import Boolean, String, List
|
||||
from ..types.field import Field
|
||||
from ..types.objecttype import ObjectType, ObjectTypeMeta
|
||||
from ..types.options import Options
|
||||
from ..utils.is_base_type import is_base_type
|
||||
from ..utils.props import props
|
||||
from .node import Node
|
||||
|
||||
from ..types.utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
|
||||
|
||||
|
||||
def is_node(objecttype):
|
||||
for i in objecttype._meta.interfaces:
|
||||
if issubclass(i, Node):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class PageInfo(ObjectType):
|
||||
has_next_page = Boolean(
|
||||
required=True,
|
||||
name='hasNextPage',
|
||||
description='When paginating forwards, are there more items?',
|
||||
)
|
||||
|
||||
has_previous_page = Boolean(
|
||||
required=True,
|
||||
name='hasPreviousPage',
|
||||
description='When paginating backwards, are there more items?',
|
||||
)
|
||||
|
||||
start_cursor = String(
|
||||
name='startCursor',
|
||||
description='When paginating backwards, the cursor to continue.',
|
||||
)
|
||||
|
||||
end_cursor = String(
|
||||
name='endCursor',
|
||||
description='When paginating forwards, the cursor to continue.',
|
||||
)
|
||||
|
||||
|
||||
class ConnectionMeta(ObjectTypeMeta):
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
# Also ensure initialization is only performed for subclasses of Model
|
||||
# (excluding Model class itself).
|
||||
if not is_base_type(bases, ConnectionMeta):
|
||||
return type.__new__(cls, name, bases, attrs)
|
||||
|
||||
options = Options(
|
||||
attrs.pop('Meta', None),
|
||||
name=None,
|
||||
description=None,
|
||||
node=None,
|
||||
)
|
||||
options.interfaces = ()
|
||||
|
||||
assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__)
|
||||
assert issubclass(options.node, (Node, ObjectType)), (
|
||||
'Received incompatible node "{}" for Connection {}.'
|
||||
).format(options.node, name)
|
||||
|
||||
base_name = re.sub('Connection$', '', name)
|
||||
if not options.name:
|
||||
options.name = '{}Connection'.format(base_name)
|
||||
|
||||
|
||||
edge_class = attrs.pop('Edge', None)
|
||||
edge_fields = OrderedDict([
|
||||
('node', Field(options.node, description='The item at the end of the edge')),
|
||||
('cursor', Field(String, required=True, description='A cursor for use in pagination'))
|
||||
])
|
||||
edge_attrs = props(edge_class) if edge_class else OrderedDict()
|
||||
edge_fields.update(get_fields_in_type(ObjectType, edge_attrs))
|
||||
EdgeMeta = type('Meta', (object, ), {'fields': edge_fields})
|
||||
Edge = type('{}Edge'.format(base_name), (ObjectType,), {'Meta': EdgeMeta})
|
||||
|
||||
options.local_fields = OrderedDict([
|
||||
('page_info', Field(PageInfo, name='pageInfo', required=True)),
|
||||
('edges', Field(List(Edge)))
|
||||
])
|
||||
typed_fields = get_fields_in_type(ObjectType, attrs)
|
||||
options.local_fields.update(typed_fields)
|
||||
options.fields = options.local_fields
|
||||
yank_fields_from_attrs(attrs, typed_fields)
|
||||
|
||||
return type.__new__(cls, name, bases, dict(attrs, _meta=options, Edge=Edge))
|
||||
|
||||
|
||||
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
|
||||
resolve_node = None
|
||||
resolve_cursor = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info'))
|
||||
super(Connection, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class IterableConnectionField(Field):
|
||||
|
||||
def __init__(self, type, *args, **kwargs):
|
||||
# arguments = kwargs.pop('args', {})
|
||||
# if not arguments:
|
||||
# arguments = connection_args
|
||||
# else:
|
||||
# arguments = copy.copy(arguments)
|
||||
# arguments.update(connection_args)
|
||||
|
||||
super(IterableConnectionField, self).__init__(
|
||||
type,
|
||||
args=connection_args,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
if is_node(self.type):
|
||||
connection_type = self.type.Connection
|
||||
else:
|
||||
connection_type = self.type
|
||||
assert issubclass(connection_type, Connection), (
|
||||
'{} type have to be a subclass of Connection'
|
||||
).format(str(self))
|
||||
return connection_type
|
||||
|
||||
@staticmethod
|
||||
def connection_resolver(resolver, connection, root, args, context, info):
|
||||
iterable = resolver(root, args, context, info)
|
||||
assert isinstance(iterable, Iterable), (
|
||||
'Resolved value from the connection field have to be iterable. '
|
||||
'Received "{}"'
|
||||
).format(iterable)
|
||||
connection = connection_from_list(
|
||||
iterable,
|
||||
args,
|
||||
connection_type=connection,
|
||||
edge_type=connection.Edge,
|
||||
)
|
||||
return connection
|
||||
|
||||
def get_resolver(self, parent_resolver):
|
||||
return partial(self.connection_resolver, parent_resolver, self.connection)
|
||||
|
||||
ConnectionField = IterableConnectionField
|
66
graphene/relay/tests/test_connection.py
Normal file
66
graphene/relay/tests/test_connection.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
|
||||
from ...types import ObjectType, Schema, List, Field, String, NonNull
|
||||
from ..connection import Connection, PageInfo
|
||||
from ..node import Node
|
||||
|
||||
|
||||
class MyObject(ObjectType):
|
||||
class Meta:
|
||||
interfaces = [Node]
|
||||
field = String()
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, id):
|
||||
pass
|
||||
|
||||
|
||||
class MyObjectConnection(Connection):
|
||||
extra = String()
|
||||
|
||||
class Meta:
|
||||
node = MyObject
|
||||
|
||||
class Edge:
|
||||
other = String()
|
||||
|
||||
|
||||
class RootQuery(ObjectType):
|
||||
my_connection = Field(MyObjectConnection)
|
||||
|
||||
|
||||
schema = Schema(query=RootQuery)
|
||||
|
||||
|
||||
def test_connection():
|
||||
assert MyObjectConnection._meta.name == 'MyObjectConnection'
|
||||
fields = MyObjectConnection._meta.fields
|
||||
assert fields.keys() == ['page_info', 'edges', 'extra']
|
||||
edge_field = fields['edges']
|
||||
pageinfo_field = fields['page_info']
|
||||
|
||||
assert isinstance(edge_field, Field)
|
||||
assert isinstance(edge_field.type, List)
|
||||
assert edge_field.type.of_type == MyObjectConnection.Edge
|
||||
|
||||
assert isinstance(pageinfo_field, Field)
|
||||
assert isinstance(pageinfo_field.type, NonNull)
|
||||
assert pageinfo_field.type.of_type == PageInfo
|
||||
|
||||
|
||||
def test_edge():
|
||||
Edge = MyObjectConnection.Edge
|
||||
assert Edge._meta.name == 'MyObjectEdge'
|
||||
edge_fields = Edge._meta.fields
|
||||
assert edge_fields.keys() == ['node', 'cursor', 'other']
|
||||
|
||||
assert isinstance(edge_fields['node'], Field)
|
||||
assert edge_fields['node'].type == MyObject
|
||||
|
||||
assert isinstance(edge_fields['other'], Field)
|
||||
assert edge_fields['other'].type == String
|
||||
|
||||
|
||||
def test_pageinfo():
|
||||
assert PageInfo._meta.name == 'PageInfo'
|
||||
fields = PageInfo._meta.fields
|
||||
assert fields.keys() == ['has_next_page', 'has_previous_page', 'start_cursor', 'end_cursor']
|
|
@ -44,3 +44,6 @@ class Field(OrderedType):
|
|||
if inspect.isfunction(self._type):
|
||||
return self._type()
|
||||
return self._type
|
||||
|
||||
def get_resolver(self, parent_resolver):
|
||||
return self.resolver or parent_resolver
|
||||
|
|
|
@ -22,6 +22,7 @@ class ObjectTypeMeta(AbstractTypeMeta):
|
|||
name=name,
|
||||
description=attrs.get('__doc__'),
|
||||
interfaces=(),
|
||||
fields=OrderedDict(),
|
||||
)
|
||||
|
||||
attrs = merge_fields_in_attrs(bases, attrs)
|
||||
|
@ -33,7 +34,7 @@ class ObjectTypeMeta(AbstractTypeMeta):
|
|||
'All interfaces of {} must be a subclass of Interface. Received "{}".'
|
||||
).format(name, interface)
|
||||
options.interface_fields.update(interface._meta.fields)
|
||||
options.fields = OrderedDict(options.interface_fields)
|
||||
options.fields.update(options.interface_fields)
|
||||
options.fields.update(options.local_fields)
|
||||
|
||||
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
|
||||
|
|
|
@ -42,6 +42,7 @@ def test_objecttype():
|
|||
class MyObjectType(ObjectType):
|
||||
'''Description'''
|
||||
foo = String(bar=String(description='Argument description', default_value='x'), description='Field description')
|
||||
bar = String(name='gizmo')
|
||||
|
||||
def resolve_foo(self, args, info):
|
||||
return args.get('bar')
|
||||
|
@ -54,7 +55,7 @@ def test_objecttype():
|
|||
assert graphql_type.description == 'Description'
|
||||
|
||||
fields = graphql_type.get_fields()
|
||||
assert 'foo' in fields
|
||||
assert fields.keys() == ['foo', 'gizmo']
|
||||
foo_field = fields['foo']
|
||||
assert isinstance(foo_field, GraphQLField)
|
||||
assert foo_field.description == 'Field description'
|
||||
|
|
|
@ -210,12 +210,12 @@ class TypeMap(GraphQLTypeMap):
|
|||
_field = GraphQLField(
|
||||
field_type,
|
||||
args=args,
|
||||
resolver=field.resolver or cls.get_resolver_for_type(type, name),
|
||||
resolver=field.get_resolver(cls.get_resolver_for_type(type, name)),
|
||||
deprecation_reason=field.deprecation_reason,
|
||||
description=field.description
|
||||
)
|
||||
processed_name = cls.process_field_name(name)
|
||||
fields[processed_name] = _field
|
||||
field_name = field.name or cls.process_field_name(name)
|
||||
fields[field_name] = _field
|
||||
return fields
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue
Block a user