diff --git a/graphene/relay/mutation.py b/graphene/relay/mutation.py index 8db462d6..ab1e9eb4 100644 --- a/graphene/relay/mutation.py +++ b/graphene/relay/mutation.py @@ -5,14 +5,14 @@ import six from promise import Promise -from ..types import AbstractType, Argument, Field, InputObjectType, String -from ..types.objecttype import ObjectType, ObjectTypeMeta +from ..types import Field, AbstractType, Argument, InputObjectType, String +from ..types.mutation import Mutation, MutationMeta +from ..types.objecttype import ObjectTypeMeta from ..utils.is_base_type import is_base_type from ..utils.props import props -class ClientIDMutationMeta(ObjectTypeMeta): - +class ClientIDMutationMeta(MutationMeta): def __new__(cls, name, bases, attrs): # Also ensure initialization is only performed for subclasses of # Mutation @@ -23,13 +23,13 @@ class ClientIDMutationMeta(ObjectTypeMeta): base_name = re.sub('Payload$', '', name) if 'client_mutation_id' not in attrs: attrs['client_mutation_id'] = String(name='clientMutationId') - cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs) + cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, + attrs) mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None) if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__: assert mutate_and_get_payload, ( "{}.mutate_and_get_payload method is required" - " in a ClientIDMutation." - ).format(name) + " in a ClientIDMutation.").format(name) input_attrs = {} bases = () if not input_class: @@ -39,13 +39,18 @@ class ClientIDMutationMeta(ObjectTypeMeta): else: bases += (input_class, ) input_attrs['client_mutation_id'] = String(name='clientMutationId') - cls.Input = type('{}Input'.format(base_name), bases + (InputObjectType,), input_attrs) - cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input, required=True)) + cls.Input = type('{}Input'.format(base_name), + bases + (InputObjectType, ), input_attrs) + output_class = getattr(cls, 'Output', cls) + cls.Field = partial( + Field, + output_class, + resolver=cls.mutate, + input=Argument(cls.Input, required=True)) return cls -class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, ObjectType)): - +class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, Mutation)): @classmethod def mutate(cls, root, args, context, info): input = args.get('input') @@ -54,11 +59,10 @@ class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, ObjectType)): try: payload.client_mutation_id = input.get('clientMutationId') except: - raise Exception(( - 'Cannot set client_mutation_id in the payload object {}' - ).format(repr(payload))) + raise Exception( + ('Cannot set client_mutation_id in the payload object {}' + ).format(repr(payload))) return payload return Promise.resolve( - cls.mutate_and_get_payload(input, context, info) - ).then(on_resolve) + cls.mutate_and_get_payload(input, context, info)).then(on_resolve) diff --git a/graphene/types/mutation.py b/graphene/types/mutation.py index f6f5b19b..2e8e120a 100644 --- a/graphene/types/mutation.py +++ b/graphene/types/mutation.py @@ -10,7 +10,6 @@ from .objecttype import ObjectType, ObjectTypeMeta class MutationMeta(ObjectTypeMeta): - def __new__(cls, name, bases, attrs): # Also ensure initialization is only performed for subclasses of # Mutation @@ -21,10 +20,12 @@ class MutationMeta(ObjectTypeMeta): cls = ObjectTypeMeta.__new__(cls, name, bases, attrs) field_args = props(input_class) if input_class else {} + output_class = getattr(cls, 'Output', cls) resolver = getattr(cls, 'mutate', None) assert resolver, 'All mutations must define a mutate method in it' resolver = get_unbound_function(resolver) - cls.Field = partial(Field, cls, args=field_args, resolver=resolver) + cls.Field = partial( + Field, output_class, args=field_args, resolver=resolver) return cls diff --git a/graphene/types/tests/test_mutation.py b/graphene/types/tests/test_mutation.py index 8ff8773f..0f6d8900 100644 --- a/graphene/types/tests/test_mutation.py +++ b/graphene/types/tests/test_mutation.py @@ -3,6 +3,7 @@ import pytest from ..mutation import Mutation from ..objecttype import ObjectType from ..schema import Schema +from ..argument import Argument from ..scalars import String from ..dynamic import Dynamic @@ -10,6 +11,7 @@ from ..dynamic import Dynamic def test_generate_mutation_no_args(): class MyMutation(Mutation): '''Documentation''' + @classmethod def mutate(cls, *args, **kwargs): pass @@ -22,7 +24,6 @@ def test_generate_mutation_no_args(): def test_generate_mutation_with_meta(): class MyMutation(Mutation): - class Meta: name = 'MyOtherMutation' description = 'Documentation' @@ -38,10 +39,33 @@ def test_generate_mutation_with_meta(): def test_mutation_raises_exception_if_no_mutate(): with pytest.raises(AssertionError) as excinfo: + class MyMutation(Mutation): pass - assert "All mutations must define a mutate method in it" == str(excinfo.value) + assert "All mutations must define a mutate method in it" == str( + excinfo.value) + + +def test_mutation_custom_output_type(): + class User(ObjectType): + name = String() + + class CreateUser(Mutation): + class Input: + name = String() + + Output = User + + @classmethod + def mutate(cls, args, context, info): + name = args.get('name') + return User(name=name) + + field = CreateUser.Field() + assert field.type == User + assert field.args == {'name': Argument(String)} + assert field.resolver == CreateUser.mutate def test_mutation_execution():