diff --git a/graphene/relay/tests/test_mutations.py b/graphene/relay/tests/test_mutations.py index d569bd7e..0874be65 100644 --- a/graphene/relay/tests/test_mutations.py +++ b/graphene/relay/tests/test_mutations.py @@ -1,10 +1,11 @@ from graphql.type import GraphQLInputObjectField import graphene -from graphene import relay +from graphene import relay, with_context from graphene.core.schema import Schema my_id = 0 +my_id_context = 0 class Query(graphene.ObjectType): @@ -25,8 +26,24 @@ class ChangeNumber(relay.ClientIDMutation): return ChangeNumber(result=my_id) +class ChangeNumberContext(relay.ClientIDMutation): + '''Result mutation''' + class Input: + to = graphene.Int() + + result = graphene.String() + + @classmethod + @with_context + def mutate_and_get_payload(cls, input, context, info): + global my_id_context + my_id_context = input.get('to', my_id_context + context) + return ChangeNumber(result=my_id_context) + + class MyResultMutation(graphene.ObjectType): change_number = graphene.Field(ChangeNumber) + change_number_context = graphene.Field(ChangeNumberContext) schema = Schema(query=Query, mutation=MyResultMutation) @@ -79,3 +96,39 @@ def test_execute_mutations(): result = schema.execute(query, root_value=object()) assert not result.errors assert result.data == expected + + +def test_context_mutations(): + query = ''' + mutation M{ + first: changeNumberContext(input: {clientMutationId: "mutation1"}) { + clientMutationId + result + }, + second: changeNumberContext(input: {clientMutationId: "mutation2"}) { + clientMutationId + result + } + third: changeNumberContext(input: {clientMutationId: "mutation3", to: 5}) { + result + clientMutationId + } + } + ''' + expected = { + 'first': { + 'clientMutationId': 'mutation1', + 'result': '-1', + }, + 'second': { + 'clientMutationId': 'mutation2', + 'result': '-2', + }, + 'third': { + 'clientMutationId': 'mutation3', + 'result': '5', + } + } + result = schema.execute(query, root_value=object(), context_value=-1) + assert not result.errors + assert result.data == expected diff --git a/graphene/relay/types.py b/graphene/relay/types.py index 6df9d7f6..29ba779a 100644 --- a/graphene/relay/types.py +++ b/graphene/relay/types.py @@ -15,7 +15,7 @@ from ..core.types import Boolean, Field, List, String from ..core.types.argument import ArgumentsGroup from ..core.types.definitions import NonNull from ..utils import memoize -from ..utils.wrap_resolver_function import has_context +from ..utils.wrap_resolver_function import has_context, with_context from .fields import GlobalIDField @@ -192,9 +192,13 @@ class ClientIDMutation(six.with_metaclass(RelayMutationMeta, Mutation)): abstract = True @classmethod - def mutate(cls, instance, args, info): + @with_context + def mutate(cls, instance, args, context, info): input = args.get('input') - payload = cls.mutate_and_get_payload(input, info) + if has_context(cls.mutate_and_get_payload): + payload = cls.mutate_and_get_payload(input, context, info) + else: + payload = cls.mutate_and_get_payload(input, info) client_mutation_id = input.get('clientMutationId') or input.get('client_mutation_id') setattr(payload, 'clientMutationId', client_mutation_id) return payload