From 9c27db7ed55f90a4d348ca76943c8fa6eb0eb2f7 Mon Sep 17 00:00:00 2001 From: Nathaniel Parrish Date: Tue, 7 Nov 2017 09:06:36 -0800 Subject: [PATCH] Handle complex input types --- graphene/types/inputobjecttype.py | 25 +++++++++++++++++++++++-- graphene/types/tests/test_typemap.py | 24 +++++++++++++++++++++--- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/graphene/types/inputobjecttype.py b/graphene/types/inputobjecttype.py index 9b9b646b..a162ca3c 100644 --- a/graphene/types/inputobjecttype.py +++ b/graphene/types/inputobjecttype.py @@ -1,11 +1,14 @@ from collections import OrderedDict from .base import BaseOptions, BaseType +from .field import Field from .inputfield import InputField +from .objecttype import ObjectType +from .scalars import Scalar +from .structures import List, NonNull from .unmountedtype import UnmountedType from .utils import yank_fields_from_attrs - # For static type checking with Mypy MYPY = False if MYPY: @@ -24,11 +27,29 @@ class InputObjectTypeContainer(dict, BaseType): def __init__(self, *args, **kwargs): dict.__init__(self, *args, **kwargs) for key in self._meta.fields.keys(): - setattr(self, key, self.get(key, None)) + field = getattr(self, key, None) + if field is None or self.get(key, None) is None: + value = None + else: + value = InputObjectTypeContainer._get_typed_field_value(field, self[key]) + setattr(self, key, value) def __init_subclass__(cls, *args, **kwargs): pass + @staticmethod + def _get_typed_field_value(field_or_type, value): + if isinstance(field_or_type, NonNull): + return InputObjectTypeContainer._get_typed_field_value(field_or_type.of_type, value) + elif isinstance(field_or_type, List): + return [ + InputObjectTypeContainer._get_typed_field_value(field_or_type.of_type, v) + for v in value + ] + elif hasattr(field_or_type, '_meta') and hasattr(field_or_type._meta, 'container'): + return field_or_type._meta.container(value) + else: + return value class InputObjectType(UnmountedType, BaseType): ''' diff --git a/graphene/types/tests/test_typemap.py b/graphene/types/tests/test_typemap.py index 082f25bd..fe8e99b6 100644 --- a/graphene/types/tests/test_typemap.py +++ b/graphene/types/tests/test_typemap.py @@ -4,6 +4,7 @@ from graphql.type import (GraphQLArgument, GraphQLEnumType, GraphQLEnumValue, GraphQLInputObjectType, GraphQLInterfaceType, GraphQLObjectType, GraphQLString) +from ..structures import List, NonNull from ..dynamic import Dynamic from ..enum import Enum from ..field import Field @@ -11,7 +12,7 @@ from ..inputfield import InputField from ..inputobjecttype import InputObjectType from ..interface import Interface from ..objecttype import ObjectType -from ..scalars import String +from ..scalars import String, Int from ..typemap import TypeMap @@ -119,10 +120,18 @@ def test_interface(): def test_inputobject(): + class OtherObjectType(InputObjectType): + thingy = NonNull(Int) + + class MyInnerObjectType(InputObjectType): + some_field = String() + some_other_field = List(OtherObjectType) + class MyInputObjectType(InputObjectType): '''Description''' foo_bar = String(description='Field description') bar = String(name='gizmo') + baz = NonNull(MyInnerObjectType) own = InputField(lambda: MyInputObjectType) def resolve_foo_bar(self, args, info): @@ -136,14 +145,23 @@ def test_inputobject(): assert graphql_type.description == 'Description' # Container - container = graphql_type.create_container({'bar': 'oh!'}) + container = graphql_type.create_container({ + 'bar': 'oh!', + 'baz': { + 'some_other_field': [{'thingy': 1}, {'thingy': 2}] + } + }) assert isinstance(container, MyInputObjectType) assert 'bar' in container assert container.bar == 'oh!' assert 'foo_bar' not in container + assert container.foo_bar is None + assert container.baz.some_field is None + assert container.baz.some_other_field[0].thingy == 1 + assert container.baz.some_other_field[1].thingy == 2 fields = graphql_type.fields - assert list(fields.keys()) == ['fooBar', 'gizmo', 'own'] + assert list(fields.keys()) == ['fooBar', 'gizmo', 'baz', 'own'] own_field = fields['own'] assert own_field.type == graphql_type foo_field = fields['fooBar']