Added Union. Improved testing and code

This commit is contained in:
Syrus Akbary 2016-08-13 10:35:31 -07:00
parent 0802aaced0
commit 84c1da60dd
6 changed files with 159 additions and 153 deletions

View File

@ -16,7 +16,7 @@ class AbstractTypeMeta(type):
for base in bases: for base in bases:
if not issubclass(base, AbstractType) and issubclass(type(base), AbstractTypeMeta): if not issubclass(base, AbstractType) and issubclass(type(base), AbstractTypeMeta):
# raise Exception('You can only') # raise Exception('You can only extend AbstractTypes after the base definition.')
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
attrs = merge_fields_in_attrs(bases, attrs) attrs = merge_fields_in_attrs(bases, attrs)

View File

@ -2,15 +2,17 @@ from collections import OrderedDict
from py.test import raises from py.test import raises
from ..abstracttype import AbstractType
from ..objecttype import ObjectType from ..objecttype import ObjectType
from ..interface import Interface from ..interface import Interface
from ..union import Union from ..union import Union
from ..scalars import String, Int, Boolean from ..scalars import String, Int, Boolean
from ..field import Field from ..field import Field
from ..structures import List from ..inputfield import InputField
from ..enum import Enum
from ..inputobjecttype import InputObjectType
from ..structures import List, NonNull from ..structures import List, NonNull
from ..enum import Enum
from ..argument import Argument
from ..inputobjecttype import InputObjectType
from ..schema import Schema from ..schema import Schema
@ -116,91 +118,87 @@ def test_defines_a_subscription_schema():
assert subscription.type._meta.name == 'Article' assert subscription.type._meta.name == 'Article'
# def test_includes_nested_input_objects_in_the_map(): def test_includes_nested_input_objects_in_the_map():
# NestedInputObject = GraphQLInputObjectType( class NestedInputObject(InputObjectType):
# name='NestedInputObject', value = String()
# fields={'value': GraphQLInputObjectField(GraphQLString)}
# )
# SomeInputObject = GraphQLInputObjectType( class SomeInputObject(InputObjectType):
# name='SomeInputObject', nested = InputField(NestedInputObject)
# fields={'nested': GraphQLInputObjectField(NestedInputObject)}
# )
# SomeMutation = GraphQLObjectType( class SomeMutation(Mutation):
# name='SomeMutation', mutate_something = Field(Article, input=Argument(SomeInputObject))
# fields={
# 'mutateSomething': GraphQLField(
# type=BlogArticle,
# args={
# 'input': GraphQLArgument(SomeInputObject)
# }
# )
# }
# )
# SomeSubscription = GraphQLObjectType(
# name='SomeSubscription',
# fields={
# 'subscribeToSomething': GraphQLField(
# type=BlogArticle,
# args={
# 'input': GraphQLArgument(SomeInputObject)
# }
# )
# }
# )
# schema = GraphQLSchema( class SomeSubscription(Mutation):
# query=BlogQuery, subscribe_to_something = Field(Article, input=Argument(SomeInputObject))
# mutation=SomeMutation,
# subscription=SomeSubscription
# )
# assert schema.get_type_map()['NestedInputObject'] is NestedInputObject schema = Schema(
query=Query,
mutation=SomeMutation,
subscription=SomeSubscription
)
print schema.get_type_map()
assert schema.get_type_map()['NestedInputObject'].graphene_type is NestedInputObject
# def test_includes_interfaces_thunk_subtypes_in_the_type_map(): def test_includes_interfaces_thunk_subtypes_in_the_type_map():
# SomeInterface = GraphQLInterfaceType( class SomeInterface(Interface):
# name='SomeInterface', f = Int()
# fields={
# 'f': GraphQLField(GraphQLInt)
# }
# )
# SomeSubtype = GraphQLObjectType( class SomeSubtype(ObjectType):
# name='SomeSubtype', class Meta:
# fields={ interfaces = (SomeInterface, )
# 'f': GraphQLField(GraphQLInt)
# },
# interfaces=lambda: [SomeInterface],
# is_type_of=lambda: True
# )
# schema = GraphQLSchema(query=GraphQLObjectType( class Query(ObjectType):
# name='Query', iface = Field(lambda: SomeInterface)
# fields={
# 'iface': GraphQLField(SomeInterface)
# }
# ), types=[SomeSubtype])
# assert schema.get_type_map()['SomeSubtype'] is SomeSubtype schema = Schema(
query=Query,
types=[SomeSubtype]
)
assert schema.get_type_map()['SomeSubtype'].graphene_type is SomeSubtype
# def test_includes_interfaces_subtypes_in_the_type_map(): def test_includes_types_in_union():
# SomeInterface = GraphQLInterfaceType('SomeInterface', fields={'f': GraphQLField(GraphQLInt)}) class SomeType(ObjectType):
# SomeSubtype = GraphQLObjectType( a = String()
# name='SomeSubtype',
# fields={'f': GraphQLField(GraphQLInt)}, class OtherType(ObjectType):
# interfaces=[SomeInterface], b = String()
# is_type_of=lambda: None
# ) class MyUnion(Union):
# schema = GraphQLSchema( class Meta:
# query=GraphQLObjectType( types = (SomeType, OtherType)
# name='Query',
# fields={ class Query(ObjectType):
# 'iface': GraphQLField(SomeInterface)}), union = Field(MyUnion)
# types=[SomeSubtype])
# assert schema.get_type_map()['SomeSubtype'] == SomeSubtype schema = Schema(
query=Query,
)
assert schema.get_type_map()['OtherType'].graphene_type is OtherType
assert schema.get_type_map()['SomeType'].graphene_type is SomeType
def test_includes_interfaces_subtypes_in_the_type_map():
class SomeInterface(Interface):
f = Int()
class SomeSubtype(ObjectType):
class Meta:
interfaces = (SomeInterface, )
class Query(ObjectType):
iface = Field(SomeInterface)
schema = Schema(
query=Query,
types=[SomeSubtype]
)
assert schema.get_type_map()['SomeSubtype'].graphene_type is SomeSubtype
def test_stringifies_simple_types(): def test_stringifies_simple_types():
@ -274,74 +272,36 @@ def test_stringifies_simple_types():
# == str(excinfo.value) # == str(excinfo.value)
# def test_does_not_mutate_passed_field_definitions(): def test_does_not_mutate_passed_field_definitions():
# fields = { class CommonFields(AbstractType):
# 'field1': GraphQLField(GraphQLString), field1 = String()
# 'field2': GraphQLField(GraphQLString, args={'id': GraphQLArgument(GraphQLString)}), field2 = String(id=String())
# }
# TestObject1 = GraphQLObjectType(name='Test1', fields=fields) class TestObject1(CommonFields, ObjectType):
# TestObject2 = GraphQLObjectType(name='Test1', fields=fields) pass
# assert TestObject1.get_fields() == TestObject2.get_fields() class TestObject2(CommonFields, ObjectType):
# assert fields == { pass
# 'field1': GraphQLField(GraphQLString),
# 'field2': GraphQLField(GraphQLString, args={'id': GraphQLArgument(GraphQLString)}),
# }
# input_fields = { assert TestObject1._meta.fields == TestObject2._meta.fields
# 'field1': GraphQLInputObjectField(GraphQLString), assert CommonFields._meta.fields == {
# 'field2': GraphQLInputObjectField(GraphQLString), 'field1': String(),
# } 'field2': String(id=String()),
}
# TestInputObject1 = GraphQLInputObjectType(name='Test1', fields=input_fields) class CommonFields(AbstractType):
# TestInputObject2 = GraphQLInputObjectType(name='Test2', fields=input_fields) field1 = String()
field2 = String()
# assert TestInputObject1.get_fields() == TestInputObject2.get_fields() class TestInputObject1(CommonFields, InputObjectType):
pass
# assert input_fields == { class TestInputObject2(CommonFields, InputObjectType):
# 'field1': GraphQLInputObjectField(GraphQLString), pass
# 'field2': GraphQLInputObjectField(GraphQLString),
# }
assert TestInputObject1._meta.fields == TestInputObject2._meta.fields
# def test_sorts_fields_and_argument_keys_if_not_using_ordered_dict(): assert CommonFields._meta.fields == {
# fields = { 'field1': String(),
# 'b': GraphQLField(GraphQLString), 'field2': String(),
# 'c': GraphQLField(GraphQLString), }
# 'a': GraphQLField(GraphQLString),
# 'd': GraphQLField(GraphQLString, args={
# 'q': GraphQLArgument(GraphQLString),
# 'x': GraphQLArgument(GraphQLString),
# 'v': GraphQLArgument(GraphQLString),
# 'a': GraphQLArgument(GraphQLString),
# 'n': GraphQLArgument(GraphQLString)
# })
# }
# test_object = GraphQLObjectType(name='Test', fields=fields)
# ordered_fields = test_object.get_fields()
# assert list(ordered_fields.keys()) == ['a', 'b', 'c', 'd']
# field_with_args = test_object.get_fields().get('d')
# assert [a.name for a in field_with_args.args] == ['a', 'n', 'q', 'v', 'x']
# def test_does_not_sort_fields_and_argument_keys_when_using_ordered_dict():
# fields = OrderedDict([
# ('b', GraphQLField(GraphQLString)),
# ('c', GraphQLField(GraphQLString)),
# ('a', GraphQLField(GraphQLString)),
# ('d', GraphQLField(GraphQLString, args=OrderedDict([
# ('q', GraphQLArgument(GraphQLString)),
# ('x', GraphQLArgument(GraphQLString)),
# ('v', GraphQLArgument(GraphQLString)),
# ('a', GraphQLArgument(GraphQLString)),
# ('n', GraphQLArgument(GraphQLString))
# ])))
# ])
# test_object = GraphQLObjectType(name='Test', fields=fields)
# ordered_fields = test_object.get_fields()
# assert list(ordered_fields.keys()) == ['b', 'c', 'a', 'd']
# field_with_args = test_object.get_fields().get('d')
# assert [a.name for a in field_with_args.args] == ['q', 'x', 'v', 'a', 'n']

