Improved Relay implementation

This commit is contained in:
Syrus Akbary 2016-08-13 17:36:11 -07:00
parent fd16de8748
commit 0ffdd8d9ab
8 changed files with 56 additions and 393 deletions

View File

@ -1,10 +1,10 @@
from .node import Node
from .mutation import ClientIDMutation
from .connection import Connection, ConnectionField
# from .mutation import ClientIDMutation
# from .connection import Connection, ConnectionField
__all__ = [
'Node',
'ClientIDMutation',
'Connection',
'ConnectionField',
# 'ClientIDMutation',
# 'Connection',
# 'ConnectionField',
]

View File

@ -1,131 +0,0 @@
import copy
import re
from collections import Iterable
from functools import partial
import six
from graphql_relay import connection_definitions, connection_from_list
from graphql_relay.connection.connection import connection_args
from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options
from ..utils.copy_fields import copy_fields
from ..utils.get_fields import get_fields
from ..utils.is_base_type import is_base_type
from ..utils.props import props
class ConnectionMeta(ObjectTypeMeta):
def __new__(cls, name, bases, attrs):
super_new = type.__new__
# Also ensure initialization is only performed for subclasses of Model
# (excluding Model class itself).
if not is_base_type(bases, ConnectionMeta):
return super_new(cls, name, bases, attrs)
options = Options(
attrs.pop('Meta', None),
name=None,
description=None,
node=None,
interfaces=[],
)
edge_class = attrs.pop('Edge', None)
edge_fields = props(edge_class) if edge_class else {}
edge_fields = get_fields(ObjectType, edge_fields, ())
connection_fields = copy_fields(Field, get_fields(ObjectType, attrs, bases))
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__)
from ..utils.get_graphql_type import get_graphql_type
edge, connection = connection_definitions(
name=options.name or re.sub('Connection$', '', cls.__name__),
node_type=get_graphql_type(options.node),
resolve_node=cls.resolve_node,
resolve_cursor=cls.resolve_cursor,
edge_fields=edge_fields,
connection_fields=connection_fields,
)
cls.Edge = type(edge.name, (ObjectType, ), {'Meta': type('Meta', (object,), {'graphql_type': edge})})
cls._meta.graphql_type = connection
fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls)
cls._meta.get_fields = lambda: fields
return cls
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
resolve_node = None
resolve_cursor = None
def __init__(self, *args, **kwargs):
kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info'))
super(Connection, self).__init__(*args, **kwargs)
class IterableConnectionField(Field):
def __init__(self, type, *other_args, **kwargs):
args = kwargs.pop('args', {})
if not args:
args = connection_args
else:
args = copy.copy(args)
args.update(connection_args)
super(IterableConnectionField, self).__init__(type, args=args, *other_args, **kwargs)
@property
def type(self):
from ..utils.get_graphql_type import get_graphql_type
return get_graphql_type(self.connection)
@type.setter
def type(self, value):
self._type = value
@property
def connection(self):
from .node import Node
if Node in self._type._meta.interfaces:
connection_type = self._type.Connection
else:
connection_type = self._type
assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self))
return connection_type
@staticmethod
def connection_resolver(resolver, connection, root, args, context, info):
iterable = resolver(root, args, context, info)
# if isinstance(resolved, self.type.graphene)
assert isinstance(
iterable, Iterable), 'Resolved value from the connection field have to be iterable. Received "{}"'.format(
iterable)
connection = connection_from_list(
iterable,
args,
connection_type=connection,
edge_type=connection.Edge,
)
return connection
@property
def resolver(self):
resolver = super(ConnectionField, self).resolver
connection = self.connection
return partial(self.connection_resolver, resolver, connection)
@resolver.setter
def resolver(self, resolver):
self._resolver = resolver
ConnectionField = IterableConnectionField

View File

