diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index f85b675f..3e7950bb 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -12,6 +12,7 @@ from ..types import (AbstractType, Boolean, Enum, Int, Interface, List, NonNull, from ..types.field import Field from ..types.objecttype import ObjectType, ObjectTypeMeta from ..types.options import Options +from ..types import Union from ..utils.is_base_type import is_base_type from ..utils.props import props from .node import is_node @@ -109,7 +110,7 @@ class IterableConnectionField(Field): @property def type(self): type = super(IterableConnectionField, self).type - if is_node(type): + if issubclass(type, Union) or is_node(type): connection_type = type.Connection else: connection_type = type diff --git a/graphene/tests/issues/test_356.py b/graphene/tests/issues/test_356.py deleted file mode 100644 index 605594e1..00000000 --- a/graphene/tests/issues/test_356.py +++ /dev/null @@ -1,24 +0,0 @@ -# https://github.com/graphql-python/graphene/issues/356 - -import pytest -import graphene -from graphene import relay - -class SomeTypeOne(graphene.ObjectType): - pass - -class SomeTypeTwo(graphene.ObjectType): - pass - -class MyUnion(graphene.Union): - class Meta: - types = (SomeTypeOne, SomeTypeTwo) - -def test_issue(): - with pytest.raises(Exception) as exc_info: - class Query(graphene.ObjectType): - things = relay.ConnectionField(MyUnion) - - schema = graphene.Schema(query=Query) - - assert str(exc_info.value) == 'IterableConnectionField type have to be a subclass of Connection. Received "MyUnion".' diff --git a/graphene/types/tests/test_union.py b/graphene/types/tests/test_union.py index c6e6825c..3b07cebf 100644 --- a/graphene/types/tests/test_union.py +++ b/graphene/types/tests/test_union.py @@ -1,7 +1,9 @@ import pytest from ..objecttype import ObjectType +from ..schema import Schema from ..union import Union +from graphene.relay.connection import ConnectionField class MyObjectType1(ObjectType): @@ -41,3 +43,33 @@ def test_generate_union_with_no_types(): pass assert str(exc_info.value) == 'Must provide types for Union MyUnion.' + + +def test_union_as_connection(): + class MyUnion(Union): + class Meta: + types = (MyObjectType1, MyObjectType2) + + class Query(ObjectType): + objects = ConnectionField(MyUnion) + + def resolve_objects(self, args, context, info): + return [MyObjectType1(), MyObjectType2()] + + query = ''' + query { + objects { + edges { + node { + __typename + } + } + } + } + ''' + schema = Schema(query=Query) + result = schema.execute(query) + assert not result.errors + assert len(result.data['objects']['edges']) == 2 + assert result.data['objects']['edges'][0]['node']['__typename'] == 'MyObjectType1' + assert result.data['objects']['edges'][1]['node']['__typename'] == 'MyObjectType2' diff --git a/graphene/types/union.py b/graphene/types/union.py index 3d236000..20cb38c9 100644 --- a/graphene/types/union.py +++ b/graphene/types/union.py @@ -1,9 +1,19 @@ import six +from functools import partial from ..utils.is_base_type import is_base_type from .options import Options +def get_default_connection(cls): + from graphene.relay.connection import Connection + + class Meta: + node = cls + + return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta}) + + class UnionMeta(type): def __new__(cls, name, bases, attrs): @@ -24,7 +34,15 @@ class UnionMeta(type): len(options.types) > 0 ), 'Must provide types for Union {}.'.format(options.name) - return type.__new__(cls, name, bases, dict(attrs, _meta=options)) + cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) + + get_connection = getattr(cls, 'get_connection', None) + if not get_connection: + get_connection = partial(get_default_connection, cls) + + cls.Connection = get_connection() + + return cls def __str__(cls): # noqa: N805 return cls._meta.name