This commit is contained in:
Syrus Akbary 2016-11-12 03:40:25 +00:00 committed by GitHub
commit b0426a33b4
9 changed files with 67 additions and 88 deletions

View File

@ -17,6 +17,11 @@ class Ship(graphene.ObjectType):
return get_ship(id) return get_ship(id)
class ShipConnection(graphene.Connection):
class Meta:
node = Ship
class Faction(graphene.ObjectType): class Faction(graphene.ObjectType):
'''A faction in the Star Wars saga''' '''A faction in the Star Wars saga'''
@ -24,7 +29,7 @@ class Faction(graphene.ObjectType):
interfaces = (relay.Node, ) interfaces = (relay.Node, )
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
ships = relay.ConnectionField(Ship, description='The ships used by the faction.') ships = ShipConnection.Field(description='The ships used by the faction.')
@resolve_only_args @resolve_only_args
def resolve_ships(self, **args): def resolve_ships(self, **args):

View File

@ -5,7 +5,7 @@ from functools import partial
import six import six
from graphql_relay import connection_from_list from graphql_relay import connection_from_list
from promise import Promise from promise import Promise, is_thenable, promisify
from ..types import (AbstractType, Boolean, Enum, Int, Interface, List, NonNull, Scalar, String, from ..types import (AbstractType, Boolean, Enum, Int, Interface, List, NonNull, Scalar, String,
Union) Union)
@ -90,65 +90,61 @@ class ConnectionMeta(ObjectTypeMeta):
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
pass
@classmethod
def Field(cls, *args, **kwargs): # noqa: N802
return ConnectionField(cls, *args, **kwargs)
@classmethod
def is_type_of(cls, root, context, info):
return isinstance(root, cls)
@classmethod
def connection_resolver(cls, resolved, args, context, info):
assert isinstance(resolved, Iterable), (
'Resolved value from the connection field have to be iterable or instance of {}. '
'Received "{}"'
).format(cls, resolved)
connection = connection_from_list(
resolved,
args,
connection_type=cls,
edge_type=cls.Edge,
pageinfo_type=PageInfo
)
connection.iterable = resolved
return connection
class IterableConnectionField(Field): class ConnectionField(Field):
def __init__(self, type, *args, **kwargs): def __init__(self, type, *args, **kwargs):
kwargs.setdefault('before', String()) kwargs.setdefault('before', String())
kwargs.setdefault('after', String()) kwargs.setdefault('after', String())
kwargs.setdefault('first', Int()) kwargs.setdefault('first', Int())
kwargs.setdefault('last', Int()) kwargs.setdefault('last', Int())
super(IterableConnectionField, self).__init__( assert issubclass(type, Connection), (
'{} type have to be a subclass of Connection. Received "{}".'
).format(str(self), type)
super(ConnectionField, self).__init__(
type, type,
*args, *args,
**kwargs **kwargs
) )
@property
def type(self):
type = super(IterableConnectionField, self).type
if is_node(type):
connection_type = type.Connection
else:
connection_type = type
assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection. Received "{}".'
).format(str(self), connection_type)
return connection_type
@classmethod
def resolve_connection(cls, connection_type, args, resolved):
if isinstance(resolved, connection_type):
return resolved
assert isinstance(resolved, Iterable), (
'Resolved value from the connection field have to be iterable or instance of {}. '
'Received "{}"'
).format(connection_type, resolved)
connection = connection_from_list(
resolved,
args,
connection_type=connection_type,
edge_type=connection_type.Edge,
pageinfo_type=PageInfo
)
connection.iterable = resolved
return connection
@classmethod @classmethod
def connection_resolver(cls, resolver, connection_type, root, args, context, info): def connection_resolver(cls, resolver, connection_type, root, args, context, info):
resolved = resolver(root, args, context, info) resolved = resolver(root, args, context, info)
on_resolve = partial(cls.resolve_connection, connection_type, args) if connection_type.is_type_of(resolved, context, info):
if isinstance(resolved, Promise): return resolved
return resolved.then(on_resolve)
on_resolve = partial(connection_type.connection_resolver, args=args, context=context, info=info)
if is_thenable(resolved):
return promisify(resolved).then(on_resolve)
return on_resolve(resolved) return on_resolve(resolved)
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) resolver = super(ConnectionField, self).get_resolver(parent_resolver)
return partial(self.connection_resolver, resolver, self.type) return partial(self.connection_resolver, resolver, self.type)
ConnectionField = IterableConnectionField

View File

@ -21,18 +21,6 @@ def is_node(objecttype):
return False return False
def get_default_connection(cls):
from .connection import Connection
assert issubclass(cls, ObjectType), (
'Can only get connection type on implemented Nodes.'
)
class Meta:
node = cls
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
class GlobalID(Field): class GlobalID(Field):
def __init__(self, node=None, required=True, *args, **kwargs): def __init__(self, node=None, required=True, *args, **kwargs):
@ -101,11 +89,3 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
@classmethod @classmethod
def to_global_id(cls, type, id): def to_global_id(cls, type, id):
return to_global_id(type, id) return to_global_id(type, id)
@classmethod
def implements(cls, objecttype):
get_connection = getattr(objecttype, 'get_connection', None)
if not get_connection:
get_connection = partial(get_default_connection, objecttype)
objecttype.Connection = get_connection()

