Support with_context on ClientIDMutation.mutate_and_get_payload.

This commit is contained in:
Marc Tamlyn 2016-05-18 11:50:03 +01:00
parent 398088a0c4
commit 61e7beee7b
2 changed files with 61 additions and 4 deletions

View File

@ -1,10 +1,11 @@
from graphql.type import GraphQLInputObjectField from graphql.type import GraphQLInputObjectField
import graphene import graphene
from graphene import relay from graphene import relay, with_context
from graphene.core.schema import Schema from graphene.core.schema import Schema
my_id = 0 my_id = 0
my_id_context = 0
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
@ -25,8 +26,24 @@ class ChangeNumber(relay.ClientIDMutation):
return ChangeNumber(result=my_id) return ChangeNumber(result=my_id)
class ChangeNumberContext(relay.ClientIDMutation):
'''Result mutation'''
class Input:
to = graphene.Int()
result = graphene.String()
@classmethod
@with_context
def mutate_and_get_payload(cls, input, context, info):
global my_id_context
my_id_context = input.get('to', my_id_context + context)
return ChangeNumber(result=my_id_context)
class MyResultMutation(graphene.ObjectType): class MyResultMutation(graphene.ObjectType):
change_number = graphene.Field(ChangeNumber) change_number = graphene.Field(ChangeNumber)
change_number_context = graphene.Field(ChangeNumberContext)
schema = Schema(query=Query, mutation=MyResultMutation) schema = Schema(query=Query, mutation=MyResultMutation)
@ -79,3 +96,39 @@ def test_execute_mutations():
result = schema.execute(query, root_value=object()) result = schema.execute(query, root_value=object())
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_context_mutations():
query = '''
mutation M{
first: changeNumberContext(input: {clientMutationId: "mutation1"}) {
clientMutationId
result
},
second: changeNumberContext(input: {clientMutationId: "mutation2"}) {
clientMutationId
result
}
third: changeNumberContext(input: {clientMutationId: "mutation3", to: 5}) {
result
clientMutationId
}
}
'''
expected = {
'first': {
'clientMutationId': 'mutation1',
'result': '-1',
},
'second': {
'clientMutationId': 'mutation2',
'result': '-2',
},
'third': {
'clientMutationId': 'mutation3',
'result': '5',
}
}
result = schema.execute(query, root_value=object(), context_value=-1)
assert not result.errors
assert result.data == expected

View File

@ -15,7 +15,7 @@ from ..core.types import Boolean, Field, List, String
from ..core.types.argument import ArgumentsGroup from ..core.types.argument import ArgumentsGroup
from ..core.types.definitions import NonNull from ..core.types.definitions import NonNull
from ..utils import memoize from ..utils import memoize
from ..utils.wrap_resolver_function import has_context from ..utils.wrap_resolver_function import has_context, with_context
from .fields import GlobalIDField from .fields import GlobalIDField
@ -192,9 +192,13 @@ class ClientIDMutation(six.with_metaclass(RelayMutationMeta, Mutation)):
abstract = True abstract = True
@classmethod @classmethod
def mutate(cls, instance, args, info): @with_context
def mutate(cls, instance, args, context, info):
input = args.get('input') input = args.get('input')
payload = cls.mutate_and_get_payload(input, info) if has_context(cls.mutate_and_get_payload):
payload = cls.mutate_and_get_payload(input, context, info)
else:
payload = cls.mutate_and_get_payload(input, info)
client_mutation_id = input.get('clientMutationId') or input.get('client_mutation_id') client_mutation_id = input.get('clientMutationId') or input.get('client_mutation_id')
setattr(payload, 'clientMutationId', client_mutation_id) setattr(payload, 'clientMutationId', client_mutation_id)
return payload return payload