From 03898007a52d1facb5cc20157544c0cc0f97a8fe Mon Sep 17 00:00:00 2001 From: Tony Angerilli Date: Thu, 10 Nov 2016 16:48:54 -0800 Subject: [PATCH] allow unions to be used in connections --- graphene/relay/connection.py | 3 ++- graphene/types/union.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 46dbba98..1f332bdd 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -12,6 +12,7 @@ from ..types import (AbstractType, Boolean, Enum, Int, Interface, List, NonNull, from ..types.field import Field from ..types.objecttype import ObjectType, ObjectTypeMeta from ..types.options import Options +from ..types import Union from ..utils.is_base_type import is_base_type from ..utils.props import props from .node import is_node @@ -109,7 +110,7 @@ class IterableConnectionField(Field): @property def type(self): type = super(IterableConnectionField, self).type - if is_node(type): + if issubclass(type, Union) or is_node(type): connection_type = type.Connection else: connection_type = type diff --git a/graphene/types/union.py b/graphene/types/union.py index 622f465e..3b2b0d6b 100644 --- a/graphene/types/union.py +++ b/graphene/types/union.py @@ -24,7 +24,21 @@ class UnionMeta(type): len(options.types) > 0 ), 'Must provide types for Union {}.'.format(options.name) - return type.__new__(cls, name, bases, dict(attrs, _meta=options)) + cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) + + get_connection = getattr(cls, 'get_connection', None) + if not get_connection: + from graphene.relay.connection import Connection + + class DefaultUnionConnection(Connection): + class Meta: + node = cls + + cls.Connection = DefaultUnionConnection + else: + cls.Connection = get_connection() + + return cls def __str__(cls): # noqa: N805 return cls._meta.name