@ -1,62 +0,0 @@
from functools import partial
import six
from graphql_relay import mutation_with_client_mutation_id
from ..types.field import Field, InputField
from ..types.inputobjecttype import InputObjectType
from ..types.mutation import Mutation, MutationMeta
from ..types.objecttype import ObjectType
from ..types.options import Options
from ..utils.copy_fields import copy_fields
from ..utils.get_fields import get_fields
from ..utils.is_base_type import is_base_type
from ..utils.props import props
class ClientIDMutationMeta(MutationMeta):
def __new__(cls, name, bases, attrs):
super_new = type.__new__
# Also ensure initialization is only performed for subclasses of Model
# (excluding Model class itself).
if not is_base_type(bases, ClientIDMutationMeta):
return super_new(cls, name, bases, attrs)
options = Options(
attrs.pop('Meta', None),
name=None,
description=None,
)
input_class = attrs.pop('Input', None)
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
input_fields = props(input_class) if input_class else {}
input_local_fields = copy_fields(InputField, get_fields(InputObjectType, input_fields, ()))
output_fields = copy_fields(Field, get_fields(ObjectType, attrs, bases))
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
assert mutate_and_get_payload, (
"{}.mutate_and_get_payload method is required"
" in a ClientIDMutation ObjectType.".format(cls.__name__)
)
field = mutation_with_client_mutation_id(
name=options.name or cls.__name__,
input_fields=input_local_fields,
output_fields=output_fields,
mutate_and_get_payload=cls.mutate_and_get_payload,
)
options.graphql_type = field.type
options.get_fields = lambda: output_fields
cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None)
return cls
class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, Mutation)):
pass

View File

