diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index f4a0f5a0..c4d969b4 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -1,3 +1,5 @@ +import inspect + from .base import BaseOptions, BaseType, BaseTypeMeta from .field import Field from .interface import Interface @@ -137,7 +139,7 @@ class ObjectType(BaseType, metaclass=ObjectTypeMeta): fields = {} for interface in interfaces: - assert issubclass( + assert inspect.isclass(interface) and issubclass( interface, Interface ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".' fields.update(interface._meta.fields) diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 2d1eaf6b..f40975b6 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -258,7 +258,8 @@ class TypeMap(dict): union_types = [] for graphene_objecttype in graphene_type._meta.types: object_type = create_graphql_type(graphene_objecttype) - assert object_type.graphene_type == graphene_objecttype + if hasattr(object_type, "graphene_type"): + assert object_type.graphene_type == graphene_objecttype union_types.append(object_type) return union_types diff --git a/graphene/types/tests/test_type_map.py b/graphene/types/tests/test_type_map.py index 41327211..25974ca4 100644 --- a/graphene/types/tests/test_type_map.py +++ b/graphene/types/tests/test_type_map.py @@ -1,5 +1,6 @@ from textwrap import dedent +import pytest from graphql.type import ( GraphQLArgument, GraphQLEnumType, @@ -22,6 +23,7 @@ from ..objecttype import ObjectType from ..scalars import Int, String from ..structures import List, NonNull from ..schema import Schema +from ..union import Union def create_type_map(types, auto_camelcase=True): @@ -313,3 +315,80 @@ def test_graphql_type(): ) assert not results.errors assert results.data == {"graphqlType": {"hello": "world"}} + + +def test_graphql_type_interface(): + MyGraphQLInterface = GraphQLInterfaceType( + name="MyGraphQLType", + fields={ + "hello": GraphQLField(GraphQLString, resolve=lambda obj, info: "world") + }, + ) + + with pytest.raises(AssertionError) as error: + + class MyGrapheneType(ObjectType): + class Meta: + interfaces = (MyGraphQLInterface,) + + assert str(error.value) == ( + "All interfaces of MyGrapheneType must be a subclass of Interface. " + 'Received "MyGraphQLType".' + ) + + +def test_graphql_type_union(): + MyGraphQLType = GraphQLObjectType( + name="MyGraphQLType", + fields={ + "hello": GraphQLField(GraphQLString, resolve=lambda obj, info: "world") + }, + ) + + class MyGrapheneType(ObjectType): + hi = String(default_value="world") + + class MyUnion(Union): + class Meta: + types = (MyGraphQLType, MyGrapheneType) + + @classmethod + def resolve_type(cls, instance, info): + return MyGraphQLType + + class Query(ObjectType): + my_union = Field(MyUnion) + + def resolve_my_union(root, info): + return {} + + schema = Schema(query=Query) + assert str(schema) == dedent( + """\ + type Query { + myUnion: MyUnion + } + + union MyUnion = MyGraphQLType | MyGrapheneType + + type MyGraphQLType { + hello: String + } + + type MyGrapheneType { + hi: String + } + """ + ) + + results = schema.execute( + """ + query { + myUnion { + __typename + } + } + """ + ) + assert not results.errors + assert results.data == {"myUnion": {"__typename": "MyGraphQLType"}}