View File

@ -4,6 +4,8 @@ from collections import OrderedDict
from graphql.type.typemap import GraphQLTypeMap from graphql.type.typemap import GraphQLTypeMap
from .objecttype import ObjectType from .objecttype import ObjectType
from .interface import Interface
from .union import Union
from .inputobjecttype import InputObjectType from .inputobjecttype import InputObjectType
from .structures import List, NonNull from .structures import List, NonNull
from .scalars import Scalar, String, Boolean, Int, Float, ID from .scalars import Scalar, String, Boolean, Int, Float, ID
@ -14,7 +16,7 @@ from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, Gr
def is_graphene_type(_type): def is_graphene_type(_type):
if isinstance(_type, (List, NonNull)): if isinstance(_type, (List, NonNull)):
return True return True
if inspect.isclass(_type) and issubclass(_type, (ObjectType, InputObjectType, Scalar)): if inspect.isclass(_type) and issubclass(_type, (ObjectType, InputObjectType, Scalar, Interface, Union)):
return True return True
@ -42,8 +44,12 @@ class TypeMap(GraphQLTypeMap):
return cls.construct_objecttype(map, type) return cls.construct_objecttype(map, type)
if issubclass(type, InputObjectType): if issubclass(type, InputObjectType):
return cls.construct_inputobjecttype(map, type) return cls.construct_inputobjecttype(map, type)
if issubclass(type, Interface):
return cls.construct_interface(map, type)
if issubclass(type, Scalar): if issubclass(type, Scalar):
return cls.construct_scalar(map, type) return cls.construct_scalar(map, type)
if issubclass(type, Union):
return cls.construct_union(map, type)
return map return map
@classmethod @classmethod
@ -79,7 +85,41 @@ class TypeMap(GraphQLTypeMap):
description=type._meta.description, description=type._meta.description,
fields=None, fields=None,
is_type_of=type.is_type_of, is_type_of=type.is_type_of,
interfaces=type._meta.interfaces interfaces=None
)
interfaces = []
for i in type._meta.interfaces:
map = cls.construct_interface(map, i)
interfaces.append(map[i._meta.name])
map[type._meta.name].interfaces = interfaces
map[type._meta.name].fields = cls.construct_fields_for_type(map, type)
return map
@classmethod
def construct_union(cls, map, type):
from ..generators.definitions import GrapheneUnionType
types = []
for i in type._meta.types:
map = cls.construct_objecttype(map, i)
types.append(map[i._meta.name])
map[type._meta.name] = GrapheneUnionType(
graphene_type=type,
name=type._meta.name,
types=types,
resolve_type=type.resolve_type,
)
map[type._meta.name].types = types
return map
@classmethod
def construct_interface(cls, map, type):
from ..generators.definitions import GrapheneInterfaceType
map[type._meta.name] = GrapheneInterfaceType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields=None,
resolve_type=type.resolve_type,
) )
map[type._meta.name].fields = cls.construct_fields_for_type(map, type) map[type._meta.name].fields = cls.construct_fields_for_type(map, type)
return map return map
@ -93,16 +133,16 @@ class TypeMap(GraphQLTypeMap):
description=type._meta.description, description=type._meta.description,
fields=None, fields=None,
) )
map[type._meta.name].fields = cls.construct_fields_for_type(map, type, is_input=True) map[type._meta.name].fields = cls.construct_fields_for_type(map, type, is_input_type=True)
return map return map
@classmethod @classmethod
def construct_fields_for_type(cls, map, type, is_input=False): def construct_fields_for_type(cls, map, type, is_input_type=False):
fields = OrderedDict() fields = OrderedDict()
for name, field in type._meta.fields.items(): for name, field in type._meta.fields.items():
map = cls.reducer(map, field.type) map = cls.reducer(map, field.type)
field_type = cls.get_field_type(map, field.type) field_type = cls.get_field_type(map, field.type)
if is_input: if is_input_type:
_field = GraphQLInputObjectField( _field = GraphQLInputObjectField(
field_type, field_type,
default_value=field.default_value, default_value=field.default_value,

View File

@ -31,7 +31,8 @@ class UnionMeta(type):
class Union(six.with_metaclass(UnionMeta)): class Union(six.with_metaclass(UnionMeta)):
resolve_type = None def resolve_type(self, _type):
return type(_type)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise Exception("An Union cannot be intitialized") raise Exception("An Union cannot be intitialized")

View File

@ -1,6 +1,4 @@
from ..utils.orderedtype import OrderedType from ..utils.orderedtype import OrderedType
# from .argument import Argument
class UnmountedType(OrderedType): class UnmountedType(OrderedType):
@ -60,3 +58,13 @@ class UnmountedType(OrderedType):
_creation_counter=self.creation_counter, _creation_counter=self.creation_counter,
**self.kwargs **self.kwargs
) )
def __eq__(self, other):
return (
self is other or (
isinstance(other, UnmountedType) and
self.get_type() == other.get_type() and
self.args == other.args and
self.kwargs == other.kwargs
)
)

View File

@ -66,8 +66,5 @@ def get_fields_in_type(in_type, attrs):
def yank_fields_from_attrs(attrs, fields): def yank_fields_from_attrs(attrs, fields):
for name, field in fields.items(): for name in fields.keys():
# attrs.pop(name, None)
del attrs[name] del attrs[name]
# return attrs
# return {k: v for k, v in attrs.items() if k not in fields}