Added Union. Improved testing and code

This commit is contained in:
Syrus Akbary 2016-08-13 10:35:31 -07:00
parent 0802aaced0
commit 84c1da60dd
6 changed files with 159 additions and 153 deletions

View File

@ -16,7 +16,7 @@ class AbstractTypeMeta(type):
for base in bases:
if not issubclass(base, AbstractType) and issubclass(type(base), AbstractTypeMeta):
# raise Exception('You can only')
# raise Exception('You can only extend AbstractTypes after the base definition.')
return type.__new__(cls, name, bases, attrs)
attrs = merge_fields_in_attrs(bases, attrs)

View File

@ -2,15 +2,17 @@ from collections import OrderedDict
from py.test import raises
from ..abstracttype import AbstractType
from ..objecttype import ObjectType
from ..interface import Interface
from ..union import Union
from ..scalars import String, Int, Boolean
from ..field import Field
from ..structures import List
from ..enum import Enum
from ..inputobjecttype import InputObjectType
from ..inputfield import InputField
from ..structures import List, NonNull
from ..enum import Enum
from ..argument import Argument
from ..inputobjecttype import InputObjectType
from ..schema import Schema
@ -116,91 +118,87 @@ def test_defines_a_subscription_schema():
assert subscription.type._meta.name == 'Article'
# def test_includes_nested_input_objects_in_the_map():
# NestedInputObject = GraphQLInputObjectType(
# name='NestedInputObject',
# fields={'value': GraphQLInputObjectField(GraphQLString)}
# )
def test_includes_nested_input_objects_in_the_map():
class NestedInputObject(InputObjectType):
value = String()
# SomeInputObject = GraphQLInputObjectType(
# name='SomeInputObject',
# fields={'nested': GraphQLInputObjectField(NestedInputObject)}
# )
class SomeInputObject(InputObjectType):
nested = InputField(NestedInputObject)
# SomeMutation = GraphQLObjectType(
# name='SomeMutation',
# fields={
# 'mutateSomething': GraphQLField(
# type=BlogArticle,
# args={
# 'input': GraphQLArgument(SomeInputObject)
# }
# )
# }
# )
# SomeSubscription = GraphQLObjectType(
# name='SomeSubscription',
# fields={
# 'subscribeToSomething': GraphQLField(
# type=BlogArticle,
# args={
# 'input': GraphQLArgument(SomeInputObject)
# }
# )
# }
# )
class SomeMutation(Mutation):
mutate_something = Field(Article, input=Argument(SomeInputObject))
# schema = GraphQLSchema(
# query=BlogQuery,
# mutation=SomeMutation,
# subscription=SomeSubscription
# )
class SomeSubscription(Mutation):
subscribe_to_something = Field(Article, input=Argument(SomeInputObject))
# assert schema.get_type_map()['NestedInputObject'] is NestedInputObject
schema = Schema(
query=Query,
mutation=SomeMutation,
subscription=SomeSubscription
)
print schema.get_type_map()
assert schema.get_type_map()['NestedInputObject'].graphene_type is NestedInputObject
# def test_includes_interfaces_thunk_subtypes_in_the_type_map():
# SomeInterface = GraphQLInterfaceType(
# name='SomeInterface',
# fields={
# 'f': GraphQLField(GraphQLInt)
# }
# )
def test_includes_interfaces_thunk_subtypes_in_the_type_map():
class SomeInterface(Interface):
f = Int()
# SomeSubtype = GraphQLObjectType(
# name='SomeSubtype',
# fields={
# 'f': GraphQLField(GraphQLInt)
# },
# interfaces=lambda: [SomeInterface],
# is_type_of=lambda: True
# )
class SomeSubtype(ObjectType):
class Meta:
interfaces = (SomeInterface, )
# schema = GraphQLSchema(query=GraphQLObjectType(
# name='Query',
# fields={
# 'iface': GraphQLField(SomeInterface)
# }
# ), types=[SomeSubtype])
class Query(ObjectType):
iface = Field(lambda: SomeInterface)
# assert schema.get_type_map()['SomeSubtype'] is SomeSubtype
schema = Schema(
query=Query,
types=[SomeSubtype]
)
assert schema.get_type_map()['SomeSubtype'].graphene_type is SomeSubtype
# def test_includes_interfaces_subtypes_in_the_type_map():
# SomeInterface = GraphQLInterfaceType('SomeInterface', fields={'f': GraphQLField(GraphQLInt)})
# SomeSubtype = GraphQLObjectType(
# name='SomeSubtype',
# fields={'f': GraphQLField(GraphQLInt)},
# interfaces=[SomeInterface],
# is_type_of=lambda: None
# )
# schema = GraphQLSchema(
# query=GraphQLObjectType(
# name='Query',
# fields={
# 'iface': GraphQLField(SomeInterface)}),
# types=[SomeSubtype])
# assert schema.get_type_map()['SomeSubtype'] == SomeSubtype
def test_includes_types_in_union():
class SomeType(ObjectType):
a = String()
class OtherType(ObjectType):
b = String()
class MyUnion(Union):
class Meta:
types = (SomeType, OtherType)
class Query(ObjectType):
union = Field(MyUnion)
schema = Schema(
query=Query,
)
assert schema.get_type_map()['OtherType'].graphene_type is OtherType
assert schema.get_type_map()['SomeType'].graphene_type is SomeType
def test_includes_interfaces_subtypes_in_the_type_map():
class SomeInterface(Interface):
f = Int()
class SomeSubtype(ObjectType):
class Meta:
interfaces = (SomeInterface, )
class Query(ObjectType):
iface = Field(SomeInterface)
schema = Schema(
query=Query,
types=[SomeSubtype]
)
assert schema.get_type_map()['SomeSubtype'].graphene_type is SomeSubtype
def test_stringifies_simple_types():
@ -274,74 +272,36 @@ def test_stringifies_simple_types():
# == str(excinfo.value)
# def test_does_not_mutate_passed_field_definitions():
# fields = {
# 'field1': GraphQLField(GraphQLString),
# 'field2': GraphQLField(GraphQLString, args={'id': GraphQLArgument(GraphQLString)}),
# }
def test_does_not_mutate_passed_field_definitions():
class CommonFields(AbstractType):
field1 = String()
field2 = String(id=String())
# TestObject1 = GraphQLObjectType(name='Test1', fields=fields)
# TestObject2 = GraphQLObjectType(name='Test1', fields=fields)
class TestObject1(CommonFields, ObjectType):
pass
# assert TestObject1.get_fields() == TestObject2.get_fields()
# assert fields == {
# 'field1': GraphQLField(GraphQLString),
# 'field2': GraphQLField(GraphQLString, args={'id': GraphQLArgument(GraphQLString)}),
# }
class TestObject2(CommonFields, ObjectType):
pass
# input_fields = {
# 'field1': GraphQLInputObjectField(GraphQLString),
# 'field2': GraphQLInputObjectField(GraphQLString),
# }
assert TestObject1._meta.fields == TestObject2._meta.fields
assert CommonFields._meta.fields == {
'field1': String(),
'field2': String(id=String()),
}
# TestInputObject1 = GraphQLInputObjectType(name='Test1', fields=input_fields)
# TestInputObject2 = GraphQLInputObjectType(name='Test2', fields=input_fields)
class CommonFields(AbstractType):
field1 = String()
field2 = String()
# assert TestInputObject1.get_fields() == TestInputObject2.get_fields()
class TestInputObject1(CommonFields, InputObjectType):
pass
# assert input_fields == {
# 'field1': GraphQLInputObjectField(GraphQLString),
# 'field2': GraphQLInputObjectField(GraphQLString),
# }
class TestInputObject2(CommonFields, InputObjectType):
pass
assert TestInputObject1._meta.fields == TestInputObject2._meta.fields
# def test_sorts_fields_and_argument_keys_if_not_using_ordered_dict():
# fields = {
# 'b': GraphQLField(GraphQLString),
# 'c': GraphQLField(GraphQLString),
# 'a': GraphQLField(GraphQLString),
# 'd': GraphQLField(GraphQLString, args={
# 'q': GraphQLArgument(GraphQLString),
# 'x': GraphQLArgument(GraphQLString),
# 'v': GraphQLArgument(GraphQLString),
# 'a': GraphQLArgument(GraphQLString),
# 'n': GraphQLArgument(GraphQLString)
# })
# }
# test_object = GraphQLObjectType(name='Test', fields=fields)
# ordered_fields = test_object.get_fields()
# assert list(ordered_fields.keys()) == ['a', 'b', 'c', 'd']
# field_with_args = test_object.get_fields().get('d')
# assert [a.name for a in field_with_args.args] == ['a', 'n', 'q', 'v', 'x']
# def test_does_not_sort_fields_and_argument_keys_when_using_ordered_dict():
# fields = OrderedDict([
# ('b', GraphQLField(GraphQLString)),
# ('c', GraphQLField(GraphQLString)),
# ('a', GraphQLField(GraphQLString)),
# ('d', GraphQLField(GraphQLString, args=OrderedDict([
# ('q', GraphQLArgument(GraphQLString)),
# ('x', GraphQLArgument(GraphQLString)),
# ('v', GraphQLArgument(GraphQLString)),
# ('a', GraphQLArgument(GraphQLString)),
# ('n', GraphQLArgument(GraphQLString))
# ])))
# ])
# test_object = GraphQLObjectType(name='Test', fields=fields)
# ordered_fields = test_object.get_fields()
# assert list(ordered_fields.keys()) == ['b', 'c', 'a', 'd']
# field_with_args = test_object.get_fields().get('d')
# assert [a.name for a in field_with_args.args] == ['q', 'x', 'v', 'a', 'n']
assert CommonFields._meta.fields == {
'field1': String(),
'field2': String(),
}

