Improved Mutations Input args

This commit is contained in:
Syrus Akbary 2015-10-25 22:30:35 -07:00
parent 2648a2300b
commit bc3d176b4e
3 changed files with 34 additions and 5 deletions

View File

@ -124,6 +124,13 @@ class Field(object):
','.join(extra_args.keys()) ','.join(extra_args.keys())
)) ))
args = self.args
object_type = self.get_object_type(schema)
if object_type and object_type._meta.mutation:
assert not self.args, 'Arguments provided for mutations are defined in Input class in Mutation'
args = object_type.input_type.fields_as_arguments(schema)
internal_type = self.internal_type(schema) internal_type = self.internal_type(schema)
if not internal_type: if not internal_type:
raise Exception("Internal type for field %s is None" % self) raise Exception("Internal type for field %s is None" % self)
@ -141,7 +148,7 @@ class Field(object):
return GraphQLField( return GraphQLField(
internal_type, internal_type,
description=description, description=description,
args=self.args, args=args,
resolver=resolver, resolver=resolver,
) )

View File

@ -5,7 +5,8 @@ from collections import OrderedDict
from graphql.core.type import ( from graphql.core.type import (
GraphQLObjectType, GraphQLObjectType,
GraphQLInterfaceType GraphQLInterfaceType,
GraphQLArgument
) )
from graphene import signals from graphene import signals
@ -58,6 +59,10 @@ class ObjectTypeMeta(type):
if new_class._meta.mutation: if new_class._meta.mutation:
assert hasattr(new_class, 'mutate'), "All mutations must implement mutate method" assert hasattr(new_class, 'mutate'), "All mutations must implement mutate method"
Input = getattr(new_class, 'Input', None)
if Input:
input_type = type('{}Input'.format(new_class._meta.type_name), (Input, ObjectType), Input.__dict__)
setattr(new_class, 'input_type', input_type)
new_class.add_extra_fields() new_class.add_extra_fields()
@ -141,6 +146,11 @@ class BaseObjectType(object):
if self.instance: if self.instance:
return getattr(self.instance, name) return getattr(self.instance, name)
@classmethod
def fields_as_arguments(cls, schema):
return OrderedDict([(f.field_name, GraphQLArgument(f.internal_type(schema)))
for f in cls._meta.fields])
@classmethod @classmethod
def resolve_objecttype(cls, schema, instance, *_): def resolve_objecttype(cls, schema, instance, *_):
return instance return instance

View File

@ -12,14 +12,14 @@ class Query(graphene.ObjectType):
class ChangeNumber(graphene.Mutation): class ChangeNumber(graphene.Mutation):
'''Result mutation''' '''Result mutation'''
class Input: class Input:
id = graphene.IntField(required=True) to = graphene.IntField()
result = graphene.StringField() result = graphene.StringField()
@classmethod @classmethod
def mutate(cls, instance, args, info): def mutate(cls, instance, args, info):
global my_id global my_id
my_id = my_id + 1 my_id = args.get('to', my_id + 1)
return ChangeNumber(result=my_id) return ChangeNumber(result=my_id)
@ -30,7 +30,13 @@ class MyResultMutation(graphene.ObjectType):
schema = Schema(query=Query, mutation=MyResultMutation) schema = Schema(query=Query, mutation=MyResultMutation)
def test_mutate(): def test_mutation_input():
assert ChangeNumber.input_type
assert ChangeNumber.input_type._meta.type_name == 'ChangeNumberInput'
assert list(ChangeNumber.input_type._meta.fields_map.keys()) == ['to']
def test_execute_mutations():
query = ''' query = '''
mutation M{ mutation M{
first: changeNumber { first: changeNumber {
@ -39,6 +45,9 @@ def test_mutate():
second: changeNumber { second: changeNumber {
result result
} }
third: changeNumber(to: 5) {
result
}
} }
''' '''
expected = { expected = {
@ -47,6 +56,9 @@ def test_mutate():
}, },
'second': { 'second': {
'result': '2', 'result': '2',
},
'third': {
'result': '5',
} }
} }
result = schema.execute(query, root=object()) result = schema.execute(query, root=object())