View File

@ -97,7 +97,11 @@ def test_edge_with_bases():
def test_edge_on_node(): def test_edge_on_node():
Edge = MyObject.Connection.Edge class MyObjectConnection(Connection):
class Meta:
node = MyObject
Edge = MyObjectConnection.Edge
assert Edge._meta.name == 'MyObjectEdge' assert Edge._meta.name == 'MyObjectEdge'
edge_fields = Edge._meta.fields edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ['node', 'cursor'] assert list(edge_fields.keys()) == ['node', 'cursor']

View File

@ -4,7 +4,7 @@ from graphql_relay.utils import base64
from promise import Promise from promise import Promise
from ...types import ObjectType, Schema, String from ...types import ObjectType, Schema, String
from ..connection import ConnectionField, PageInfo from ..connection import ConnectionField, Connection, PageInfo
from ..node import Node from ..node import Node
letter_chars = ['A', 'B', 'C', 'D', 'E'] letter_chars = ['A', 'B', 'C', 'D', 'E']
@ -18,10 +18,15 @@ class Letter(ObjectType):
letter = String() letter = String()
class LetterConnection(Connection):
class Meta:
node = Letter
class Query(ObjectType): class Query(ObjectType):
letters = ConnectionField(Letter) letters = LetterConnection.Field()
connection_letters = ConnectionField(Letter) connection_letters = LetterConnection.Field()
promise_letters = ConnectionField(Letter) promise_letters = LetterConnection.Field()
node = Node.Field() node = Node.Field()
@ -32,13 +37,13 @@ class Query(ObjectType):
return Promise.resolve(list(letters.values())) return Promise.resolve(list(letters.values()))
def resolve_connection_letters(self, args, context, info): def resolve_connection_letters(self, args, context, info):
return Letter.Connection( return LetterConnection(
page_info=PageInfo( page_info=PageInfo(
has_next_page=True, has_next_page=True,
has_previous_page=False has_previous_page=False
), ),
edges=[ edges=[
Letter.Connection.Edge( LetterConnection.Edge(
node=Letter(id=0, letter='A'), node=Letter(id=0, letter='A'),
cursor='a-cursor' cursor='a-cursor'
), ),

View File

@ -3,6 +3,7 @@ import pytest
from ...types import (AbstractType, Argument, Field, InputField, from ...types import (AbstractType, Argument, Field, InputField,
InputObjectType, NonNull, ObjectType, Schema) InputObjectType, NonNull, ObjectType, Schema)
from ...types.scalars import String from ...types.scalars import String
from ..connection import Connection
from ..mutation import ClientIDMutation from ..mutation import ClientIDMutation
from ..node import Node from ..node import Node
@ -19,6 +20,11 @@ class MyNode(ObjectType):
name = String() name = String()
class MyNodeConnection(Connection):
class Meta:
node = MyNode
class SaySomething(ClientIDMutation): class SaySomething(ClientIDMutation):
class Input: class Input:
@ -37,13 +43,13 @@ class OtherMutation(ClientIDMutation):
additional_field = String() additional_field = String()
name = String() name = String()
my_node_edge = Field(MyNode.Connection.Edge) my_node_edge = Field(MyNodeConnection.Edge)
@classmethod @classmethod
def mutate_and_get_payload(cls, args, context, info): def mutate_and_get_payload(cls, args, context, info):
shared = args.get('shared', '') shared = args.get('shared', '')
additionalField = args.get('additionalField', '') additionalField = args.get('additionalField', '')
edge_type = MyNode.Connection.Edge edge_type = MyNodeConnection.Edge
return OtherMutation(name=shared + additionalField, return OtherMutation(name=shared + additionalField,
my_node_edge=edge_type( my_node_edge=edge_type(
cursor='1', node=MyNode(name='name'))) cursor='1', node=MyNode(name='name')))

View File

@ -52,15 +52,6 @@ def test_node_good():
assert 'id' in MyNode._meta.fields assert 'id' in MyNode._meta.fields
def test_node_get_connection():
connection = MyNode.Connection
assert issubclass(connection, Connection)
def test_node_get_connection_dont_duplicate():
assert MyNode.Connection == MyNode.Connection
def test_node_query(): def test_node_query():
executed = schema.execute( executed = schema.execute(
'{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1) '{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1)

View File

@ -56,7 +56,3 @@ class Interface(six.with_metaclass(InterfaceMeta)):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise Exception("An Interface cannot be intitialized") raise Exception("An Interface cannot be intitialized")
@classmethod
def implements(cls, objecttype):
pass

View File

@ -45,10 +45,6 @@ class ObjectTypeMeta(AbstractTypeMeta):
) )
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
for interface in options.interfaces:
interface.implements(cls)
return cls return cls
def __str__(cls): # noqa: N802 def __str__(cls): # noqa: N802