From 6817761a0892f73f6b837115ff5ec0821392cf46 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Tue, 15 Nov 2016 22:31:46 -0800 Subject: [PATCH] Fixed Union and Interface resolve_type when the field is a List/NonNull --- graphene/types/tests/test_query.py | 95 ++++++++++++++++++++++++++++++ graphene/types/typemap.py | 14 ++--- 2 files changed, 102 insertions(+), 7 deletions(-) diff --git a/graphene/types/tests/test_query.py b/graphene/types/tests/test_query.py index b1c5d112..daeb63e8 100644 --- a/graphene/types/tests/test_query.py +++ b/graphene/types/tests/test_query.py @@ -4,12 +4,14 @@ from functools import partial from graphql import Source, execute, parse, GraphQLError from ..field import Field +from ..interface import Interface from ..inputfield import InputField from ..inputobjecttype import InputObjectType from ..objecttype import ObjectType from ..scalars import Int, String from ..schema import Schema from ..structures import List +from ..union import Union from ..dynamic import Dynamic @@ -24,6 +26,99 @@ def test_query(): assert executed.data == {'hello': 'World'} +def test_query_union(): + class one_object(object): + pass + + class two_object(object): + pass + + class One(ObjectType): + one = String() + + @classmethod + def is_type_of(cls, root, context, info): + return isinstance(root, one_object) + + class Two(ObjectType): + two = String() + + @classmethod + def is_type_of(cls, root, context, info): + return isinstance(root, two_object) + + class MyUnion(Union): + class Meta: + types = (One, Two) + + class Query(ObjectType): + unions = List(MyUnion) + + def resolve_unions(self, args, context, info): + return [one_object(), two_object()] + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ unions { __typename } }') + assert not executed.errors + assert executed.data == { + 'unions': [{ + '__typename': 'One' + }, { + '__typename': 'Two' + }] + } + + +def test_query_interface(): + class one_object(object): + pass + + class two_object(object): + pass + + class MyInterface(Interface): + base = String() + + class One(ObjectType): + class Meta: + interfaces = (MyInterface, ) + + one = String() + + @classmethod + def is_type_of(cls, root, context, info): + return isinstance(root, one_object) + + class Two(ObjectType): + class Meta: + interfaces = (MyInterface, ) + + two = String() + + @classmethod + def is_type_of(cls, root, context, info): + return isinstance(root, two_object) + + class Query(ObjectType): + interfaces = List(MyInterface) + + def resolve_interfaces(self, args, context, info): + return [one_object(), two_object()] + + hello_schema = Schema(Query, types=[One, Two]) + + executed = hello_schema.execute('{ interfaces { __typename } }') + assert not executed.errors + assert executed.data == { + 'interfaces': [{ + '__typename': 'One' + }, { + '__typename': 'Two' + }] + } + + def test_query_dynamic(): class Query(ObjectType): hello = Dynamic(lambda: String(resolver=lambda *_: 'World')) diff --git a/graphene/types/typemap.py b/graphene/types/typemap.py index d7a7f23e..3aa80145 100644 --- a/graphene/types/typemap.py +++ b/graphene/types/typemap.py @@ -30,18 +30,18 @@ def is_graphene_type(_type): return True -def resolve_type(resolve_type_func, map, root, context, info): +def resolve_type(resolve_type_func, map, type_name, root, context, info): _type = resolve_type_func(root, context, info) - # assert inspect.isclass(_type) and issubclass(_type, ObjectType), ( - # 'Received incompatible type "{}".'.format(_type) - # ) + if not _type: - return get_default_resolve_type_fn(root, context, info, info.return_type) + return_type = map[type_name] + return get_default_resolve_type_fn(root, context, info, return_type) if inspect.isclass(_type) and issubclass(_type, ObjectType): graphql_type = map.get(_type._meta.name) assert graphql_type and graphql_type.graphene_type == _type return graphql_type + return _type @@ -151,7 +151,7 @@ class TypeMap(GraphQLTypeMap): from .definitions import GrapheneInterfaceType _resolve_type = None if type.resolve_type: - _resolve_type = partial(resolve_type, type.resolve_type, map) + _resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name) map[type._meta.name] = GrapheneInterfaceType( graphene_type=type, name=type._meta.name, @@ -178,7 +178,7 @@ class TypeMap(GraphQLTypeMap): from .definitions import GrapheneUnionType _resolve_type = None if type.resolve_type: - _resolve_type = partial(resolve_type, type.resolve_type, map) + _resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name) types = [] for i in type._meta.types: map = self.construct_objecttype(map, i)