@ -1,69 +1,39 @@
from functools import partial
import six
from graphql_relay import from_global_id, node_definitions, to_global_id
from ..types.field import Field
from ..types.interface import Interface
from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options
from ..utils.copy_fields import copy_fields
from .connection import Connection
from graphql_relay import from_global_id, to_global_id
from ..types import Interface, ID, Field
def get_default_connection(cls):
assert issubclass(cls, ObjectType), 'Can only get connection type on implemented Nodes.'
class Node(Interface):
'''An object with an ID'''
class Meta:
node = cls
id = ID(required=True, description='The ID of the object.')
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
@classmethod
def resolve_id(cls, root, args, context, info):
type_name = root._meta.name # info.parent_type.name
return cls.to_global_id(type_name, getattr(root, 'id', None))
@classmethod
def Field(cls):
def resolve_node(root, args, context, info):
return cls.get_node_from_global_id(args.get('id'), context, info)
# We inherit from ObjectTypeMeta as we want to allow
# inheriting from Node, and also ObjectType.
# Like class MyNode(Node): pass
# And class MyNodeImplementation(Node, ObjectType): pass
class NodeMeta(ObjectTypeMeta):
@staticmethod
def _get_interface_options(meta):
return Options(
meta,
return Field(
cls,
description='The ID of the object',
id=ID(required=True),
resolver=resolve_node
)
@staticmethod
def _create_interface(cls, name, bases, attrs):
options = cls._get_interface_options(attrs.pop('Meta', None))
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None)
id_resolver = getattr(cls, 'id_resolver', None)
assert get_node_from_global_id, '{}.get_node_from_global_id method is required by the Node interface.'.format(
cls.__name__)
node_interface, node_field = node_definitions(
get_node_from_global_id,
id_resolver=id_resolver,
type_resolver=cls.resolve_type,
)
options.graphql_type = node_interface
fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls)
options.get_fields = lambda: fields
cls.Field = partial(
Field.copy_and_extend,
node_field,
type=node_field.type,
parent=cls,
_creation_counter=None)
return cls
class Node(six.with_metaclass(NodeMeta, Interface)):
_connection = None
resolve_type = None
@classmethod
def get_node_from_global_id(cls, global_id, context, info):
try:
_type, _id = cls.from_global_id(global_id)
graphene_type = info.schema.get_type(_type).graphene_type
# We make sure the ObjectType implements the "Node" interface
assert cls in graphene_type._meta.interfaces
except:
return None
return graphene_type.get_node(_id, context, info)
@classmethod
def from_global_id(cls, global_id):
@ -73,33 +43,14 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
def to_global_id(cls, type, id):
return to_global_id(type, id)
@classmethod
def resolve_id(cls, root, args, context, info):
return getattr(root, 'id', None)
@classmethod
def id_resolver(cls, root, args, context, info):
return cls.to_global_id(info.parent_type.name, cls.resolve_id(root, args, context, info))
@classmethod
def get_node_from_global_id(cls, global_id, context, info):
try:
_type, _id = cls.from_global_id(global_id)
except:
return None
graphql_type = info.schema.get_type(_type)
if cls._meta.graphql_type not in graphql_type.get_interfaces():
return
return graphql_type.graphene_type.get_node(_id, context, info)
@classmethod
def implements(cls, objecttype):
require_get_node = Node._meta.graphql_type in objecttype._meta.get_interfaces
get_connection = getattr(objecttype, 'get_connection', None)
if not get_connection:
get_connection = partial(get_default_connection, objecttype)
require_get_node = Node in objecttype._meta.interfaces
# get_connection = getattr(objecttype, 'get_connection', None)
# if not get_connection:
# get_connection = partial(get_default_connection, objecttype)
objecttype.Connection = get_connection()
# objecttype.Connection = get_connection()
if require_get_node:
assert hasattr(
objecttype, 'get_node'), '{}.get_node method is required by the Node interface.'.format(

View File

@ -1,43 +0,0 @@
from ...types import ObjectType, Schema
from ...types.field import Field
from ...types.scalars import String
from ..connection import Connection
from ..node import Node
class MyObject(ObjectType):
class Meta:
interfaces = [Node]
field = String()
@classmethod
def get_node(cls, id):
pass
class MyObjectConnection(Connection):
class Meta:
node = MyObject
class Edge:
other = String()
class RootQuery(ObjectType):
my_connection = Field(MyObjectConnection)
schema = Schema(query=RootQuery)
def test_node_good():
graphql_type = MyObjectConnection._meta.graphql_type
fields = graphql_type.get_fields()
assert 'edges' in fields
assert 'pageInfo' in fields
edge_fields = fields['edges'].type.of_type.get_fields()
assert 'node' in edge_fields
assert edge_fields['node'].type == MyObject._meta.graphql_type
assert 'other' in edge_fields

View File

@ -1,55 +0,0 @@
import pytest
from ...types import ObjectType, Schema
from ...types.scalars import String
from ..mutation import ClientIDMutation
class SaySomething(ClientIDMutation):
class Input:
what = String()
phrase = String()
@staticmethod
def mutate_and_get_payload(args, context, info):
what = args.get('what')
return SaySomething(phrase=str(what))
class RootQuery(ObjectType):
something = String()
class Mutation(ObjectType):
say = SaySomething.Field()
schema = Schema(query=RootQuery, mutation=Mutation)
def test_no_mutate_and_get_payload():
with pytest.raises(AssertionError) as excinfo:
class MyMutation(ClientIDMutation):
pass
assert "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation ObjectType." == str(
excinfo.value)
def test_node_good():
graphql_type = SaySomething._meta.graphql_type
fields = graphql_type.get_fields()
assert 'phrase' in fields
graphql_field = SaySomething.Field()
assert graphql_field.type == SaySomething._meta.graphql_type
assert 'input' in graphql_field.args
input = graphql_field.args['input']
assert 'clientMutationId' in input.type.of_type.get_fields()
def test_node_query():
executed = schema.execute(
'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
)
assert not executed.errors
assert executed.data == {'say': {'phrase': 'hello'}}

View File

@ -2,15 +2,15 @@ import pytest
from graphql_relay import to_global_id
from ...types import ObjectType, Schema
from ...types.scalars import String
from ..connection import Connection
from ...types import ObjectType, Schema, String
# from ...types.scalars import String
# from ..connection import Connection
from ..node import Node
class MyNode(ObjectType):
class Meta:
interfaces = [Node]
interfaces = (Node, )
name = String()
@staticmethod
@ -27,32 +27,33 @@ schema = Schema(query=RootQuery, types=[MyNode])
def test_node_no_get_node():
with pytest.raises(AssertionError) as excinfo:
class MyNode(Node, ObjectType):
pass
class MyNode(ObjectType):
class Meta:
interfaces = (Node, )
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
def test_node_no_get_node_with_meta():
with pytest.raises(AssertionError) as excinfo:
class MyNode(Node, ObjectType):
pass
class MyNode(ObjectType):
class Meta:
interfaces = (Node, )
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
def test_node_good():
graphql_type = MyNode._meta.graphql_type
assert 'id' in graphql_type.get_fields()
assert 'id' in MyNode._meta.fields
def test_node_get_connection():
connection = MyNode.Connection
assert issubclass(connection, Connection)
# def test_node_get_connection():
# connection = MyNode.Connection
# assert issubclass(connection, Connection)
def test_node_get_connection_dont_duplicate():
assert MyNode.Connection == MyNode.Connection
# def test_node_get_connection_dont_duplicate():
# assert MyNode.Connection == MyNode.Connection
def test_node_query():

View File

@ -6,6 +6,8 @@ from ..node import Node
class CustomNode(Node):
class Meta:
name = 'Node'
@staticmethod
def to_global_id(type, id):