mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-02 20:54:16 +03:00
Improved Relay implementation
This commit is contained in:
parent
fd16de8748
commit
0ffdd8d9ab
|
@ -1,10 +1,10 @@
|
|||
from .node import Node
|
||||
from .mutation import ClientIDMutation
|
||||
from .connection import Connection, ConnectionField
|
||||
# from .mutation import ClientIDMutation
|
||||
# from .connection import Connection, ConnectionField
|
||||
|
||||
__all__ = [
|
||||
'Node',
|
||||
'ClientIDMutation',
|
||||
'Connection',
|
||||
'ConnectionField',
|
||||
# 'ClientIDMutation',
|
||||
# 'Connection',
|
||||
# 'ConnectionField',
|
||||
]
|
||||
|
|
|
@ -1,131 +0,0 @@
|
|||
import copy
|
||||
import re
|
||||
from collections import Iterable
|
||||
from functools import partial
|
||||
|
||||
import six
|
||||
|
||||
from graphql_relay import connection_definitions, connection_from_list
|
||||
from graphql_relay.connection.connection import connection_args
|
||||
|
||||
from ..types.field import Field
|
||||
from ..types.objecttype import ObjectType, ObjectTypeMeta
|
||||
from ..types.options import Options
|
||||
from ..utils.copy_fields import copy_fields
|
||||
from ..utils.get_fields import get_fields
|
||||
from ..utils.is_base_type import is_base_type
|
||||
from ..utils.props import props
|
||||
|
||||
|
||||
class ConnectionMeta(ObjectTypeMeta):
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
super_new = type.__new__
|
||||
|
||||
# Also ensure initialization is only performed for subclasses of Model
|
||||
# (excluding Model class itself).
|
||||
if not is_base_type(bases, ConnectionMeta):
|
||||
return super_new(cls, name, bases, attrs)
|
||||
|
||||
options = Options(
|
||||
attrs.pop('Meta', None),
|
||||
name=None,
|
||||
description=None,
|
||||
node=None,
|
||||
interfaces=[],
|
||||
)
|
||||
|
||||
edge_class = attrs.pop('Edge', None)
|
||||
edge_fields = props(edge_class) if edge_class else {}
|
||||
edge_fields = get_fields(ObjectType, edge_fields, ())
|
||||
|
||||
connection_fields = copy_fields(Field, get_fields(ObjectType, attrs, bases))
|
||||
|
||||
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
|
||||
|
||||
assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__)
|
||||
from ..utils.get_graphql_type import get_graphql_type
|
||||
edge, connection = connection_definitions(
|
||||
name=options.name or re.sub('Connection$', '', cls.__name__),
|
||||
node_type=get_graphql_type(options.node),
|
||||
resolve_node=cls.resolve_node,
|
||||
resolve_cursor=cls.resolve_cursor,
|
||||
|
||||
edge_fields=edge_fields,
|
||||
connection_fields=connection_fields,
|
||||
)
|
||||
cls.Edge = type(edge.name, (ObjectType, ), {'Meta': type('Meta', (object,), {'graphql_type': edge})})
|
||||
cls._meta.graphql_type = connection
|
||||
fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls)
|
||||
|
||||
cls._meta.get_fields = lambda: fields
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
|
||||
resolve_node = None
|
||||
resolve_cursor = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info'))
|
||||
super(Connection, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class IterableConnectionField(Field):
|
||||
|
||||
def __init__(self, type, *other_args, **kwargs):
|
||||
args = kwargs.pop('args', {})
|
||||
if not args:
|
||||
args = connection_args
|
||||
else:
|
||||
args = copy.copy(args)
|
||||
args.update(connection_args)
|
||||
|
||||
super(IterableConnectionField, self).__init__(type, args=args, *other_args, **kwargs)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
from ..utils.get_graphql_type import get_graphql_type
|
||||
return get_graphql_type(self.connection)
|
||||
|
||||
@type.setter
|
||||
def type(self, value):
|
||||
self._type = value
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
from .node import Node
|
||||
if Node in self._type._meta.interfaces:
|
||||
connection_type = self._type.Connection
|
||||
else:
|
||||
connection_type = self._type
|
||||
assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self))
|
||||
return connection_type
|
||||
|
||||
@staticmethod
|
||||
def connection_resolver(resolver, connection, root, args, context, info):
|
||||
iterable = resolver(root, args, context, info)
|
||||
# if isinstance(resolved, self.type.graphene)
|
||||
assert isinstance(
|
||||
iterable, Iterable), 'Resolved value from the connection field have to be iterable. Received "{}"'.format(
|
||||
iterable)
|
||||
connection = connection_from_list(
|
||||
iterable,
|
||||
args,
|
||||
connection_type=connection,
|
||||
edge_type=connection.Edge,
|
||||
)
|
||||
return connection
|
||||
|
||||
@property
|
||||
def resolver(self):
|
||||
resolver = super(ConnectionField, self).resolver
|
||||
connection = self.connection
|
||||
return partial(self.connection_resolver, resolver, connection)
|
||||
|
||||
@resolver.setter
|
||||
def resolver(self, resolver):
|
||||
self._resolver = resolver
|
||||
|
||||
ConnectionField = IterableConnectionField
|
|
@ -1,62 +0,0 @@
|
|||
from functools import partial
|
||||
|
||||
import six
|
||||
|
||||
from graphql_relay import mutation_with_client_mutation_id
|
||||
|
||||
from ..types.field import Field, InputField
|
||||
from ..types.inputobjecttype import InputObjectType
|
||||
from ..types.mutation import Mutation, MutationMeta
|
||||
from ..types.objecttype import ObjectType
|
||||
from ..types.options import Options
|
||||
from ..utils.copy_fields import copy_fields
|
||||
from ..utils.get_fields import get_fields
|
||||
from ..utils.is_base_type import is_base_type
|
||||
from ..utils.props import props
|
||||
|
||||
|
||||
class ClientIDMutationMeta(MutationMeta):
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
super_new = type.__new__
|
||||
|
||||
# Also ensure initialization is only performed for subclasses of Model
|
||||
# (excluding Model class itself).
|
||||
if not is_base_type(bases, ClientIDMutationMeta):
|
||||
return super_new(cls, name, bases, attrs)
|
||||
|
||||
options = Options(
|
||||
attrs.pop('Meta', None),
|
||||
name=None,
|
||||
description=None,
|
||||
)
|
||||
|
||||
input_class = attrs.pop('Input', None)
|
||||
|
||||
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
|
||||
|
||||
input_fields = props(input_class) if input_class else {}
|
||||
input_local_fields = copy_fields(InputField, get_fields(InputObjectType, input_fields, ()))
|
||||
output_fields = copy_fields(Field, get_fields(ObjectType, attrs, bases))
|
||||
|
||||
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
|
||||
assert mutate_and_get_payload, (
|
||||
"{}.mutate_and_get_payload method is required"
|
||||
" in a ClientIDMutation ObjectType.".format(cls.__name__)
|
||||
)
|
||||
|
||||
field = mutation_with_client_mutation_id(
|
||||
name=options.name or cls.__name__,
|
||||
input_fields=input_local_fields,
|
||||
output_fields=output_fields,
|
||||
mutate_and_get_payload=cls.mutate_and_get_payload,
|
||||
)
|
||||
options.graphql_type = field.type
|
||||
options.get_fields = lambda: output_fields
|
||||
|
||||
cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None)
|
||||
return cls
|
||||
|
||||
|
||||
class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, Mutation)):
|
||||
pass
|
|
@ -1,69 +1,39 @@
|
|||
from functools import partial
|
||||
|
||||
import six
|
||||
|
||||
from graphql_relay import from_global_id, node_definitions, to_global_id
|
||||
|
||||
from ..types.field import Field
|
||||
from ..types.interface import Interface
|
||||
from ..types.objecttype import ObjectType, ObjectTypeMeta
|
||||
from ..types.options import Options
|
||||
from ..utils.copy_fields import copy_fields
|
||||
from .connection import Connection
|
||||
from graphql_relay import from_global_id, to_global_id
|
||||
from ..types import Interface, ID, Field
|
||||
|
||||
|
||||
def get_default_connection(cls):
|
||||
assert issubclass(cls, ObjectType), 'Can only get connection type on implemented Nodes.'
|
||||
class Node(Interface):
|
||||
'''An object with an ID'''
|
||||
|
||||
class Meta:
|
||||
node = cls
|
||||
id = ID(required=True, description='The ID of the object.')
|
||||
|
||||
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
|
||||
@classmethod
|
||||
def resolve_id(cls, root, args, context, info):
|
||||
type_name = root._meta.name # info.parent_type.name
|
||||
return cls.to_global_id(type_name, getattr(root, 'id', None))
|
||||
|
||||
@classmethod
|
||||
def Field(cls):
|
||||
def resolve_node(root, args, context, info):
|
||||
return cls.get_node_from_global_id(args.get('id'), context, info)
|
||||
|
||||
# We inherit from ObjectTypeMeta as we want to allow
|
||||
# inheriting from Node, and also ObjectType.
|
||||
# Like class MyNode(Node): pass
|
||||
# And class MyNodeImplementation(Node, ObjectType): pass
|
||||
class NodeMeta(ObjectTypeMeta):
|
||||
|
||||
@staticmethod
|
||||
def _get_interface_options(meta):
|
||||
return Options(
|
||||
meta,
|
||||
return Field(
|
||||
cls,
|
||||
description='The ID of the object',
|
||||
id=ID(required=True),
|
||||
resolver=resolve_node
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_interface(cls, name, bases, attrs):
|
||||
options = cls._get_interface_options(attrs.pop('Meta', None))
|
||||
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
|
||||
get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None)
|
||||
id_resolver = getattr(cls, 'id_resolver', None)
|
||||
assert get_node_from_global_id, '{}.get_node_from_global_id method is required by the Node interface.'.format(
|
||||
cls.__name__)
|
||||
node_interface, node_field = node_definitions(
|
||||
get_node_from_global_id,
|
||||
id_resolver=id_resolver,
|
||||
type_resolver=cls.resolve_type,
|
||||
)
|
||||
options.graphql_type = node_interface
|
||||
|
||||
fields = copy_fields(Field, options.graphql_type.get_fields(), parent=cls)
|
||||
options.get_fields = lambda: fields
|
||||
|
||||
cls.Field = partial(
|
||||
Field.copy_and_extend,
|
||||
node_field,
|
||||
type=node_field.type,
|
||||
parent=cls,
|
||||
_creation_counter=None)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class Node(six.with_metaclass(NodeMeta, Interface)):
|
||||
_connection = None
|
||||
resolve_type = None
|
||||
@classmethod
|
||||
def get_node_from_global_id(cls, global_id, context, info):
|
||||
try:
|
||||
_type, _id = cls.from_global_id(global_id)
|
||||
graphene_type = info.schema.get_type(_type).graphene_type
|
||||
# We make sure the ObjectType implements the "Node" interface
|
||||
assert cls in graphene_type._meta.interfaces
|
||||
except:
|
||||
return None
|
||||
return graphene_type.get_node(_id, context, info)
|
||||
|
||||
@classmethod
|
||||
def from_global_id(cls, global_id):
|
||||
|
@ -73,33 +43,14 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
|
|||
def to_global_id(cls, type, id):
|
||||
return to_global_id(type, id)
|
||||
|
||||
@classmethod
|
||||
def resolve_id(cls, root, args, context, info):
|
||||
return getattr(root, 'id', None)
|
||||
|
||||
@classmethod
|
||||
def id_resolver(cls, root, args, context, info):
|
||||
return cls.to_global_id(info.parent_type.name, cls.resolve_id(root, args, context, info))
|
||||
|
||||
@classmethod
|
||||
def get_node_from_global_id(cls, global_id, context, info):
|
||||
try:
|
||||
_type, _id = cls.from_global_id(global_id)
|
||||
except:
|
||||
return None
|
||||
graphql_type = info.schema.get_type(_type)
|
||||
if cls._meta.graphql_type not in graphql_type.get_interfaces():
|
||||
return
|
||||
return graphql_type.graphene_type.get_node(_id, context, info)
|
||||
|
||||
@classmethod
|
||||
def implements(cls, objecttype):
|
||||
require_get_node = Node._meta.graphql_type in objecttype._meta.get_interfaces
|
||||
get_connection = getattr(objecttype, 'get_connection', None)
|
||||
if not get_connection:
|
||||
get_connection = partial(get_default_connection, objecttype)
|
||||
require_get_node = Node in objecttype._meta.interfaces
|
||||
# get_connection = getattr(objecttype, 'get_connection', None)
|
||||
# if not get_connection:
|
||||
# get_connection = partial(get_default_connection, objecttype)
|
||||
|
||||
objecttype.Connection = get_connection()
|
||||
# objecttype.Connection = get_connection()
|
||||
if require_get_node:
|
||||
assert hasattr(
|
||||
objecttype, 'get_node'), '{}.get_node method is required by the Node interface.'.format(
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
|
||||
from ...types import ObjectType, Schema
|
||||
from ...types.field import Field
|
||||
from ...types.scalars import String
|
||||
from ..connection import Connection
|
||||
from ..node import Node
|
||||
|
||||
|
||||
class MyObject(ObjectType):
|
||||
class Meta:
|
||||
interfaces = [Node]
|
||||
field = String()
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, id):
|
||||
pass
|
||||
|
||||
|
||||
class MyObjectConnection(Connection):
|
||||
|
||||
class Meta:
|
||||
node = MyObject
|
||||
|
||||
class Edge:
|
||||
other = String()
|
||||
|
||||
|
||||
class RootQuery(ObjectType):
|
||||
my_connection = Field(MyObjectConnection)
|
||||
|
||||
|
||||
schema = Schema(query=RootQuery)
|
||||
|
||||
|
||||
def test_node_good():
|
||||
graphql_type = MyObjectConnection._meta.graphql_type
|
||||
fields = graphql_type.get_fields()
|
||||
assert 'edges' in fields
|
||||
assert 'pageInfo' in fields
|
||||
edge_fields = fields['edges'].type.of_type.get_fields()
|
||||
assert 'node' in edge_fields
|
||||
assert edge_fields['node'].type == MyObject._meta.graphql_type
|
||||
assert 'other' in edge_fields
|
|
@ -1,55 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from ...types import ObjectType, Schema
|
||||
from ...types.scalars import String
|
||||
from ..mutation import ClientIDMutation
|
||||
|
||||
|
||||
class SaySomething(ClientIDMutation):
|
||||
|
||||
class Input:
|
||||
what = String()
|
||||
phrase = String()
|
||||
|
||||
@staticmethod
|
||||
def mutate_and_get_payload(args, context, info):
|
||||
what = args.get('what')
|
||||
return SaySomething(phrase=str(what))
|
||||
|
||||
|
||||
class RootQuery(ObjectType):
|
||||
something = String()
|
||||
|
||||
|
||||
class Mutation(ObjectType):
|
||||
say = SaySomething.Field()
|
||||
|
||||
schema = Schema(query=RootQuery, mutation=Mutation)
|
||||
|
||||
|
||||
def test_no_mutate_and_get_payload():
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
class MyMutation(ClientIDMutation):
|
||||
pass
|
||||
|
||||
assert "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation ObjectType." == str(
|
||||
excinfo.value)
|
||||
|
||||
|
||||
def test_node_good():
|
||||
graphql_type = SaySomething._meta.graphql_type
|
||||
fields = graphql_type.get_fields()
|
||||
assert 'phrase' in fields
|
||||
graphql_field = SaySomething.Field()
|
||||
assert graphql_field.type == SaySomething._meta.graphql_type
|
||||
assert 'input' in graphql_field.args
|
||||
input = graphql_field.args['input']
|
||||
assert 'clientMutationId' in input.type.of_type.get_fields()
|
||||
|
||||
|
||||
def test_node_query():
|
||||
executed = schema.execute(
|
||||
'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
|
||||
)
|
||||
assert not executed.errors
|
||||
assert executed.data == {'say': {'phrase': 'hello'}}
|
|
@ -2,15 +2,15 @@ import pytest
|
|||
|
||||
from graphql_relay import to_global_id
|
||||
|
||||
from ...types import ObjectType, Schema
|
||||
from ...types.scalars import String
|
||||
from ..connection import Connection
|
||||
from ...types import ObjectType, Schema, String
|
||||
# from ...types.scalars import String
|
||||
# from ..connection import Connection
|
||||
from ..node import Node
|
||||
|
||||
|
||||
class MyNode(ObjectType):
|
||||
class Meta:
|
||||
interfaces = [Node]
|
||||
interfaces = (Node, )
|
||||
name = String()
|
||||
|
||||
@staticmethod
|
||||
|
@ -27,32 +27,33 @@ schema = Schema(query=RootQuery, types=[MyNode])
|
|||
|
||||
def test_node_no_get_node():
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
class MyNode(Node, ObjectType):
|
||||
pass
|
||||
class MyNode(ObjectType):
|
||||
class Meta:
|
||||
interfaces = (Node, )
|
||||
|
||||
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
||||
|
||||
|
||||
def test_node_no_get_node_with_meta():
|
||||
with pytest.raises(AssertionError) as excinfo:
|
||||
class MyNode(Node, ObjectType):
|
||||
pass
|
||||
class MyNode(ObjectType):
|
||||
class Meta:
|
||||
interfaces = (Node, )
|
||||
|
||||
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
|
||||
|
||||
|
||||
def test_node_good():
|
||||
graphql_type = MyNode._meta.graphql_type
|
||||
assert 'id' in graphql_type.get_fields()
|
||||
assert 'id' in MyNode._meta.fields
|
||||
|
||||
|
||||
def test_node_get_connection():
|
||||
connection = MyNode.Connection
|
||||
assert issubclass(connection, Connection)
|
||||
# def test_node_get_connection():
|
||||
# connection = MyNode.Connection
|
||||
# assert issubclass(connection, Connection)
|
||||
|
||||
|
||||
def test_node_get_connection_dont_duplicate():
|
||||
assert MyNode.Connection == MyNode.Connection
|
||||
# def test_node_get_connection_dont_duplicate():
|
||||
# assert MyNode.Connection == MyNode.Connection
|
||||
|
||||
|
||||
def test_node_query():
|
||||
|
|
|
@ -6,6 +6,8 @@ from ..node import Node
|
|||
|
||||
|
||||
class CustomNode(Node):
|
||||
class Meta:
|
||||
name = 'Node'
|
||||
|
||||
@staticmethod
|
||||
def to_global_id(type, id):
|
||||
|
|
Loading…
Reference in New Issue
Block a user