Fixed get_resolver args, relay client_id_mutation ad allow GlobalID in mutation.

This commit is contained in:
Markus Padourek 2016-09-12 11:47:21 +01:00
parent 513b3e46c3
commit 4b6066fd5b
9 changed files with 43 additions and 24 deletions

View File

@ -46,7 +46,7 @@ class DjangoConnectionField(ConnectionField):
connection.length = _len
return connection
def get_resolver(self, parent_resolver, _):
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager())

View File

@ -34,6 +34,6 @@ class DjangoFilterConnectionField(DjangoConnectionField):
return DjangoConnectionField.connection_resolver(resolver, connection, qs, root, args, context, info)
def get_resolver(self, parent_resolver, _):
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.get_manager(),
self.filterset_class, self.filtering_args)

View File

@ -33,5 +33,5 @@ class SQLAlchemyConnectionField(ConnectionField):
edge_type=connection.Edge,
)
def get_resolver(self, parent_resolver, _):
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)

View File

@ -133,8 +133,8 @@ class IterableConnectionField(Field):
connection.iterable = iterable
return connection
def get_resolver(self, parent_resolver, _):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver, None)
def get_resolver(self, parent_resolver):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
return partial(self.connection_resolver, resolver, self.type)
ConnectionField = IterableConnectionField

View File

@ -20,6 +20,8 @@ class ClientIDMutationMeta(ObjectTypeMeta):
input_class = attrs.pop('Input', None)
base_name = re.sub('Payload$', '', name)
default_client_mutation_id = String(name='clientMutationId')
attrs['client_mutation_id'] = attrs.get('client_mutation_id', default_client_mutation_id)
cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs)
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__:
@ -35,7 +37,7 @@ class ClientIDMutationMeta(ObjectTypeMeta):
input_attrs = props(input_class)
else:
bases += (input_class, )
input_attrs['client_mutation_id'] = String(name='clientMutationId')
input_attrs['client_mutation_id'] = default_client_mutation_id
cls.Input = type('{}Input'.format(base_name), bases + (InputObjectType,), input_attrs)
cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input, required=True))
return cls

View File

@ -35,17 +35,19 @@ def get_default_connection(cls):
class GlobalID(Field):
def __init__(self, node=None, required=True, *args, **kwargs):
def __init__(self, node=None, parent_type=None, required=True, *args, **kwargs):
super(GlobalID, self).__init__(ID, required=required, *args, **kwargs)
self.node = node
self._node = node or Node
self._parent_type_name = parent_type._meta.name if parent_type else None
@staticmethod
def id_resolver(parent_resolver, node, root, args, context, info):
id = parent_resolver(root, args, context, info)
return node.to_global_id(info.parent_type.name, id) # root._meta.name
def id_resolver(parent_resolver, node, root, args, context, info, parent_type_name=None):
type_id = parent_resolver(root, args, context, info)
parent_type_name = parent_type_name or info.parent_type.name # root._meta.name
return node.to_global_id(parent_type_name, type_id)
def get_resolver(self, parent_resolver, parent_type):
return partial(self.id_resolver, parent_resolver, self.node or parent_type)
def get_resolver(self, parent_resolver):
return partial(self.id_resolver, parent_resolver, self._node, parent_type_name=self._parent_type_name)
class NodeMeta(InterfaceMeta):

View File

@ -1,25 +1,39 @@
from collections import OrderedDict
import pytest
from graphql_relay import to_global_id
from ...types import (Argument, Field, InputField, InputObjectType, ObjectType,
Schema, AbstractType, NonNull)
from ...types.scalars import String
from ..mutation import ClientIDMutation
from ..node import GlobalID, Node
class SharedFields(AbstractType):
shared = String()
class MyNode(ObjectType):
class Meta:
interfaces = (Node, )
name = String()
class SaySomething(ClientIDMutation):
class Input:
what = String()
phrase = String()
my_node_id = GlobalID(parent_type=MyNode)
@staticmethod
def mutate_and_get_payload(args, context, info):
what = args.get('what')
return SaySomething(phrase=str(what))
return SaySomething(phrase=str(what), my_node_id=1)
class OtherMutation(ClientIDMutation):
@ -58,8 +72,9 @@ def test_no_mutate_and_get_payload():
def test_mutation():
fields = SaySomething._meta.fields
assert list(fields.keys()) == ['phrase']
assert list(fields.keys()) == ['phrase', 'my_node_id', 'client_mutation_id']
assert isinstance(fields['phrase'], Field)
assert isinstance(fields['my_node_id'], GlobalID)
field = SaySomething.Field()
assert field.type == SaySomething
assert list(field.args.keys()) == ['input']
@ -81,7 +96,7 @@ def test_mutation_input():
def test_subclassed_mutation():
fields = OtherMutation._meta.fields
assert list(fields.keys()) == ['name']
assert list(fields.keys()) == ['name', 'client_mutation_id']
assert isinstance(fields['name'], Field)
field = OtherMutation.Field()
assert field.type == OtherMutation
@ -104,9 +119,9 @@ def test_subclassed_mutation_input():
assert fields['client_mutation_id'].type == String
# def test_node_query():
# executed = schema.execute(
# 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
# )
# assert not executed.errors
# assert executed.data == {'say': {'phrase': 'hello'}}
def test_node_query():
executed = schema.execute(
'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase, clientMutationId, myNodeId} }'
)
assert not executed.errors
assert executed.data == OrderedDict({'say': OrderedDict({'phrase': 'hello', 'clientMutationId': '1', 'myNodeId': to_global_id('MyNode', '1')})})

View File

@ -45,5 +45,5 @@ class Field(OrderedType):
return self._type()
return self._type
def get_resolver(self, parent_resolver, _):
def get_resolver(self, parent_resolver):
return self.resolver or parent_resolver

View File

@ -219,7 +219,7 @@ class TypeMap(GraphQLTypeMap):
_field = GraphQLField(
field_type,
args=args,
resolver=field.get_resolver(self.get_resolver_for_type(type, name), type),
resolver=field.get_resolver(self.get_resolver_for_type(type, name)),
deprecation_reason=field.deprecation_reason,
description=field.description
)