diff --git a/graphene/core/fields.py b/graphene/core/fields.py index 677c7b65..6bf211e9 100644 --- a/graphene/core/fields.py +++ b/graphene/core/fields.py @@ -124,6 +124,13 @@ class Field(object): ','.join(extra_args.keys()) )) + args = self.args + + object_type = self.get_object_type(schema) + if object_type and object_type._meta.mutation: + assert not self.args, 'Arguments provided for mutations are defined in Input class in Mutation' + args = object_type.input_type.fields_as_arguments(schema) + internal_type = self.internal_type(schema) if not internal_type: raise Exception("Internal type for field %s is None" % self) @@ -141,7 +148,7 @@ class Field(object): return GraphQLField( internal_type, description=description, - args=self.args, + args=args, resolver=resolver, ) diff --git a/graphene/core/types.py b/graphene/core/types.py index 0de9ead4..34e17c0d 100644 --- a/graphene/core/types.py +++ b/graphene/core/types.py @@ -5,7 +5,8 @@ from collections import OrderedDict from graphql.core.type import ( GraphQLObjectType, - GraphQLInterfaceType + GraphQLInterfaceType, + GraphQLArgument ) from graphene import signals @@ -58,6 +59,10 @@ class ObjectTypeMeta(type): if new_class._meta.mutation: assert hasattr(new_class, 'mutate'), "All mutations must implement mutate method" + Input = getattr(new_class, 'Input', None) + if Input: + input_type = type('{}Input'.format(new_class._meta.type_name), (Input, ObjectType), Input.__dict__) + setattr(new_class, 'input_type', input_type) new_class.add_extra_fields() @@ -141,6 +146,11 @@ class BaseObjectType(object): if self.instance: return getattr(self.instance, name) + @classmethod + def fields_as_arguments(cls, schema): + return OrderedDict([(f.field_name, GraphQLArgument(f.internal_type(schema))) + for f in cls._meta.fields]) + @classmethod def resolve_objecttype(cls, schema, instance, *_): return instance diff --git a/tests/core/test_mutations.py b/tests/core/test_mutations.py index 25b94161..ad87a228 100644 --- a/tests/core/test_mutations.py +++ b/tests/core/test_mutations.py @@ -12,14 +12,14 @@ class Query(graphene.ObjectType): class ChangeNumber(graphene.Mutation): '''Result mutation''' class Input: - id = graphene.IntField(required=True) + to = graphene.IntField() result = graphene.StringField() @classmethod def mutate(cls, instance, args, info): global my_id - my_id = my_id + 1 + my_id = args.get('to', my_id + 1) return ChangeNumber(result=my_id) @@ -30,7 +30,13 @@ class MyResultMutation(graphene.ObjectType): schema = Schema(query=Query, mutation=MyResultMutation) -def test_mutate(): +def test_mutation_input(): + assert ChangeNumber.input_type + assert ChangeNumber.input_type._meta.type_name == 'ChangeNumberInput' + assert list(ChangeNumber.input_type._meta.fields_map.keys()) == ['to'] + + +def test_execute_mutations(): query = ''' mutation M{ first: changeNumber { @@ -39,6 +45,9 @@ def test_mutate(): second: changeNumber { result } + third: changeNumber(to: 5) { + result + } } ''' expected = { @@ -47,6 +56,9 @@ def test_mutate(): }, 'second': { 'result': '2', + }, + 'third': { + 'result': '5', } } result = schema.execute(query, root=object())