Improved Connection abstractions. Removed default Node.Connection

This commit is contained in:
Syrus Akbary 2016-11-03 00:07:28 -07:00
parent bd07303971
commit 072b2f3dd0
8 changed files with 30 additions and 56 deletions

View File

@ -97,9 +97,6 @@ class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
@classmethod @classmethod
def connection_resolver(cls, resolved, args, context, info): def connection_resolver(cls, resolved, args, context, info):
if isinstance(resolved, cls):
return resolved
assert isinstance(resolved, Iterable), ( assert isinstance(resolved, Iterable), (
'Resolved value from the connection field have to be iterable or instance of {}. ' 'Resolved value from the connection field have to be iterable or instance of {}. '
'Received "{}"' 'Received "{}"'
@ -131,19 +128,18 @@ class ConnectionField(Field):
@property @property
def type(self): def type(self):
type = super(ConnectionField, self).type type = super(ConnectionField, self).type
if is_node(type): assert issubclass(type, Connection), (
connection_type = type.Connection
else:
connection_type = type
assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection. Received "{}".' '{} type have to be a subclass of Connection. Received "{}".'
).format(str(self), connection_type) ).format(str(self), type)
return connection_type return type
@classmethod @classmethod
def connection_resolver(cls, resolver, connection_type, root, args, context, info): def connection_resolver(cls, resolver, connection_type, root, args, context, info):
resolved = resolver(root, args, context, info) resolved = resolver(root, args, context, info)
if isinstance(resolved, connection_type):
return resolved
on_resolve = partial(connection_type.connection_resolver, args=args, context=context, info=info) on_resolve = partial(connection_type.connection_resolver, args=args, context=context, info=info)
if isinstance(resolved, Promise): if isinstance(resolved, Promise):
return resolved.then(on_resolve) return resolved.then(on_resolve)

View File

@ -21,18 +21,6 @@ def is_node(objecttype):
return False return False
def get_default_connection(cls):
from .connection import Connection
assert issubclass(cls, ObjectType), (
'Can only get connection type on implemented Nodes.'
)
class Meta:
node = cls
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
class GlobalID(Field): class GlobalID(Field):
def __init__(self, node, *args, **kwargs): def __init__(self, node, *args, **kwargs):
@ -101,11 +89,3 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
@classmethod @classmethod
def to_global_id(cls, type, id): def to_global_id(cls, type, id):
return to_global_id(type, id) return to_global_id(type, id)
@classmethod
def implements(cls, objecttype):
get_connection = getattr(objecttype, 'get_connection', None)
if not get_connection:
get_connection = partial(get_default_connection, objecttype)
objecttype.Connection = get_connection()

View File

@ -97,7 +97,11 @@ def test_edge_with_bases():
def test_edge_on_node(): def test_edge_on_node():
Edge = MyObject.Connection.Edge class MyObjectConnection(Connection):
class Meta:
node = MyObject
Edge = MyObjectConnection.Edge
assert Edge._meta.name == 'MyObjectEdge' assert Edge._meta.name == 'MyObjectEdge'
edge_fields = Edge._meta.fields edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ['node', 'cursor'] assert list(edge_fields.keys()) == ['node', 'cursor']

View File

@ -4,7 +4,7 @@ from graphql_relay.utils import base64
from promise import Promise from promise import Promise
from ...types import ObjectType, Schema, String from ...types import ObjectType, Schema, String
from ..connection import ConnectionField, PageInfo from ..connection import ConnectionField, Connection, PageInfo
from ..node import Node from ..node import Node
letter_chars = ['A', 'B', 'C', 'D', 'E'] letter_chars = ['A', 'B', 'C', 'D', 'E']
@ -18,10 +18,15 @@ class Letter(ObjectType):
letter = String() letter = String()
class LetterConnection(Connection):
class Meta:
node = Letter
class Query(ObjectType): class Query(ObjectType):
letters = ConnectionField(Letter) letters = LetterConnection.Field()
connection_letters = ConnectionField(Letter) connection_letters = LetterConnection.Field()
promise_letters = ConnectionField(Letter) promise_letters = LetterConnection.Field()
node = Node.Field() node = Node.Field()
@ -32,13 +37,13 @@ class Query(ObjectType):
return Promise.resolve(list(letters.values())) return Promise.resolve(list(letters.values()))
def resolve_connection_letters(self, args, context, info): def resolve_connection_letters(self, args, context, info):
return Letter.Connection( return LetterConnection(
page_info=PageInfo( page_info=PageInfo(
has_next_page=True, has_next_page=True,
has_previous_page=False has_previous_page=False
), ),
edges=[ edges=[
Letter.Connection.Edge( LetterConnection.Edge(
node=Letter(id=0, letter='A'), node=Letter(id=0, letter='A'),
cursor='a-cursor' cursor='a-cursor'
), ),

View File

@ -3,6 +3,7 @@ import pytest
from ...types import (AbstractType, Argument, Field, InputField, from ...types import (AbstractType, Argument, Field, InputField,
InputObjectType, NonNull, ObjectType, Schema) InputObjectType, NonNull, ObjectType, Schema)
from ...types.scalars import String from ...types.scalars import String
from ..connection import Connection
from ..mutation import ClientIDMutation from ..mutation import ClientIDMutation
from ..node import Node from ..node import Node
@ -19,6 +20,11 @@ class MyNode(ObjectType):
name = String() name = String()
class MyNodeConnection(Connection):
class Meta:
node = MyNode
class SaySomething(ClientIDMutation): class SaySomething(ClientIDMutation):
class Input: class Input:
@ -37,13 +43,13 @@ class OtherMutation(ClientIDMutation):
additional_field = String() additional_field = String()
name = String() name = String()
my_node_edge = Field(MyNode.Connection.Edge) my_node_edge = Field(MyNodeConnection.Edge)
@classmethod @classmethod
def mutate_and_get_payload(cls, args, context, info): def mutate_and_get_payload(cls, args, context, info):
shared = args.get('shared', '') shared = args.get('shared', '')
additionalField = args.get('additionalField', '') additionalField = args.get('additionalField', '')
edge_type = MyNode.Connection.Edge edge_type = MyNodeConnection.Edge
return OtherMutation(name=shared + additionalField, return OtherMutation(name=shared + additionalField,
my_node_edge=edge_type( my_node_edge=edge_type(
cursor='1', node=MyNode(name='name'))) cursor='1', node=MyNode(name='name')))

View File

@ -52,15 +52,6 @@ def test_node_good():
assert 'id' in MyNode._meta.fields assert 'id' in MyNode._meta.fields
def test_node_get_connection():
connection = MyNode.Connection
assert issubclass(connection, Connection)
def test_node_get_connection_dont_duplicate():
assert MyNode.Connection == MyNode.Connection
def test_node_query(): def test_node_query():
executed = schema.execute( executed = schema.execute(
'{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1) '{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1)

View File

@ -56,7 +56,3 @@ class Interface(six.with_metaclass(InterfaceMeta)):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise Exception("An Interface cannot be intitialized") raise Exception("An Interface cannot be intitialized")
@classmethod
def implements(cls, objecttype):
pass

View File

@ -45,10 +45,6 @@ class ObjectTypeMeta(AbstractTypeMeta):
) )
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
for interface in options.interfaces:
interface.implements(cls)
return cls return cls
def __str__(cls): # noqa: N802 def __str__(cls): # noqa: N802