This commit is contained in:
Tony Angerilli 2017-01-11 14:02:54 +00:00 committed by GitHub
commit 54bc03f93e
4 changed files with 53 additions and 26 deletions

View File

@ -12,6 +12,7 @@ from ..types import (AbstractType, Boolean, Enum, Int, Interface, List, NonNull,
from ..types.field import Field from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options from ..types.options import Options
from ..types import Union
from ..utils.is_base_type import is_base_type from ..utils.is_base_type import is_base_type
from ..utils.props import props from ..utils.props import props
from .node import is_node from .node import is_node
@ -109,7 +110,7 @@ class IterableConnectionField(Field):
@property @property
def type(self): def type(self):
type = super(IterableConnectionField, self).type type = super(IterableConnectionField, self).type
if is_node(type): if issubclass(type, Union) or is_node(type):
connection_type = type.Connection connection_type = type.Connection
else: else:
connection_type = type connection_type = type

View File

@ -1,24 +0,0 @@
# https://github.com/graphql-python/graphene/issues/356
import pytest
import graphene
from graphene import relay
class SomeTypeOne(graphene.ObjectType):
pass
class SomeTypeTwo(graphene.ObjectType):
pass
class MyUnion(graphene.Union):
class Meta:
types = (SomeTypeOne, SomeTypeTwo)
def test_issue():
with pytest.raises(Exception) as exc_info:
class Query(graphene.ObjectType):
things = relay.ConnectionField(MyUnion)
schema = graphene.Schema(query=Query)
assert str(exc_info.value) == 'IterableConnectionField type have to be a subclass of Connection. Received "MyUnion".'

View File

@ -1,7 +1,9 @@
import pytest import pytest
from ..objecttype import ObjectType from ..objecttype import ObjectType
from ..schema import Schema
from ..union import Union from ..union import Union
from graphene.relay.connection import ConnectionField
class MyObjectType1(ObjectType): class MyObjectType1(ObjectType):
@ -41,3 +43,33 @@ def test_generate_union_with_no_types():
pass pass
assert str(exc_info.value) == 'Must provide types for Union MyUnion.' assert str(exc_info.value) == 'Must provide types for Union MyUnion.'
def test_union_as_connection():
class MyUnion(Union):
class Meta:
types = (MyObjectType1, MyObjectType2)
class Query(ObjectType):
objects = ConnectionField(MyUnion)
def resolve_objects(self, args, context, info):
return [MyObjectType1(), MyObjectType2()]
query = '''
query {
objects {
edges {
node {
__typename
}
}
}
}
'''
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
assert len(result.data['objects']['edges']) == 2
assert result.data['objects']['edges'][0]['node']['__typename'] == 'MyObjectType1'
assert result.data['objects']['edges'][1]['node']['__typename'] == 'MyObjectType2'

View File

@ -1,9 +1,19 @@
import six import six
from functools import partial
from ..utils.is_base_type import is_base_type from ..utils.is_base_type import is_base_type
from .options import Options from .options import Options
def get_default_connection(cls):
from graphene.relay.connection import Connection
class Meta:
node = cls
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
class UnionMeta(type): class UnionMeta(type):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
@ -24,7 +34,15 @@ class UnionMeta(type):
len(options.types) > 0 len(options.types) > 0
), 'Must provide types for Union {}.'.format(options.name) ), 'Must provide types for Union {}.'.format(options.name)
return type.__new__(cls, name, bases, dict(attrs, _meta=options)) cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
get_connection = getattr(cls, 'get_connection', None)
if not get_connection:
get_connection = partial(get_default_connection, cls)
cls.Connection = get_connection()
return cls
def __str__(cls): # noqa: N805 def __str__(cls): # noqa: N805
return cls._meta.name return cls._meta.name