Improved relay integration

This commit is contained in:
Syrus Akbary 2015-11-10 22:57:22 -08:00
parent 28d89a44f1
commit 41648b5a94
6 changed files with 59 additions and 53 deletions

View File

@ -21,7 +21,8 @@ class Field(OrderedType):
self.name = name self.name = name
if isinstance(type, six.string_types): if isinstance(type, six.string_types):
type = LazyType(type) type = LazyType(type)
if required: self.required = required
if self.required:
type = NonNull(type) type = NonNull(type)
self.type = type self.type = type
self.description = description self.description = description
@ -68,7 +69,7 @@ class Field(OrderedType):
type_objecttype = schema.objecttype(type) type_objecttype = schema.objecttype(type)
if type_objecttype and type_objecttype._meta.is_mutation: if type_objecttype and type_objecttype._meta.is_mutation:
assert len(arguments) == 0 assert len(arguments) == 0
arguments = type_objecttype.arguments arguments = type_objecttype.get_arguments()
resolver = getattr(type_objecttype, 'mutate') resolver = getattr(type_objecttype, 'mutate')
assert type, 'Internal type for field %s is None' % str(self) assert type, 'Internal type for field %s is None' % str(self)

View File

@ -52,10 +52,6 @@ class ObjectTypeMeta(type):
assert not ( assert not (
new_class._meta.is_interface and new_class._meta.is_mutation) new_class._meta.is_interface and new_class._meta.is_mutation)
input_class = None
if new_class._meta.is_mutation:
input_class = attrs.pop('Input', None)
# Add all attributes to the class. # Add all attributes to the class.
for obj_name, obj in attrs.items(): for obj_name, obj in attrs.items():
new_class.add_to_class(obj_name, obj) new_class.add_to_class(obj_name, obj)
@ -64,14 +60,6 @@ class ObjectTypeMeta(type):
assert hasattr( assert hasattr(
new_class, 'mutate'), "All mutations must implement mutate method" new_class, 'mutate'), "All mutations must implement mutate method"
if input_class:
items = dict(input_class.__dict__)
items.pop('__dict__', None)
items.pop('__doc__', None)
items.pop('__module__', None)
arguments = ArgumentsGroup(**items)
new_class.add_to_class('arguments', arguments)
new_class.add_extra_fields() new_class.add_extra_fields()
new_fields = new_class._meta.local_fields new_fields = new_class._meta.local_fields
@ -215,7 +203,21 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
class Mutation(six.with_metaclass(ObjectTypeMeta, BaseObjectType)): class Mutation(six.with_metaclass(ObjectTypeMeta, BaseObjectType)):
pass @classmethod
def _prepare_class(cls):
input_class = getattr(cls, 'Input', None)
if input_class:
items = dict(input_class.__dict__)
items.pop('__dict__', None)
items.pop('__doc__', None)
items.pop('__module__', None)
arguments = ArgumentsGroup(**items)
cls.add_to_class('arguments', arguments)
delattr(cls, 'Input')
@classmethod
def get_arguments(cls):
return cls.arguments
class InputObjectType(ObjectType): class InputObjectType(ObjectType):

View File

@ -1,18 +1,21 @@
from collections import Iterable from collections import Iterable
from graphene.core.fields import Field, IDField from graphene.core.fields import Field, IDField
from graphene.core.types.scalars import String, ID
from graphql.core.type import GraphQLArgument, GraphQLID, GraphQLNonNull from graphql.core.type import GraphQLArgument, GraphQLID, GraphQLNonNull
from graphql_relay.connection.arrayconnection import connection_from_list from graphql_relay.connection.arrayconnection import connection_from_list
from graphql_relay.connection.connection import connection_args
from graphql_relay.node.node import from_global_id from graphql_relay.node.node import from_global_id
class ConnectionField(Field): class ConnectionField(Field):
def __init__(self, field_type, resolve=None, description='', def __init__(self, field_type, resolver=None, description='',
connection_type=None, edge_type=None, **kwargs): connection_type=None, edge_type=None, **kwargs):
super(ConnectionField, self).__init__(field_type, resolve=resolve, super(ConnectionField, self).__init__(field_type, resolver=resolver,
args=connection_args, before=String(),
after=String(),
first=String(),
last=String(),
description=description, **kwargs) description=description, **kwargs)
self.connection_type = connection_type self.connection_type = connection_type
self.edge_type = edge_type self.edge_type = edge_type
@ -60,12 +63,9 @@ class NodeField(Field):
def __init__(self, object_type=None, *args, **kwargs): def __init__(self, object_type=None, *args, **kwargs):
from graphene.relay.types import Node from graphene.relay.types import Node
kwargs['id'] = ID(description='The ID of an object')
super(NodeField, self).__init__(object_type or Node, *args, **kwargs) super(NodeField, self).__init__(object_type or Node, *args, **kwargs)
self.field_object_type = object_type self.field_object_type = object_type
self.args['id'] = GraphQLArgument(
GraphQLNonNull(GraphQLID),
description='The ID of an object'
)
def id_fetcher(self, global_id, info): def id_fetcher(self, global_id, info):
from graphene.relay.utils import is_node from graphene.relay.utils import is_node
@ -88,11 +88,11 @@ class GlobalIDField(IDField):
'''The ID of an object''' '''The ID of an object'''
required = True required = True
def contribute_to_class(self, cls, name, add=True): def contribute_to_class(self, cls, name):
from graphene.relay.utils import is_node, is_node_type from graphene.relay.utils import is_node, is_node_type
in_node = is_node(cls) or is_node_type(cls) in_node = is_node(cls) or is_node_type(cls)
assert in_node, 'GlobalIDField could only be inside a Node, but got %r' % cls assert in_node, 'GlobalIDField could only be inside a Node, but got %r' % cls
super(GlobalIDField, self).contribute_to_class(cls, name, add) super(GlobalIDField, self).contribute_to_class(cls, name)
def resolve(self, instance, args, info): def resolve(self, instance, args, info):
return self.object_type.to_global_id(instance, args, info) return self.object_type.to_global_id(instance, args, info)

