mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-02 20:54:16 +03:00
Improved Relay implementation
This commit is contained in:
parent
fd16de8748
commit
0ffdd8d9ab
|
@ -1,10 +1,10 @@
|
||||||
from .node import Node
|
from .node import Node
|
||||||
from .mutation import ClientIDMutation
|
# from .mutation import ClientIDMutation
|
||||||
from .connection import Connection, ConnectionField
|
# from .connection import Connection, ConnectionField
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Node',
|
'Node',
|
||||||
'ClientIDMutation',
|
# 'ClientIDMutation',
|
||||||
'Connection',
|
# 'Connection',
|
||||||
'ConnectionField',
|
# 'ConnectionField',
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,131 +0,0 @@
|
||||||
import copy
|
|
||||||
import re
|
|
||||||
from collections import Iterable
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
from graphql_relay import connection_definitions, connection_from_list
|
|
||||||
from graphql_relay.connection.connection import connection_args
|
|
||||||
|
|
||||||
from ..types.field import Field
|
|
||||||
from ..types.objecttype import ObjectType, ObjectTypeMeta
|
|
||||||
from ..types.options import Options
|
|
||||||
from ..utils.copy_fields import copy_fields
|
|
||||||
from ..utils.get_fields import get_fields
|
|
||||||
from ..utils.is_base_type import is_base_type
|
|
||||||
from ..utils.props import props
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionMeta(ObjectTypeMeta):
|
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
|
||||||
super_new = type.__new__
|
|
||||||
|
|
||||||
# Also ensure initialization is only performed for subclasses of Model
|
|
||||||
# (excluding Model class itself).
|
|
||||||
if not is_base_type(bases, ConnectionMeta):
|
|
||||||
return super_new(cls, name, bases, attrs)
|
|
||||||
|
|
||||||
options = Options(
|
|
||||||
attrs.pop('Meta', None),
|
|
||||||
name=None,
|
|
||||||
description=None,
|
|
||||||
node=None,
|
|
||||||
interfaces=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
edge_class = attrs.pop('Edge', None)
|
|
||||||
edge_fields = props(edge_class) if edge_class else {}
|
|
||||||
edge_fields = get_fields(ObjectType, edge_fields, ())
|
|
||||||
|
|
||||||
connection_fields = copy_fields(Field, get_fields(ObjectType, attrs, bases))
|
|
||||||
|
|
||||||
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
|
|
||||||
|
|
||||||
assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__)
|
|
||||||
from ..utils.get_graphql_type import get_graphql_type
|
|
||||||
edge, connection = connection_definitions(
|
|
||||||
name=options.name or re.sub('Connection$', '', cls.__name__),
|
|
||||||
node_type=get_graphql_type(options.node),
|
|
||||||
resolve_node=cls.resolve_node,
|
|
||||||
resolve_cursor=cls.resolve_cursor,
|
|
||||||
|
|
||||||
edge_fields=edge_fields,
|
|
||||||
connection_fields=connection_fields,
|
|
||||||
)
|
|
||||||
cls.Edge = type(edge.name, (ObjectType, ), {'Meta': type('Meta', (object,), {'graphql_type': edge})})
|
|
||||||
cls._meta.graphql_type = connection
|
|
||||||
fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls)
|
|
||||||
|
|
||||||
cls._meta.get_fields = lambda: fields
|
|
||||||
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
|
|
||||||
resolve_node = None
|
|
||||||
resolve_cursor = None
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info'))
|
|
||||||
super(Connection, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class IterableConnectionField(Field):
|
|
||||||
|
|
||||||
def __init__(self, type, *other_args, **kwargs):
|
|
||||||
args = kwargs.pop('args', {})
|
|
||||||
if not args:
|
|
||||||
args = connection_args
|
|
||||||
else:
|
|
||||||
args = copy.copy(args)
|
|
||||||
args.update(connection_args)
|
|
||||||
|
|
||||||
super(IterableConnectionField, self).__init__(type, args=args, *other_args, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self):
|
|
||||||
from ..utils.get_graphql_type import get_graphql_type
|
|
||||||
return get_graphql_type(self.connection)
|
|
||||||
|
|
||||||
@type.setter
|
|
||||||
def type(self, value):
|
|
||||||
self._type = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def connection(self):
|
|
||||||
from .node import Node
|
|
||||||
if Node in self._type._meta.interfaces:
|
|
||||||
connection_type = self._type.Connection
|
|
||||||
else:
|
|
||||||
connection_type = self._type
|
|
||||||
assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self))
|
|
||||||
return connection_type
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def connection_resolver(resolver, connection, root, args, context, info):
|
|
||||||
iterable = resolver(root, args, context, info)
|
|
||||||
# if isinstance(resolved, self.type.graphene)
|
|
||||||
assert isinstance(
|
|
||||||
iterable, Iterable), 'Resolved value from the connection field have to be iterable. Received "{}"'.format(
|
|
||||||
iterable)
|
|
||||||
connection = connection_from_list(
|
|
||||||
iterable,
|
|
||||||
args,
|
|
||||||
connection_type=connection,
|
|
||||||
edge_type=connection.Edge,
|
|
||||||
)
|
|
||||||
return connection
|
|
||||||
|
|
||||||
@property
|
|
||||||
def resolver(self):
|
|
||||||
resolver = super(ConnectionField, self).resolver
|
|
||||||
connection = self.connection
|
|
||||||
return partial(self.connection_resolver, resolver, connection)
|
|
||||||
|
|
||||||
@resolver.setter
|
|
||||||
def resolver(self, resolver):
|
|
||||||
self._resolver = resolver
|
|
||||||
|
|
||||||
ConnectionField = IterableConnectionField
|
|
|
@ -1,62 +0,0 @@
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
from graphql_relay import mutation_with_client_mutation_id
|
|
||||||
|
|
||||||
from ..types.field import Field, InputField
|
|
||||||
from ..types.inputobjecttype import InputObjectType
|
|
||||||
from ..types.mutation import Mutation, MutationMeta
|
|
||||||
from ..types.objecttype import ObjectType
|
|
||||||
from ..types.options import Options
|
|
||||||
from ..utils.copy_fields import copy_fields
|
|
||||||
from ..utils.get_fields import get_fields
|
|
||||||
from ..utils.is_base_type import is_base_type
|
|
||||||
from ..utils.props import props
|
|
||||||
|
|
||||||
|
|
||||||
class ClientIDMutationMeta(MutationMeta):
|
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
|
||||||
super_new = type.__new__
|
|
||||||
|
|
||||||
# Also ensure initialization is only performed for subclasses of Model
|
|
||||||
# (excluding Model class itself).
|
|
||||||
if not is_base_type(bases, ClientIDMutationMeta):
|
|
||||||
return super_new(cls, name, bases, attrs)
|
|
||||||
|
|
||||||
options = Options(
|
|
||||||
attrs.pop('Meta', None),
|
|
||||||
name=None,
|
|
||||||
description=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_class = attrs.pop('Input', None)
|
|
||||||
|
|
||||||
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
|
|
||||||
|
|
||||||
input_fields = props(input_class) if input_class else {}
|
|
||||||
input_local_fields = copy_fields(InputField, get_fields(InputObjectType, input_fields, ()))
|
|
||||||
output_fields = copy_fields(Field, get_fields(ObjectType, attrs, bases))
|
|
||||||
|
|
||||||
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
|
|
||||||
assert mutate_and_get_payload, (
|
|
||||||
"{}.mutate_and_get_payload method is required"
|
|
||||||
" in a ClientIDMutation ObjectType.".format(cls.__name__)
|
|
||||||
)
|
|
||||||
|
|
||||||
field = mutation_with_client_mutation_id(
|
|
||||||
name=options.name or cls.__name__,
|
|
||||||
input_fields=input_local_fields,
|
|
||||||
output_fields=output_fields,
|
|
||||||
mutate_and_get_payload=cls.mutate_and_get_payload,
|
|
||||||
)
|
|
||||||
options.graphql_type = field.type
|
|
||||||
options.get_fields = lambda: output_fields
|
|
||||||
|
|
||||||
cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, Mutation)):
|
|
||||||
pass
|
|
|
@ -1,69 +1,39 @@
|
||||||
from functools import partial
|
from graphql_relay import from_global_id, to_global_id
|
||||||
|
from ..types import Interface, ID, Field
|
||||||
import six
|
|
||||||
|
|
||||||
from graphql_relay import from_global_id, node_definitions, to_global_id
|
|
||||||
|
|
||||||
from ..types.field import Field
|
|
||||||
from ..types.interface import Interface
|
|
||||||
from ..types.objecttype import ObjectType, ObjectTypeMeta
|
|
||||||
from ..types.options import Options
|
|
||||||
from ..utils.copy_fields import copy_fields
|
|
||||||
from .connection import Connection
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_connection(cls):
|
class Node(Interface):
|
||||||
assert issubclass(cls, ObjectType), 'Can only get connection type on implemented Nodes.'
|
'''An object with an ID'''
|
||||||
|
|
||||||
class Meta:
|
id = ID(required=True, description='The ID of the object.')
|
||||||
node = cls
|
|
||||||
|
|
||||||
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
|
@classmethod
|
||||||
|
def resolve_id(cls, root, args, context, info):
|
||||||
|
type_name = root._meta.name # info.parent_type.name
|
||||||
|
return cls.to_global_id(type_name, getattr(root, 'id', None))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def Field(cls):
|
||||||
|
def resolve_node(root, args, context, info):
|
||||||
|
return cls.get_node_from_global_id(args.get('id'), context, info)
|
||||||
|
|
||||||
# We inherit from ObjectTypeMeta as we want to allow
|
return Field(
|
||||||
# inheriting from Node, and also ObjectType.
|
cls,
|
||||||
# Like class MyNode(Node): pass
|
description='The ID of the object',
|
||||||
# And class MyNodeImplementation(Node, ObjectType): pass
|
id=ID(required=True),
|
||||||
class NodeMeta(ObjectTypeMeta):
|
resolver=resolve_node
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_interface_options(meta):
|
|
||||||
return Options(
|
|
||||||
meta,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def _create_interface(cls, name, bases, attrs):
|
def get_node_from_global_id(cls, global_id, context, info):
|
||||||
options = cls._get_interface_options(attrs.pop('Meta', None))
|
try:
|
||||||
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
|
_type, _id = cls.from_global_id(global_id)
|
||||||
get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None)
|
graphene_type = info.schema.get_type(_type).graphene_type
|
||||||
id_resolver = getattr(cls, 'id_resolver', None)
|
# We make sure the ObjectType implements the "Node" interface
|
||||||
assert get_node_from_global_id, '{}.get_node_from_global_id method is required by the Node interface.'.format(
|
assert cls in graphene_type._meta.interfaces
|
||||||
cls.__name__)
|
except:
|
||||||
node_interface, node_field = node_definitions(
|
return None
|
||||||
get_node_from_global_id,
|
return graphene_type.get_node(_id, context, info)
|
||||||
id_resolver=id_resolver,
|
|
||||||
type_resolver=cls.resolve_type,
|
|
||||||
)
|
|
||||||
options.graphql_type = node_interface
|
|
||||||
|
|
||||||
fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls)
|
|
||||||
options.get_fields = lambda: fields
|
|
||||||
|
|
||||||
cls.Field = partial(
|
|
||||||
Field.copy_and_extend,
|
|
||||||
node_field,
|
|
||||||
type=node_field.type,
|
|
||||||
parent=cls,
|
|
||||||
_creation_counter=None)
|
|
||||||
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
class Node(six.with_metaclass(NodeMeta, Interface)):
|
|
||||||
_connection = None
|
|
||||||
resolve_type = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_global_id(cls, global_id):
|
def from_global_id(cls, global_id):
|
||||||
|
@ -73,33 +43,14 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
|
||||||
def to_global_id(cls, type, id):
|
def to_global_id(cls, type, id):
|
||||||
return to_global_id(type, id)
|
return to_global_id(type, id)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def resolve_id(cls, root, args, context, info):
|
|
||||||
return getattr(root, 'id', None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def id_resolver(cls, root, args, context, info):
|
|
||||||
return cls.to_global_id(info.parent_type.name, cls.resolve_id(root, args, context, info))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_node_from_global_id(cls, global_id, context, info):
|
|
||||||
try:
|
|
||||||
_type, _id = cls.from_global_id(global_id)
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
graphql_type = info.schema.get_type(_type)
|
|
||||||
if cls._meta.graphql_type not in graphql_type.get_interfaces():
|
|
||||||
return
|
|
||||||
return graphql_type.graphene_type.get_node(_id, context, info)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def implements(cls, objecttype):
|
def implements(cls, objecttype):
|
||||||
require_get_node = Node._meta.graphql_type in objecttype._meta.get_interfaces
|
require_get_node = Node in objecttype._meta.interfaces
|
||||||
get_connection = getattr(objecttype, 'get_connection', None)
|
# get_connection = getattr(objecttype, 'get_connection', None)
|
||||||
if not get_connection:
|
# if not get_connection:
|
||||||
get_connection = partial(get_default_connection, objecttype)
|
# get_connection = partial(get_default_connection, objecttype)
|
||||||
|
|
||||||
objecttype.Connection = get_connection()
|
# objecttype.Connection = get_connection()
|
||||||
if require_get_node:
|
if require_get_node:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
objecttype, 'get_node'), '{}.get_node method is required by the Node interface.'.format(
|
objecttype, 'get_node'), '{}.get_node method is required by the Node interface.'.format(
|
||||||
|
|
|
@ -1,43 +0,0 @@
|
||||||
|
|
||||||
from ...types import ObjectType, Schema
|
|
||||||
from ...types.field import Field
|
|
||||||
from ...types.scalars import String
|
|
||||||
from ..connection import Connection
|
|
||||||
from ..node import Node
|
|
||||||
|
|
||||||
|
|
||||||
class MyObject(ObjectType):
|
|
||||||
class Meta:
|
|
||||||
interfaces = [Node]
|
|
||||||
field = String()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_node(cls, id):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MyObjectConnection(Connection):
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
node = MyObject
|
|
||||||
|
|
||||||
class Edge:
|
|
||||||
other = String()
|
|
||||||
|
|
||||||
|
|
||||||
class RootQuery(ObjectType):
|
|
||||||
my_connection = Field(MyObjectConnection)
|
|
||||||
|
|
||||||
|
|
||||||
schema = Schema(query=RootQuery)
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_good():
|
|
||||||
graphql_type = MyObjectConnection._meta.graphql_type
|
|
||||||
fields = graphql_type.get_fields()
|
|
||||||
assert 'edges' in fields
|
|
||||||
assert 'pageInfo' in fields
|
|
||||||
edge_fields = fields['edges'].type.of_type.get_fields()
|
|
||||||
assert 'node' in edge_fields
|
|
||||||
assert edge_fields['node'].type == MyObject._meta.graphql_type
|
|
||||||
assert 'other' in edge_fields
|
|
|
@ -1,55 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ...types import ObjectType, Schema
|
|
||||||
from ...types.scalars import String
|
|
||||||
from ..mutation import ClientIDMutation
|
|
||||||
|
|
||||||
|
|
||||||
class SaySomething(ClientIDMutation):
|
|
||||||
|
|
||||||
class Input:
|
|
||||||
what = String()
|
|
||||||
phrase = String()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mutate_and_get_payload(args, context, info):
|
|
||||||
what = args.get('what')
|
|
||||||
return SaySomething(phrase=str(what))
|
|
||||||
|
|
||||||
|
|
||||||
class RootQuery(ObjectType):
|
|
||||||
something = String()
|
|
||||||
|
|
||||||
|
|
||||||
class Mutation(ObjectType):
|
|
||||||
say = SaySomething.Field()
|
|
||||||
|
|
||||||
schema = Schema(query=RootQuery, mutation=Mutation)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_mutate_and_get_payload():
|
|
||||||
with pytest.raises(AssertionError) as excinfo:
|
|
||||||
class MyMutation(ClientIDMutation):
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation ObjectType." == str(
|
|
||||||
excinfo.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_good():
|
|
||||||
graphql_type = SaySomething._meta.graphql_type
|
|
||||||
fields = graphql_type.get_fields()
|
|
||||||
assert 'phrase' in fields
|
|
||||||
graphql_field = SaySomething.Field()
|
|
||||||
assert graphql_field.type == SaySomething._meta.graphql_type
|
|
||||||
assert 'input' in graphql_field.args
|
|
||||||
input = graphql_field.args['input']
|
|
||||||
assert 'clientMutationId' in input.type.of_type.get_fields()
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_query():
|
|
||||||
executed = schema.execute(
|
|
||||||
'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
|
|
||||||
)
|
|
||||||
assert not executed.errors
|
|
||||||
assert executed.data == {'say': {'phrase': 'hello'}}
|
|
|
@ -2,15 +2,15 @@ import pytest
|
||||||
|
|
||||||
from graphql_relay import to_global_id
|
from graphql_relay import to_global_id
|
||||||
|
|
||||||
from ...types import ObjectType, Schema
|
from ...types import ObjectType, Schema, String
|
||||||
from ...types.scalars import String
|
# from ...types.scalars import String
|
||||||
from ..connection import Connection
|
# from ..connection import Connection
|
||||||
from ..node import Node
|
from ..node import Node
|
||||||
|
|
||||||
|
|
||||||
class MyNode(ObjectType):
|
class MyNode(ObjectType):
|
||||||
class Meta:
|
class Meta:
|
||||||
interfaces = [Node]
|
interfaces = (Node, )
|
||||||
name = String()
|
name = String()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -27,32 +27,33 @@ schema = Schema(query=RootQuery, types=[MyNode])
|
||||||
|
|
||||||
def test_node_no_get_node():
|
def test_node_no_get_node():
|
||||||
with pytest.raises(AssertionError) as excinfo:
|
with pytest.raises(AssertionError) as excinfo:
|
||||||
class MyNode(Node, ObjectType):
|
class MyNode(ObjectType):
|
||||||
pass
|
class Meta:
|
||||||
|
interfaces = (Node, )
|
||||||
|
|
||||||
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
def test_node_no_get_node_with_meta():
|
def test_node_no_get_node_with_meta():
|
||||||
with pytest.raises(AssertionError) as excinfo:
|
with pytest.raises(AssertionError) as excinfo:
|
||||||
class MyNode(Node, ObjectType):
|
class MyNode(ObjectType):
|
||||||
pass
|
class Meta:
|
||||||
|
interfaces = (Node, )
|
||||||
|
|
||||||
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
def test_node_good():
|
def test_node_good():
|
||||||
graphql_type = MyNode._meta.graphql_type
|
assert 'id' in MyNode._meta.fields
|
||||||
assert 'id' in graphql_type.get_fields()
|
|
||||||
|
|
||||||
|
|
||||||
def test_node_get_connection():
|
# def test_node_get_connection():
|
||||||
connection = MyNode.Connection
|
# connection = MyNode.Connection
|
||||||
assert issubclass(connection, Connection)
|
# assert issubclass(connection, Connection)
|
||||||
|
|
||||||
|
|
||||||
def test_node_get_connection_dont_duplicate():
|
# def test_node_get_connection_dont_duplicate():
|
||||||
assert MyNode.Connection == MyNode.Connection
|
# assert MyNode.Connection == MyNode.Connection
|
||||||
|
|
||||||
|
|
||||||
def test_node_query():
|
def test_node_query():
|
||||||
|
|
|
@ -6,6 +6,8 @@ from ..node import Node
|
||||||
|
|
||||||
|
|
||||||
class CustomNode(Node):
|
class CustomNode(Node):
|
||||||
|
class Meta:
|
||||||
|
name = 'Node'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_global_id(type, id):
|
def to_global_id(type, id):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user