diff --git a/graphene/types/tests/test_union.py b/graphene/types/tests/test_union.py index c6e6825c..d7ba2f31 100644 --- a/graphene/types/tests/test_union.py +++ b/graphene/types/tests/test_union.py @@ -1,7 +1,9 @@ import pytest +from ..field import Field from ..objecttype import ObjectType from ..union import Union +from ..unmountedtype import UnmountedType class MyObjectType1(ObjectType): @@ -41,3 +43,15 @@ def test_generate_union_with_no_types(): pass assert str(exc_info.value) == 'Must provide types for Union MyUnion.' + + +def test_union_can_be_mounted(): + class MyUnion(Union): + class Meta: + types = (MyObjectType1, MyObjectType2) + + my_union_instance = MyUnion() + assert isinstance(my_union_instance, UnmountedType) + my_union_field = my_union_instance.mount_as(Field) + assert isinstance(my_union_field, Field) + assert my_union_field.type == MyUnion diff --git a/graphene/types/union.py b/graphene/types/union.py index 3d236000..e36086d0 100644 --- a/graphene/types/union.py +++ b/graphene/types/union.py @@ -2,6 +2,7 @@ import six from ..utils.is_base_type import is_base_type from .options import Options +from .unmountedtype import UnmountedType class UnionMeta(type): @@ -30,7 +31,7 @@ class UnionMeta(type): return cls._meta.name -class Union(six.with_metaclass(UnionMeta)): +class Union(six.with_metaclass(UnionMeta, UnmountedType)): ''' Union Type Definition @@ -39,11 +40,16 @@ class Union(six.with_metaclass(UnionMeta)): to determine which type is actually used when the field is resolved. ''' + @classmethod + def get_type(cls): + ''' + This function is called when the unmounted type (Union instance) + is mounted (as a Field, InputField or Argument) + ''' + return cls + @classmethod def resolve_type(cls, instance, context, info): from .objecttype import ObjectType if isinstance(instance, ObjectType): return type(instance) - - def __init__(self, *args, **kwargs): - raise Exception("A Union cannot be intitialized")