View File

@ -1,6 +1,8 @@
from graphene.core.fields import BooleanField, Field, ListField, StringField from graphene.core.fields import BooleanField, Field, ListField, StringField
from graphene.core.types import (InputObjectType, Interface, Mutation, from graphene.core.types import (InputObjectType, Interface, Mutation,
ObjectType) ObjectType)
from graphene.core.types.argument import ArgumentsGroup
from graphene.core.types.definitions import NonNull
from graphene.relay.fields import GlobalIDField from graphene.relay.fields import GlobalIDField
from graphene.utils import memoize from graphene.utils import memoize
from graphql_relay.node.node import to_global_id from graphql_relay.node.node import to_global_id
@ -90,7 +92,7 @@ class BaseNode(object):
class Node(BaseNode, Interface): class Node(BaseNode, Interface):
'''An object with an ID''' '''An object with an ID'''
id = GlobalIDField() id = GlobalIDField(required=True)
class MutationInputType(InputObjectType): class MutationInputType(InputObjectType):
@ -102,19 +104,19 @@ class ClientIDMutation(Mutation):
@classmethod @classmethod
def _prepare_class(cls): def _prepare_class(cls):
input_type = getattr(cls, 'input_type', None) Input = getattr(cls, 'Input', None)
if input_type: if Input:
assert hasattr( assert hasattr(
cls, 'mutate_and_get_payload'), 'You have to implement mutate_and_get_payload' cls, 'mutate_and_get_payload'), 'You have to implement mutate_and_get_payload'
new_input_inner_type = type('{}InnerInput'.format(
cls._meta.type_name), (MutationInputType, input_type, ), {}) items = dict(Input.__dict__)
items = { items.pop('__dict__', None)
'input': Field(new_input_inner_type) new_input_type = type('{}Input'.format(
} cls._meta.type_name), (MutationInputType, ), items)
assert issubclass(new_input_inner_type, InputObjectType) cls.add_to_class('input_type', new_input_type)
input_type = type('{}Input'.format( arguments = ArgumentsGroup(input=NonNull(new_input_type))
cls._meta.type_name), (ObjectType, ), items) cls.add_to_class('arguments', arguments)
setattr(cls, 'input_type', input_type) delattr(cls, 'Input')
@classmethod @classmethod
def mutate(cls, instance, args, info): def mutate(cls, instance, args, info):

View File

@ -32,19 +32,19 @@ class MyResultMutation(graphene.ObjectType):
schema = Schema(query=Query, mutation=MyResultMutation) schema = Schema(query=Query, mutation=MyResultMutation)
def test_mutation_input(): def test_mutation_arguments():
assert ChangeNumber.input_type assert ChangeNumber.arguments
assert ChangeNumber.input_type._meta.type_name == 'ChangeNumberInput' assert list(ChangeNumber.arguments) == ['input']
assert list(ChangeNumber.input_type._meta.fields_map.keys()) == ['input'] _input = ChangeNumber.arguments['input']
_input = ChangeNumber.input_type._meta.fields_map['input']
inner_type = _input.get_object_type(schema) # inner_type = _input.get_object_type(schema)
client_mutation_id_field = inner_type._meta.fields_map[ # client_mutation_id_field = inner_type._meta.fields_map[
'client_mutation_id'] # 'client_mutation_id']
assert issubclass(inner_type, InputObjectType) # assert issubclass(inner_type, InputObjectType)
assert isinstance(client_mutation_id_field, graphene.StringField) # assert isinstance(client_mutation_id_field, graphene.StringField)
assert client_mutation_id_field.object_type == inner_type # assert client_mutation_id_field.object_type == inner_type
assert isinstance(client_mutation_id_field.internal_field( # assert isinstance(client_mutation_id_field.internal_field(
schema), GraphQLInputObjectField) # schema), GraphQLInputObjectField)
def test_execute_mutations(): def test_execute_mutations():

View File

@ -21,7 +21,7 @@ class MyNode(relay.Node):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
my_node = relay.NodeField(MyNode) my_node = relay.NodeField(MyNode)
all_my_nodes = relay.ConnectionField( all_my_nodes = relay.ConnectionField(
MyNode, connection_type=MyConnection, customArg=graphene.Argument(graphene.String)) MyNode, connection_type=MyConnection, customArg=graphene.String())
def resolve_all_my_nodes(self, args, info): def resolve_all_my_nodes(self, args, info):
custom_arg = args.get('customArg') custom_arg = args.get('customArg')
@ -73,5 +73,6 @@ def test_nodefield_query():
def test_nodeidfield(): def test_nodeidfield():
id_field = MyNode._meta.fields_map['id'] id_field = MyNode._meta.fields_map['id']
assert isinstance(id_field.internal_field(schema).type, GraphQLNonNull) id_field_type = schema.T(id_field)
assert id_field.internal_field(schema).type.of_type == GraphQLID assert isinstance(id_field_type.type, GraphQLNonNull)
assert id_field_type.type.of_type == GraphQLID