This commit is contained in:
Tony Angerilli 2017-01-11 14:02:54 +00:00 committed by GitHub
commit 54bc03f93e
4 changed files with 53 additions and 26 deletions

View File

@ -12,6 +12,7 @@ from ..types import (AbstractType, Boolean, Enum, Int, Interface, List, NonNull,
from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options
from ..types import Union
from ..utils.is_base_type import is_base_type
from ..utils.props import props
from .node import is_node
@ -109,7 +110,7 @@ class IterableConnectionField(Field):
@property
def type(self):
type = super(IterableConnectionField, self).type
if is_node(type):
if issubclass(type, Union) or is_node(type):
connection_type = type.Connection
else:
connection_type = type

View File

@ -1,24 +0,0 @@
# https://github.com/graphql-python/graphene/issues/356
import pytest
import graphene
from graphene import relay
class SomeTypeOne(graphene.ObjectType):
pass
class SomeTypeTwo(graphene.ObjectType):
pass
class MyUnion(graphene.Union):
class Meta:
types = (SomeTypeOne, SomeTypeTwo)
def test_issue():
with pytest.raises(Exception) as exc_info:
class Query(graphene.ObjectType):
things = relay.ConnectionField(MyUnion)
schema = graphene.Schema(query=Query)
assert str(exc_info.value) == 'IterableConnectionField type have to be a subclass of Connection. Received "MyUnion".'

View File

@ -1,7 +1,9 @@
import pytest
from ..objecttype import ObjectType
from ..schema import Schema
from ..union import Union
from graphene.relay.connection import ConnectionField
class MyObjectType1(ObjectType):
@ -41,3 +43,33 @@ def test_generate_union_with_no_types():
pass
assert str(exc_info.value) == 'Must provide types for Union MyUnion.'
def test_union_as_connection():
class MyUnion(Union):
class Meta:
types = (MyObjectType1, MyObjectType2)
class Query(ObjectType):
objects = ConnectionField(MyUnion)
def resolve_objects(self, args, context, info):
return [MyObjectType1(), MyObjectType2()]
query = '''
query {
objects {
edges {
node {
__typename
}
}
}
}
'''
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
assert len(result.data['objects']['edges']) == 2
assert result.data['objects']['edges'][0]['node']['__typename'] == 'MyObjectType1'
assert result.data['objects']['edges'][1]['node']['__typename'] == 'MyObjectType2'

View File

@ -1,9 +1,19 @@
import six
from functools import partial
from ..utils.is_base_type import is_base_type
from .options import Options
def get_default_connection(cls):
from graphene.relay.connection import Connection
class Meta:
node = cls
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
class UnionMeta(type):
def __new__(cls, name, bases, attrs):
@ -24,7 +34,15 @@ class UnionMeta(type):
len(options.types) > 0
), 'Must provide types for Union {}.'.format(options.name)
return type.__new__(cls, name, bases, dict(attrs, _meta=options))
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
get_connection = getattr(cls, 'get_connection', None)
if not get_connection:
get_connection = partial(get_default_connection, cls)
cls.Connection = get_connection()
return cls
def __str__(cls): # noqa: N805
return cls._meta.name