View File

@ -4,6 +4,8 @@ from collections import OrderedDict
from graphql.type.typemap import GraphQLTypeMap
from .objecttype import ObjectType
from .interface import Interface
from .union import Union
from .inputobjecttype import InputObjectType
from .structures import List, NonNull
from .scalars import Scalar, String, Boolean, Int, Float, ID
@ -14,7 +16,7 @@ from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, Gr
def is_graphene_type(_type):
if isinstance(_type, (List, NonNull)):
return True
if inspect.isclass(_type) and issubclass(_type, (ObjectType, InputObjectType, Scalar)):
if inspect.isclass(_type) and issubclass(_type, (ObjectType, InputObjectType, Scalar, Interface, Union)):
return True
@ -42,8 +44,12 @@ class TypeMap(GraphQLTypeMap):
return cls.construct_objecttype(map, type)
if issubclass(type, InputObjectType):
return cls.construct_inputobjecttype(map, type)
if issubclass(type, Interface):
return cls.construct_interface(map, type)
if issubclass(type, Scalar):
return cls.construct_scalar(map, type)
if issubclass(type, Union):
return cls.construct_union(map, type)
return map
@classmethod
@ -79,7 +85,41 @@ class TypeMap(GraphQLTypeMap):
description=type._meta.description,
fields=None,
is_type_of=type.is_type_of,
interfaces=type._meta.interfaces
interfaces=None
)
interfaces = []
for i in type._meta.interfaces:
map = cls.construct_interface(map, i)
interfaces.append(map[i._meta.name])
map[type._meta.name].interfaces = interfaces
map[type._meta.name].fields = cls.construct_fields_for_type(map, type)
return map
@classmethod
def construct_union(cls, map, type):
from ..generators.definitions import GrapheneUnionType
types = []
for i in type._meta.types:
map = cls.construct_objecttype(map, i)
types.append(map[i._meta.name])
map[type._meta.name] = GrapheneUnionType(
graphene_type=type,
name=type._meta.name,
types=types,
resolve_type=type.resolve_type,
)
map[type._meta.name].types = types
return map
@classmethod
def construct_interface(cls, map, type):
from ..generators.definitions import GrapheneInterfaceType
map[type._meta.name] = GrapheneInterfaceType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields=None,
resolve_type=type.resolve_type,
)
map[type._meta.name].fields = cls.construct_fields_for_type(map, type)
return map
@ -93,16 +133,16 @@ class TypeMap(GraphQLTypeMap):
description=type._meta.description,
fields=None,
)
map[type._meta.name].fields = cls.construct_fields_for_type(map, type, is_input=True)
map[type._meta.name].fields = cls.construct_fields_for_type(map, type, is_input_type=True)
return map
@classmethod
def construct_fields_for_type(cls, map, type, is_input=False):
def construct_fields_for_type(cls, map, type, is_input_type=False):
fields = OrderedDict()
for name, field in type._meta.fields.items():
map = cls.reducer(map, field.type)
field_type = cls.get_field_type(map, field.type)
if is_input:
if is_input_type:
_field = GraphQLInputObjectField(
field_type,
default_value=field.default_value,

View File

@ -31,7 +31,8 @@ class UnionMeta(type):
class Union(six.with_metaclass(UnionMeta)):
resolve_type = None
def resolve_type(self, _type):
return type(_type)
def __init__(self, *args, **kwargs):
raise Exception("An Union cannot be intitialized")

View File

@ -1,6 +1,4 @@
from ..utils.orderedtype import OrderedType
# from .argument import Argument
class UnmountedType(OrderedType):
@ -60,3 +58,13 @@ class UnmountedType(OrderedType):
_creation_counter=self.creation_counter,
**self.kwargs
)
def __eq__(self, other):
return (
self is other or (
isinstance(other, UnmountedType) and
self.get_type() == other.get_type() and
self.args == other.args and
self.kwargs == other.kwargs
)
)

View File

@ -66,8 +66,5 @@ def get_fields_in_type(in_type, attrs):
def yank_fields_from_attrs(attrs, fields):
for name, field in fields.items():
# attrs.pop(name, None)
for name in fields.keys():
del attrs[name]
# return attrs
# return {k: v for k, v in attrs.items() if k not in fields}