Fixed get_resolver args, relay client_id_mutation ad allow GlobalID in mutation.

This commit is contained in:
Markus Padourek 2016-09-12 11:47:21 +01:00
parent 513b3e46c3
commit 4b6066fd5b
9 changed files with 43 additions and 24 deletions

View File

@ -46,7 +46,7 @@ class DjangoConnectionField(ConnectionField):
connection.length = _len connection.length = _len
return connection return connection
def get_resolver(self, parent_resolver, _): def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager()) return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager())

View File

@ -34,6 +34,6 @@ class DjangoFilterConnectionField(DjangoConnectionField):
return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info) return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info)
def get_resolver(self, parent_resolver, _): def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(), return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(),
self.filterset_class, self.filtering_args) self.filterset_class, self.filtering_args)

View File

@ -33,5 +33,5 @@ class SQLAlchemyConnectionField(ConnectionField):
edge_type=connection.Edge, edge_type=connection.Edge,
) )
def get_resolver(self, parent_resolver, _): def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model) return partial(self.connection_resolver, parent_resolver, self.type, self.model)

View File

@ -133,8 +133,8 @@ class IterableConnectionField(Field):
connection.iterable = iterable connection.iterable = iterable
return connection return connection
def get_resolver(self, parent_resolver, _): def get_resolver(self, parent_resolver):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver, None) resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
return partial(self.connection_resolver, resolver, self.type) return partial(self.connection_resolver, resolver, self.type)
ConnectionField = IterableConnectionField ConnectionField = IterableConnectionField

View File

@ -20,6 +20,8 @@ class ClientIDMutationMeta(ObjectTypeMeta):
input_class = attrs.pop('Input', None) input_class = attrs.pop('Input', None)
base_name = re.sub('Payload$', '', name) base_name = re.sub('Payload$', '', name)
default_client_mutation_id = String(name='clientMutationId')
attrs['client_mutation_id'] = attrs.get('client_mutation_id', default_client_mutation_id)
cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs) cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs)
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None) mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__: if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__:
@ -35,7 +37,7 @@ class ClientIDMutationMeta(ObjectTypeMeta):
input_attrs = props(input_class) input_attrs = props(input_class)
else: else:
bases += (input_class, ) bases += (input_class, )
input_attrs['client_mutation_id'] = String(name='clientMutationId') input_attrs['client_mutation_id'] = default_client_mutation_id
cls.Input = type('{}Input'.format(base_name), bases + (InputObjectType,), input_attrs) cls.Input = type('{}Input'.format(base_name), bases + (InputObjectType,), input_attrs)
cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input, required=True)) cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input, required=True))
return cls return cls

View File

@ -35,17 +35,19 @@ def get_default_connection(cls):
class GlobalID(Field): class GlobalID(Field):
def __init__(self, node=None, required=True, *args, **kwargs): def __init__(self, node=None, parent_type=None, required=True, *args, **kwargs):
super(GlobalID, self).__init__(ID, required=required, *args, **kwargs) super(GlobalID, self).__init__(ID, required=required, *args, **kwargs)
self.node = node self._node = node or Node
self._parent_type_name = parent_type._meta.name if parent_type else None
@staticmethod @staticmethod
def id_resolver(parent_resolver, node, root, args, context, info): def id_resolver(parent_resolver, node, root, args, context, info, parent_type_name=None):
id = parent_resolver(root, args, context, info) type_id = parent_resolver(root, args, context, info)
return node.to_global_id(info.parent_type.name, id) # root._meta.name parent_type_name = parent_type_name or info.parent_type.name # root._meta.name
return node.to_global_id(parent_type_name, type_id)
def get_resolver(self, parent_resolver, parent_type): def get_resolver(self, parent_resolver):
return partial(self.id_resolver, parent_resolver, self.node or parent_type) return partial(self.id_resolver, parent_resolver, self._node, parent_type_name=self._parent_type_name)
class NodeMeta(InterfaceMeta): class NodeMeta(InterfaceMeta):

View File

@ -1,25 +1,39 @@
from collections import OrderedDict
import pytest import pytest
from graphql_relay import to_global_id
from ...types import (Argument, Field, InputField, InputObjectType, ObjectType, from ...types import (Argument, Field, InputField, InputObjectType, ObjectType,
Schema, AbstractType, NonNull) Schema, AbstractType, NonNull)
from ...types.scalars import String from ...types.scalars import String
from ..mutation import ClientIDMutation from ..mutation import ClientIDMutation
from ..node import GlobalID, Node
class SharedFields(AbstractType): class SharedFields(AbstractType):
shared = String() shared = String()
class MyNode(ObjectType):
class Meta:
interfaces = (Node, )
name = String()
class SaySomething(ClientIDMutation): class SaySomething(ClientIDMutation):
class Input: class Input:
what = String() what = String()
phrase = String() phrase = String()
my_node_id = GlobalID(parent_type=MyNode)
@staticmethod @staticmethod
def mutate_and_get_payload(args, context, info): def mutate_and_get_payload(args, context, info):
what = args.get('what') what = args.get('what')
return SaySomething(phrase=str(what)) return SaySomething(phrase=str(what), my_node_id=1)
class OtherMutation(ClientIDMutation): class OtherMutation(ClientIDMutation):
@ -58,8 +72,9 @@ def test_no_mutate_and_get_payload():
def test_mutation(): def test_mutation():
fields = SaySomething._meta.fields fields = SaySomething._meta.fields
assert list(fields.keys()) == ['phrase'] assert list(fields.keys()) == ['phrase', 'my_node_id', 'client_mutation_id']
assert isinstance(fields['phrase'], Field) assert isinstance(fields['phrase'], Field)
assert isinstance(fields['my_node_id'], GlobalID)
field = SaySomething.Field() field = SaySomething.Field()
assert field.type == SaySomething assert field.type == SaySomething
assert list(field.args.keys()) == ['input'] assert list(field.args.keys()) == ['input']
@ -81,7 +96,7 @@ def test_mutation_input():
def test_subclassed_mutation(): def test_subclassed_mutation():
fields = OtherMutation._meta.fields fields = OtherMutation._meta.fields
assert list(fields.keys()) == ['name'] assert list(fields.keys()) == ['name', 'client_mutation_id']
assert isinstance(fields['name'], Field) assert isinstance(fields['name'], Field)
field = OtherMutation.Field() field = OtherMutation.Field()
assert field.type == OtherMutation assert field.type == OtherMutation
@ -104,9 +119,9 @@ def test_subclassed_mutation_input():
assert fields['client_mutation_id'].type == String assert fields['client_mutation_id'].type == String
# def test_node_query(): def test_node_query():
# executed = schema.execute( executed = schema.execute(
# 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }' 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase, clientMutationId, myNodeId} }'
# ) )
# assert not executed.errors assert not executed.errors
# assert executed.data == {'say': {'phrase': 'hello'}} assert executed.data == OrderedDict({'say': OrderedDict({'phrase': 'hello', 'clientMutationId': '1', 'myNodeId': to_global_id('MyNode', '1')})})

View File

@ -45,5 +45,5 @@ class Field(OrderedType):
return self._type() return self._type()
return self._type return self._type
def get_resolver(self, parent_resolver, _): def get_resolver(self, parent_resolver):
return self.resolver or parent_resolver return self.resolver or parent_resolver

View File

@ -219,7 +219,7 @@ class TypeMap(GraphQLTypeMap):
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolver=field.get_resolver(self.get_resolver_for_type(type, name), type), resolver=field.get_resolver(self.get_resolver_for_type(type, name)),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description description=field.description
) )