Merge pull request #237 from Globegitter/connection-fields-object-types

Allow ConnectionFields to have ObjectTypes as per relay spec
This commit is contained in:
Syrus Akbary 2016-07-29 19:24:46 -07:00 committed by GitHub
commit bd1fbb6a33
7 changed files with 195 additions and 115 deletions

View File

@ -4,7 +4,8 @@ import six
from django.db import models from django.db import models
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta
from ...relay.types import Connection, Node, NodeMeta from ...relay.types import Node, NodeMeta
from ...relay.connection import Connection
from .converter import convert_django_field_with_choices from .converter import convert_django_field_with_choices
from .options import DjangoOptions from .options import DjangoOptions
from .utils import get_reverse_fields from .utils import get_reverse_fields

View File

@ -5,7 +5,8 @@ from sqlalchemy.inspection import inspect as sqlalchemyinspect
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta
from ...relay.types import Connection, Node, NodeMeta from ...relay.types import Node, NodeMeta
from ...relay.connection import Connection
from .converter import (convert_sqlalchemy_column, from .converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship) convert_sqlalchemy_relationship)
from .options import SQLAlchemyOptions from .options import SQLAlchemyOptions

View File

@ -6,12 +6,15 @@ from .fields import (
from .types import ( from .types import (
Node, Node,
PageInfo,
Edge,
Connection,
ClientIDMutation ClientIDMutation
) )
from .connection import (
PageInfo,
Connection,
Edge,
)
from .utils import is_node from .utils import is_node
__all__ = ['ConnectionField', 'NodeField', 'GlobalIDField', 'Node', __all__ = ['ConnectionField', 'NodeField', 'GlobalIDField', 'Node',

View File

@ -0,0 +1,87 @@
from collections import Iterable
from graphql_relay.connection.arrayconnection import connection_from_list
from ..core.classtypes import ObjectType
from ..core.types import Field, Boolean, String, List
from ..utils import memoize
class PageInfo(ObjectType):
def __init__(self, start_cursor="", end_cursor="",
has_previous_page=False, has_next_page=False, **kwargs):
super(PageInfo, self).__init__(**kwargs)
self.startCursor = start_cursor
self.endCursor = end_cursor
self.hasPreviousPage = has_previous_page
self.hasNextPage = has_next_page
hasNextPage = Boolean(
required=True,
description='When paginating forwards, are there more items?')
hasPreviousPage = Boolean(
required=True,
description='When paginating backwards, are there more items?')
startCursor = String(
description='When paginating backwards, the cursor to continue.')
endCursor = String(
description='When paginating forwards, the cursor to continue.')
class Edge(ObjectType):
'''An edge in a connection.'''
cursor = String(
required=True, description='A cursor for use in pagination')
@classmethod
@memoize
def for_node(cls, node):
node_field = Field(node, description='The item at the end of the edge')
return type(
'%s%s' % (node._meta.type_name, cls._meta.type_name),
(cls,),
{'node_type': node, 'node': node_field})
class Connection(ObjectType):
'''A connection to a list of items.'''
def __init__(self, edges, page_info, **kwargs):
super(Connection, self).__init__(**kwargs)
self.edges = edges
self.pageInfo = page_info
class Meta:
type_name = 'DefaultConnection'
pageInfo = Field(PageInfo, required=True,
description='The Information to aid in pagination')
_connection_data = None
@classmethod
@memoize
def for_node(cls, node, edge_type=None):
edge_type = edge_type or Edge.for_node(node)
edges = List(edge_type, description='Information to aid in pagination.')
return type(
'%s%s' % (node._meta.type_name, cls._meta.type_name),
(cls,),
{'edge_type': edge_type, 'edges': edges})
@classmethod
def from_list(cls, iterable, args, context, info):
assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable'
connection = connection_from_list(
iterable, args, connection_type=cls,
edge_type=cls.edge_type, pageinfo_type=PageInfo)
connection.set_connection_data(iterable)
return connection
def set_connection_data(self, data):
self._connection_data = data
def get_connection_data(self):
return self._connection_data

View File

@ -6,6 +6,7 @@ from ..core.fields import Field
from ..core.types.definitions import NonNull from ..core.types.definitions import NonNull
from ..core.types.scalars import ID, Int, String from ..core.types.scalars import ID, Int, String
from ..utils.wrap_resolver_function import has_context, with_context from ..utils.wrap_resolver_function import has_context, with_context
from .connection import Connection, Edge
class ConnectionField(Field): class ConnectionField(Field):
@ -23,8 +24,8 @@ class ConnectionField(Field):
last=Int(), last=Int(),
description=description, description=description,
**kwargs) **kwargs)
self.connection_type = connection_type self.connection_type = connection_type or Connection
self.edge_type = edge_type self.edge_type = edge_type or Edge
@with_context @with_context
def resolver(self, instance, args, context, info): def resolver(self, instance, args, context, info):
@ -37,7 +38,7 @@ class ConnectionField(Field):
else: else:
resolved = super(ConnectionField, self).resolver(instance, args, info) resolved = super(ConnectionField, self).resolver(instance, args, info)
if isinstance(resolved, connection_type): if isinstance(resolved, self.connection_type):
return resolved return resolved
return self.from_list(connection_type, resolved, args, context, info) return self.from_list(connection_type, resolved, args, context, info)
@ -45,19 +46,14 @@ class ConnectionField(Field):
return connection_type.from_list(resolved, args, context, info) return connection_type.from_list(resolved, args, context, info)
def get_connection_type(self, node): def get_connection_type(self, node):
connection_type = self.connection_type or node.get_connection_type() return self.connection_type.for_node(node)
edge_type = self.get_edge_type(node)
return connection_type.for_node(node, edge_type=edge_type)
def get_edge_type(self, node): def get_edge_type(self, node):
edge_type = self.edge_type or node.get_edge_type() return self.edge_type.for_node(node)
return edge_type.for_node(node)
def get_type(self, schema): def get_type(self, schema):
from graphene.relay.utils import is_node
type = schema.T(self.type) type = schema.T(self.type)
node = schema.objecttype(type) node = schema.objecttype(type)
assert is_node(node), 'Only nodes have connections.'
schema.register(node) schema.register(node)
connection_type = self.get_connection_type(node) connection_type = self.get_connection_type(node)

View File

@ -19,6 +19,9 @@ class MyNode(relay.Node):
def get_node(cls, id, info): def get_node(cls, id, info):
return MyNode(id=id, name='mo') return MyNode(id=id, name='mo')
class MyObject(graphene.ObjectType):
name = graphene.String()
class SpecialNode(relay.Node): class SpecialNode(relay.Node):
value = graphene.String() value = graphene.String()
@ -29,6 +32,9 @@ class SpecialNode(relay.Node):
value = "!!!" if context.get('is_special') else "???" value = "!!!" if context.get('is_special') else "???"
return SpecialNode(id=id, value=value) return SpecialNode(id=id, value=value)
def _create_my_node_edge(myNode):
edge_type = relay.Edge.for_node(MyNode)
return edge_type(node=myNode, cursor=str(myNode.id))
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
my_node = relay.NodeField(MyNode) my_node = relay.NodeField(MyNode)
@ -40,6 +46,12 @@ class Query(graphene.ObjectType):
context_nodes = relay.ConnectionField( context_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection, customArg=graphene.String()) MyNode, connection_type=MyConnection, customArg=graphene.String())
connection_type_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection)
all_my_objects = relay.ConnectionField(
MyObject, connection_type=MyConnection)
def resolve_all_my_nodes(self, args, info): def resolve_all_my_nodes(self, args, info):
custom_arg = args.get('customArg') custom_arg = args.get('customArg')
assert custom_arg == "1" assert custom_arg == "1"
@ -51,6 +63,16 @@ class Query(graphene.ObjectType):
assert custom_arg == "1" assert custom_arg == "1"
return [MyNode(name='my')] return [MyNode(name='my')]
def resolve_connection_type_nodes(self, args, info):
edges = [_create_my_node_edge(n) for n in [MyNode(id='1', name='my')]]
connection_type = MyConnection.for_node(MyNode)
return connection_type(
edges=edges, page_info=relay.PageInfo(has_next_page=True))
def resolve_all_my_objects(self, args, info):
return [MyObject(name='my_object')]
schema.query = Query schema.query = Query
@ -135,6 +157,74 @@ def test_connectionfield_context_query():
assert result.data == expected assert result.data == expected
def test_connectionfield_resolve_returns_connection_type_directly():
query = '''
query RebelsShipsQuery {
connectionTypeNodes {
edges {
node {
name
}
},
myCustomField
pageInfo {
hasNextPage
}
}
}
'''
expected = {
'connectionTypeNodes': {
'edges': [{
'node': {
'name': 'my'
}
}],
'myCustomField': 'Custom',
'pageInfo': {
'hasNextPage': True,
}
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_connectionfield_resolve_returning_objects():
query = '''
query RebelsShipsQuery {
allMyObjects {
edges {
node {
name
}
},
myCustomField
pageInfo {
hasNextPage
}
}
}
'''
expected = {
'allMyObjects': {
'edges': [{
'node': {
'name': 'my_object'
}
}],
'myCustomField': 'Custom',
'pageInfo': {
'hasNextPage': False,
}
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
@pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')]) @pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')])
def test_get_node_info(specialness, value): def test_get_node_info(specialness, value):
query = ''' query = '''

View File

@ -1,108 +1,21 @@
import inspect import inspect
import warnings import warnings
from collections import Iterable
from functools import wraps from functools import wraps
import six import six
from graphql_relay.connection.arrayconnection import connection_from_list
from graphql_relay.node.node import to_global_id from graphql_relay.node.node import to_global_id
from ..core.classtypes import InputObjectType, Interface, Mutation, ObjectType from ..core.classtypes import InputObjectType, Interface, Mutation
from ..core.classtypes.interface import InterfaceMeta from ..core.classtypes.interface import InterfaceMeta
from ..core.classtypes.mutation import MutationMeta from ..core.classtypes.mutation import MutationMeta
from ..core.types import Boolean, Field, List, String from ..core.types import String
from ..core.types.argument import ArgumentsGroup from ..core.types.argument import ArgumentsGroup
from ..core.types.definitions import NonNull from ..core.types.definitions import NonNull
from ..utils import memoize
from ..utils.wrap_resolver_function import has_context, with_context from ..utils.wrap_resolver_function import has_context, with_context
from .fields import GlobalIDField from .fields import GlobalIDField
class PageInfo(ObjectType):
def __init__(self, start_cursor="", end_cursor="",
has_previous_page=False, has_next_page=False, **kwargs):
super(PageInfo, self).__init__(**kwargs)
self.startCursor = start_cursor
self.endCursor = end_cursor
self.hasPreviousPage = has_previous_page
self.hasNextPage = has_next_page
hasNextPage = Boolean(
required=True,
description='When paginating forwards, are there more items?')
hasPreviousPage = Boolean(
required=True,
description='When paginating backwards, are there more items?')
startCursor = String(
description='When paginating backwards, the cursor to continue.')
endCursor = String(
description='When paginating forwards, the cursor to continue.')
class Edge(ObjectType):
'''An edge in a connection.'''
cursor = String(
required=True, description='A cursor for use in pagination')
@classmethod
@memoize
def for_node(cls, node):
from graphene.relay.utils import is_node
assert is_node(node), 'ObjectTypes in a edge have to be Nodes'
node_field = Field(node, description='The item at the end of the edge')
return type(
'%s%s' % (node._meta.type_name, cls._meta.type_name),
(cls,),
{'node_type': node, 'node': node_field})
class Connection(ObjectType):
'''A connection to a list of items.'''
def __init__(self, edges, page_info, **kwargs):
super(Connection, self).__init__(**kwargs)
self.edges = edges
self.pageInfo = page_info
class Meta:
type_name = 'DefaultConnection'
pageInfo = Field(PageInfo, required=True,
description='The Information to aid in pagination')
_connection_data = None
@classmethod
@memoize
def for_node(cls, node, edge_type=None):
from graphene.relay.utils import is_node
edge_type = edge_type or Edge.for_node(node)
assert is_node(node), 'ObjectTypes in a connection have to be Nodes'
edges = List(edge_type, description='Information to aid in pagination.')
return type(
'%s%s' % (node._meta.type_name, cls._meta.type_name),
(cls,),
{'edge_type': edge_type, 'edges': edges})
@classmethod
def from_list(cls, iterable, args, context, info):
assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable'
connection = connection_from_list(
iterable, args, connection_type=cls,
edge_type=cls.edge_type, pageinfo_type=PageInfo)
connection.set_connection_data(iterable)
return connection
def set_connection_data(self, data):
self._connection_data = data
def get_connection_data(self):
return self._connection_data
class NodeMeta(InterfaceMeta): class NodeMeta(InterfaceMeta):
def construct_get_node(cls): def construct_get_node(cls):
@ -153,17 +66,6 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
def to_global_id(self): def to_global_id(self):
return self.global_id(self.id) return self.global_id(self.id)
connection_type = Connection
edge_type = Edge
@classmethod
def get_connection_type(cls):
return cls.connection_type
@classmethod
def get_edge_type(cls):
return cls.edge_type
class MutationInputType(InputObjectType): class MutationInputType(InputObjectType):
clientMutationId = String(required=True) clientMutationId = String(required=True)