Refactored fields getter to be immutable

This commit is contained in:
Syrus Akbary 2016-06-08 22:23:28 -07:00
parent 25e967200b
commit b24e9a1051
23 changed files with 242 additions and 275 deletions

View File

@ -21,13 +21,17 @@ class Character(graphene.Interface):
return [get_character(f) for f in self.friends] return [get_character(f) for f in self.friends]
@graphene.implements(Character) # @graphene.implements(Character)
class Human(graphene.ObjectType): class Human(graphene.ObjectType):
class Meta:
interfaces = [Character]
home_planet = graphene.String() home_planet = graphene.String()
@graphene.implements(Character) # @graphene.implements(Character)
class Droid(graphene.ObjectType): class Droid(graphene.ObjectType):
class Meta:
interfaces = [Character]
primary_function = graphene.String() primary_function = graphene.String()

View File

@ -0,0 +1,2 @@
from .node import Node
# from .mutation import ClientIDMutation

View File

@ -7,8 +7,10 @@ from ...types import ObjectType, Schema, implements
from ...types.scalars import String from ...types.scalars import String
@implements(Node)
class MyNode(ObjectType): class MyNode(ObjectType):
class Meta:
interfaces = [Node]
name = String() name = String()
@staticmethod @staticmethod
@ -25,8 +27,9 @@ schema = Schema(query=RootQuery, types=[MyNode])
def test_node_no_get_node(): def test_node_no_get_node():
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
@implements(Node)
class MyNode(ObjectType): class MyNode(ObjectType):
class Meta:
interfaces = [Node]
pass pass
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value) assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)

View File

@ -15,13 +15,17 @@ class CustomNode(Node):
return photo_data.get(id) return photo_data.get(id)
@implements(CustomNode) # @implements(CustomNode)
class User(ObjectType): class User(ObjectType):
class Meta:
interfaces = [CustomNode]
name = String() name = String()
@implements(CustomNode) # @implements(CustomNode)
class Photo(ObjectType): class Photo(ObjectType):
class Meta:
interfaces = [CustomNode]
width = Int() width = Int()

View File

@ -1,11 +1,9 @@
from collections import OrderedDict from collections import OrderedDict
import inspect import inspect
import copy
from itertools import chain from itertools import chain
from functools import partial from functools import partial
from graphql.utils.assert_valid_name import assert_valid_name from graphql.utils.assert_valid_name import assert_valid_name
from graphql.type.definition import GraphQLObjectType
from .options import Options from .options import Options
@ -28,34 +26,22 @@ class ClassTypeMeta(type):
else: else:
meta = attr_meta meta = attr_meta
new_class.add_to_class('_meta', new_class.get_options(meta)) new_class._meta = new_class.get_options(meta)
new_class._meta.parent = new_class
new_class._meta.validate_attrs()
if new_class._meta.name: if new_class._meta.name:
assert_valid_name(new_class._meta.name) assert_valid_name(new_class._meta.name)
new_class.construct_graphql_type(bases)
return mcs.construct(new_class, bases, attrs) return mcs.construct(new_class, bases, attrs)
def get_options(cls, meta): def get_options(cls, meta):
raise NotImplementedError("get_options is not implemented") raise NotImplementedError("get_options is not implemented")
def construct_graphql_type(cls, bases):
raise NotImplementedError("construct_graphql_type is not implemented")
def add_to_class(cls, name, value):
# We should call the contribute_to_class method only if it's bound
if not inspect.isclass(value) and hasattr(
value, 'contribute_to_class'):
value.contribute_to_class(cls, name)
else:
setattr(cls, name, value)
def construct(cls, bases, attrs): def construct(cls, bases, attrs):
# Add all attributes to the class. # Add all attributes to the class.
for obj_name, obj in attrs.items(): for name, value in attrs.items():
cls.add_to_class(obj_name, obj) setattr(cls, name, value)
# if not cls._meta.abstract:
# from ..types import List, NonNull
return cls return cls
@ -64,19 +50,29 @@ class FieldsMeta(type):
def _build_field_map(cls, bases, local_fields): def _build_field_map(cls, bases, local_fields):
from ..utils.extract_fields import get_base_fields from ..utils.extract_fields import get_base_fields
extended_fields = get_base_fields(cls, bases) extended_fields = list(get_base_fields(cls, bases))
fields = chain(extended_fields, local_fields)
fields = []
field_names = set(f.name for f in local_fields)
for extended_field in extended_fields:
if extended_field.name in field_names:
continue
fields.append(extended_field)
field_names.add(extended_field.name)
fields.extend(local_fields)
return OrderedDict((f.name, f) for f in fields) return OrderedDict((f.name, f) for f in fields)
def _fields(cls, bases, attrs): def _extract_local_fields(cls, attrs):
from ..utils.is_graphene_type import is_graphene_type
from ..utils.extract_fields import extract_fields from ..utils.extract_fields import extract_fields
return extract_fields(cls, attrs)
inherited_types = [ def _fields(cls, bases, attrs, local_fields, extra_types=()):
from ..utils.is_graphene_type import is_graphene_type
inherited_types = tuple(
base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract
] ) + extra_types
local_fields = extract_fields(cls, attrs)
return partial(cls._build_field_map, inherited_types, local_fields) return partial(cls._build_field_map, inherited_types, local_fields)
@ -84,58 +80,3 @@ class GrapheneGraphQLType(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.graphene_type = kwargs.pop('graphene_type') self.graphene_type = kwargs.pop('graphene_type')
super(GrapheneGraphQLType, self).__init__(*args, **kwargs) super(GrapheneGraphQLType, self).__init__(*args, **kwargs)
class GrapheneFieldsType(GrapheneGraphQLType):
def __init__(self, *args, **kwargs):
self._fields = None
self._field_map = None
super(GrapheneFieldsType, self).__init__(*args, **kwargs)
def add_field(self, field):
# We clear the cached fields
self._field_map = None
self._fields.add(field)
class FieldMap(object):
def __init__(self, parent, bases=None, fields=None):
self.parent = parent
self.fields = fields or []
self.bases = bases or []
def add(self, field):
self.fields.append(field)
def __call__(self):
# It's in a call function for assuring that if a field is added
# in runtime then it will be reflected in the Class type fields
# If we add the field in the class type creation, then we
# would not be able to change it later.
from .field import Field
prev_fields = []
graphql_type = self.parent._meta.graphql_type
# We collect the fields from the interfaces
if isinstance(graphql_type, GraphQLObjectType):
interfaces = graphql_type.get_interfaces()
for interface in interfaces:
prev_fields += interface.get_fields().items()
# We collect the fields from the bases
for base in self.bases:
prev_fields += base.get_fields().items()
fields = prev_fields + [
(field.name, field) for field in sorted(self.fields)
]
# Then we copy all the fields and assign the parent
new_fields = []
for field_name, field in fields:
field = copy.copy(field)
if isinstance(field, Field):
field.parent = self.parent
new_fields.append((field_name, field))
return OrderedDict(new_fields)

View File

@ -42,7 +42,7 @@ class EnumTypeMeta(ClassTypeMeta):
abstract=False abstract=False
) )
def construct_graphql_type(cls, bases): def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract: if not cls._meta.graphql_type and not cls._meta.abstract:
cls._meta.graphql_type = GrapheneEnumType( cls._meta.graphql_type = GrapheneEnumType(
graphene_type=cls, graphene_type=cls,
@ -50,7 +50,6 @@ class EnumTypeMeta(ClassTypeMeta):
description=cls._meta.description or cls.__doc__, description=cls._meta.description or cls.__doc__,
) )
def construct(cls, bases, attrs):
if not cls._meta.enum: if not cls._meta.enum:
cls._meta.enum = type(cls.__name__, (PyEnum,), attrs) cls._meta.enum = type(cls.__name__, (PyEnum,), attrs)

View File

@ -61,16 +61,13 @@ class Field(AbstractField, GraphQLField, OrderedType):
where.__name__ where.__name__
) )
def contribute_to_class(self, cls, attname): def mount(self, parent, attname=None):
from .objecttype import ObjectType from .objecttype import ObjectType
from .interface import Interface from .interface import Interface
assert issubclass(parent, (ObjectType, Interface)), self.mount_error_message(parent)
assert issubclass(cls, (ObjectType, Interface)), self.mount_error_message(cls)
self.attname = attname self.attname = attname
self.parent = cls self.parent = parent
add_field = getattr(cls._meta.graphql_type, "add_field", None)
assert add_field, self.mount_error_message(cls)
add_field(self)
@property @property
def resolver(self): def resolver(self):
@ -149,15 +146,12 @@ class InputField(AbstractField, GraphQLInputObjectField, OrderedType):
where.__name__ where.__name__
) )
def contribute_to_class(self, cls, attname): def mount(self, parent, attname):
from .inputobjecttype import InputObjectType from .inputobjecttype import InputObjectType
assert issubclass(cls, (InputObjectType)), self.mount_error_message(cls) assert issubclass(parent, (InputObjectType)), self.mount_error_message(parent)
self.attname = attname self.attname = attname
self.parent = cls self.parent = parent
add_field = getattr(cls._meta.graphql_type, "add_field", None)
assert add_field, self.mount_error_message(cls)
add_field(self)
def __copy__(self): def __copy__(self):
return InputField( return InputField(

View File

@ -2,11 +2,11 @@ import six
from graphql import GraphQLInputObjectType from graphql import GraphQLInputObjectType
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType from .definitions import FieldsMeta, ClassTypeMeta, GrapheneGraphQLType
from .proxy import TypeProxy from .proxy import TypeProxy
class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType): class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType):
pass pass
@ -21,17 +21,18 @@ class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta):
abstract=False abstract=False
) )
def construct_graphql_type(cls, bases):
pass
def construct(cls, bases, attrs): def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract: if not cls._meta.abstract:
local_fields = cls._extract_local_fields(attrs)
if not cls._meta.graphql_type:
cls._meta.graphql_type = GrapheneInputObjectType( cls._meta.graphql_type = GrapheneInputObjectType(
graphene_type=cls, graphene_type=cls,
name=cls._meta.name or cls.__name__, name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__, description=cls._meta.description or cls.__doc__,
fields=cls._fields(bases, attrs), fields=cls._fields(bases, attrs, local_fields),
) )
else:
assert not local_fields, "Can't mount Fields in an InputObjectType with a defined graphql_type"
return super(InputObjectTypeMeta, cls).construct(bases, attrs) return super(InputObjectTypeMeta, cls).construct(bases, attrs)

View File

@ -1,10 +1,10 @@
import six import six
from graphql import GraphQLInterfaceType from graphql import GraphQLInterfaceType
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType from .definitions import FieldsMeta, ClassTypeMeta, GrapheneGraphQLType
class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType): class GrapheneInterfaceType(GrapheneGraphQLType, GraphQLInterfaceType):
pass pass
@ -19,18 +19,19 @@ class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta):
abstract=False abstract=False
) )
def construct_graphql_type(cls, bases):
pass
def construct(cls, bases, attrs): def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract: if not cls._meta.abstract:
local_fields = cls._extract_local_fields(attrs)
if not cls._meta.graphql_type:
cls._meta.graphql_type = GrapheneInterfaceType( cls._meta.graphql_type = GrapheneInterfaceType(
graphene_type=cls, graphene_type=cls,
name=cls._meta.name or cls.__name__, name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__, description=cls._meta.description or cls.__doc__,
fields=cls._fields(bases, attrs), fields=cls._fields(bases, attrs, local_fields),
) )
else:
assert not local_fields, "Can't mount Fields in an Interface with a defined graphql_type"
return super(InterfaceTypeMeta, cls).construct(bases, attrs) return super(InterfaceTypeMeta, cls).construct(bases, attrs)

View File

@ -3,11 +3,11 @@ import six
from graphql import GraphQLObjectType from graphql import GraphQLObjectType
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap from .definitions import FieldsMeta, ClassTypeMeta, GrapheneGraphQLType
from .interface import GrapheneInterfaceType from .interface import GrapheneInterfaceType
class GrapheneObjectType(GrapheneFieldsType, GraphQLObjectType): class GrapheneObjectType(GrapheneGraphQLType, GraphQLObjectType):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(GrapheneObjectType, self).__init__(*args, **kwargs) super(GrapheneObjectType, self).__init__(*args, **kwargs)
@ -43,7 +43,7 @@ def get_interfaces(cls, interfaces):
yield graphql_type yield graphql_type
class ObjectTypeMeta(ClassTypeMeta): class ObjectTypeMeta(FieldsMeta, ClassTypeMeta):
def get_options(cls, meta): def get_options(cls, meta):
return cls.options_class( return cls.options_class(
@ -55,23 +55,24 @@ class ObjectTypeMeta(ClassTypeMeta):
abstract=False abstract=False
) )
def get_interfaces(cls): def construct(cls, bases, attrs):
return get_interfaces(cls, cls._meta.interfaces) if not cls._meta.abstract:
interfaces = tuple(get_interfaces(cls, cls._meta.interfaces))
def construct_graphql_type(cls, bases): local_fields = cls._extract_local_fields(attrs)
if not cls._meta.graphql_type and not cls._meta.abstract: if not cls._meta.graphql_type:
from ..utils.is_graphene_type import is_graphene_type cls = super(ObjectTypeMeta, cls).construct(bases, attrs)
inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base)
]
cls._meta.graphql_type = GrapheneObjectType( cls._meta.graphql_type = GrapheneObjectType(
graphene_type=cls, graphene_type=cls,
name=cls._meta.name or cls.__name__, name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__, description=cls._meta.description or cls.__doc__,
fields=FieldMap(cls, bases=filter(None, inherited_types)), fields=cls._fields(bases, attrs, local_fields, interfaces),
interfaces=tuple(cls.get_interfaces()), interfaces=interfaces,
) )
return cls
else:
assert not local_fields, "Can't mount Fields in an ObjectType with a defined graphql_type"
return super(ObjectTypeMeta, cls).construct(bases, attrs)
def implements(*interfaces): def implements(*interfaces):
@ -80,8 +81,10 @@ def implements(*interfaces):
def wrap_class(cls): def wrap_class(cls):
interface_types = get_interfaces(cls, interfaces) interface_types = get_interfaces(cls, interfaces)
graphql_type = cls._meta.graphql_type graphql_type = cls._meta.graphql_type
# fields = cls._build_field_map(interface_types, graphql_type.get_fields().values())
new_type = copy.copy(graphql_type) new_type = copy.copy(graphql_type)
new_type._provided_interfaces = tuple(graphql_type._provided_interfaces) + tuple(interface_types) new_type._provided_interfaces = tuple(graphql_type._provided_interfaces) + tuple(interface_types)
new_type._fields = graphql_type._fields
cls._meta.graphql_type = new_type cls._meta.graphql_type = new_type
cls._meta.graphql_type.check_interfaces() cls._meta.graphql_type.check_interfaces()
return cls return cls

View File

@ -8,11 +8,6 @@ class Options(object):
setattr(self, name, value) setattr(self, name, value)
self.valid_attrs = defaults.keys() self.valid_attrs = defaults.keys()
def contribute_to_class(self, cls, name):
cls._meta = self
self.parent = cls
self.validate_attrs()
def validate_attrs(self): def validate_attrs(self):
# Store the original user-defined values for each option, # Store the original user-defined values for each option,
# for use when serializing the model definition # for use when serializing the model definition

View File

@ -3,6 +3,8 @@ from .argument import Argument
from ..utils.orderedtype import OrderedType from ..utils.orderedtype import OrderedType
# UnmountedType ?
class TypeProxy(OrderedType): class TypeProxy(OrderedType):
''' '''
This class acts a proxy for a Graphene Type, so it can be mounted This class acts a proxy for a Graphene Type, so it can be mounted
@ -51,20 +53,6 @@ class TypeProxy(OrderedType):
**self.kwargs **self.kwargs
) )
def contribute_to_class(self, cls, attname):
from .inputobjecttype import InputObjectType
from .objecttype import ObjectType
from .interface import Interface
if issubclass(cls, (ObjectType, Interface)):
inner = self.as_field()
elif issubclass(cls, (InputObjectType)):
inner = self.as_inputfield()
else:
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
return inner.contribute_to_class(cls, attname)
def as_mounted(self, cls): def as_mounted(self, cls):
from .inputobjecttype import InputObjectType from .inputobjecttype import InputObjectType
from .objecttype import ObjectType from .objecttype import ObjectType

View File

@ -20,9 +20,6 @@ class ScalarTypeMeta(ClassTypeMeta):
abstract=False abstract=False
) )
def construct_graphql_type(cls, bases):
pass
def construct(cls, *args, **kwargs): def construct(cls, *args, **kwargs):
constructed = super(ScalarTypeMeta, cls).construct(*args, **kwargs) constructed = super(ScalarTypeMeta, cls).construct(*args, **kwargs)
if not cls._meta.graphql_type and not cls._meta.abstract: if not cls._meta.graphql_type and not cls._meta.abstract:

View File

@ -38,25 +38,6 @@ def test_not_source_and_resolver():
assert "You cannot have a source and a resolver at the same time" == str(excinfo.value) assert "You cannot have a source and a resolver at the same time" == str(excinfo.value)
def test_contributed_field_objecttype():
class MyObject(ObjectType):
pass
field = Field(GraphQLString)
field.contribute_to_class(MyObject, 'field_name')
assert field.name == 'fieldName'
def test_contributed_field_non_objecttype():
class MyObject(object):
pass
field = Field(GraphQLString)
with pytest.raises(AssertionError):
field.contribute_to_class(MyObject, 'field_name')
def test_copy_field_works(): def test_copy_field_works():
field = Field(GraphQLString) field = Field(GraphQLString)
copy.copy(field) copy.copy(field)

View File

@ -81,4 +81,4 @@ def test_interface_add_fields_in_reused_graphql_type():
class Meta: class Meta:
graphql_type = MyGraphQLType graphql_type = MyGraphQLType
assert """Field "MyGraphQLType.field" can only be mounted in ObjectType or Interface, received GrapheneInterface.""" == str(excinfo.value) assert """Can't mount Fields in an Interface with a defined graphql_type""" == str(excinfo.value)

View File

@ -10,8 +10,8 @@ from ..field import Field
class Container(ObjectType): class Container(ObjectType):
field1 = Field(GraphQLString) field1 = Field(GraphQLString, name='field1')
field2 = Field(GraphQLString) field2 = Field(GraphQLString, name='field2')
def test_generate_objecttype(): def test_generate_objecttype():
@ -53,116 +53,152 @@ def test_generate_objecttype_with_fields():
assert 'field' in fields assert 'field' in fields
def test_ordered_fields_in_objecttype():
class MyObjectType(ObjectType):
b = Field(GraphQLString)
a = Field(GraphQLString)
field = Field(GraphQLString)
asa = Field(GraphQLString)
graphql_type = MyObjectType._meta.graphql_type
fields = graphql_type.get_fields()
assert fields.keys() == ['b', 'a', 'field', 'asa']
def test_objecttype_inheritance(): def test_objecttype_inheritance():
class MyInheritedObjectType(ObjectType): class MyInheritedObjectType(ObjectType):
inherited = Field(GraphQLString) inherited = Field(GraphQLString)
class MyObjectType(MyInheritedObjectType): class MyObjectType(MyInheritedObjectType):
field = Field(GraphQLString) field1 = Field(GraphQLString)
field2 = Field(GraphQLString)
graphql_type = MyObjectType._meta.graphql_type graphql_type = MyObjectType._meta.graphql_type
fields = graphql_type.get_fields() fields = graphql_type.get_fields()
assert 'field' in fields assert 'field1' in fields
assert 'field2' in fields
assert 'inherited' in fields assert 'inherited' in fields
assert fields['field'] > fields['inherited'] assert fields['field1'] > fields['inherited']
assert fields['field2'] > fields['field1']
def test_objecttype_as_container_only_args(): def test_objecttype_as_container_get_fields():
container = Container("1", "2")
assert container.field1 == "1" class Container(ObjectType):
assert container.field2 == "2" field1 = Field(GraphQLString)
field2 = Field(GraphQLString)
assert Container._meta.graphql_type.get_fields().keys() == ['field1', 'field2']
def test_objecttype_as_container_args_kwargs(): def test_parent_container_get_fields():
container = Container("1", field2="2") fields = Container._meta.graphql_type.get_fields()
assert container.field1 == "1" print [(f.creation_counter, f.name) for f in fields.values()]
assert container.field2 == "2" assert fields.keys() == ['field1', 'field2']
def test_objecttype_as_container_few_kwargs(): # def test_objecttype_as_container_only_args():
container = Container(field2="2") # container = Container("1", "2")
assert container.field2 == "2" # assert container.field1 == "1"
# assert container.field2 == "2"
def test_objecttype_as_container_all_kwargs(): # def test_objecttype_as_container_args_kwargs():
container = Container(field1="1", field2="2") # container = Container("1", field2="2")
assert container.field1 == "1" # assert container.field1 == "1"
assert container.field2 == "2" # assert container.field2 == "2"
def test_objecttype_as_container_extra_args(): # def test_objecttype_as_container_few_kwargs():
with pytest.raises(IndexError) as excinfo: # container = Container(field2="2")
Container("1", "2", "3") # assert container.field2 == "2"
assert "Number of args exceeds number of fields" == str(excinfo.value)
def test_objecttype_as_container_invalid_kwargs(): # def test_objecttype_as_container_all_kwargs():
with pytest.raises(TypeError) as excinfo: # container = Container(field1="1", field2="2")
Container(unexisting_field="3") # assert container.field1 == "1"
# assert container.field2 == "2"
assert "'unexisting_field' is an invalid keyword argument for this function" == str(excinfo.value)
def test_objecttype_reuse_graphql_type(): # def test_objecttype_as_container_extra_args():
MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={ # with pytest.raises(IndexError) as excinfo:
'field': GraphQLField(GraphQLString) # Container("1", "2", "3")
})
class GrapheneObjectType(ObjectType): # assert "Number of args exceeds number of fields" == str(excinfo.value)
class Meta:
graphql_type = MyGraphQLType
graphql_type = GrapheneObjectType._meta.graphql_type
assert graphql_type == MyGraphQLType
instance = GrapheneObjectType(field="A")
assert instance.field == "A"
def test_objecttype_add_fields_in_reused_graphql_type(): # def test_objecttype_as_container_invalid_kwargs():
MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={ # with pytest.raises(TypeError) as excinfo:
'field': GraphQLField(GraphQLString) # Container(unexisting_field="3")
})
with pytest.raises(AssertionError) as excinfo: # assert "'unexisting_field' is an invalid keyword argument for this function" == str(excinfo.value)
class GrapheneObjectType(ObjectType):
field = Field(GraphQLString)
class Meta:
graphql_type = MyGraphQLType
assert """Field "MyGraphQLType.field" can only be mounted in ObjectType or Interface, received GrapheneObjectType.""" == str(excinfo.value)
def test_objecttype_graphql_interface(): # def test_objecttype_reuse_graphql_type():
MyInterface = GraphQLInterfaceType('MyInterface', fields={ # MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={
'field': GraphQLField(GraphQLString) # 'field': GraphQLField(GraphQLString)
}) # })
class GrapheneObjectType(ObjectType): # class GrapheneObjectType(ObjectType):
class Meta: # class Meta:
interfaces = [MyInterface] # graphql_type = MyGraphQLType
graphql_type = GrapheneObjectType._meta.graphql_type # graphql_type = GrapheneObjectType._meta.graphql_type
assert graphql_type.get_interfaces() == (MyInterface, ) # assert graphql_type == MyGraphQLType
# assert graphql_type.is_type_of(MyInterface, None, None) # instance = GrapheneObjectType(field="A")
fields = graphql_type.get_fields() # assert instance.field == "A"
assert 'field' in fields
def test_objecttype_graphene_interface(): # def test_objecttype_add_fields_in_reused_graphql_type():
class GrapheneInterface(Interface): # MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={
field = Field(GraphQLString) # 'field': GraphQLField(GraphQLString)
# })
class GrapheneObjectType(ObjectType): # with pytest.raises(AssertionError) as excinfo:
class Meta: # class GrapheneObjectType(ObjectType):
interfaces = [GrapheneInterface] # field = Field(GraphQLString)
graphql_type = GrapheneObjectType._meta.graphql_type # class Meta:
assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, ) # graphql_type = MyGraphQLType
assert graphql_type.is_type_of(GrapheneObjectType(), None, None)
fields = graphql_type.get_fields() # assert """Field "MyGraphQLType.field" can only be mounted in ObjectType or Interface, received GrapheneObjectType.""" == str(excinfo.value)
assert 'field' in fields
# def test_objecttype_graphql_interface():
# MyInterface = GraphQLInterfaceType('MyInterface', fields={
# 'field': GraphQLField(GraphQLString)
# })
# class GrapheneObjectType(ObjectType):
# class Meta:
# interfaces = [MyInterface]
# graphql_type = GrapheneObjectType._meta.graphql_type
# assert graphql_type.get_interfaces() == (MyInterface, )
# # assert graphql_type.is_type_of(MyInterface, None, None)
# fields = graphql_type.get_fields()
# assert 'field' in fields
# def test_objecttype_graphene_interface():
# class GrapheneInterface(Interface):
# name = Field(GraphQLString)
# extended = Field(GraphQLString)
# class GrapheneObjectType(ObjectType):
# class Meta:
# interfaces = [GrapheneInterface]
# field = Field(GraphQLString)
# graphql_type = GrapheneObjectType._meta.graphql_type
# assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, )
# assert graphql_type.is_type_of(GrapheneObjectType(), None, None)
# fields = graphql_type.get_fields()
# assert 'field' in fields
# assert 'extended' in fields
# assert 'name' in fields
# assert fields['field'] > fields['extended'] > fields['name']
# def test_objecttype_graphene_interface_extended(): # def test_objecttype_graphene_interface_extended():

View File

@ -21,16 +21,11 @@ def test_options_contribute_to_class():
overwritten = True overwritten = True
accepted = True accepted = True
class MyObject(object):
pass
options = Options(Meta, attr=True, overwritten=False) options = Options(Meta, attr=True, overwritten=False)
options.valid_attrs = ['accepted', 'overwritten'] options.valid_attrs = ['accepted', 'overwritten']
assert options.attr assert options.attr
assert not options.overwritten assert not options.overwritten
options.contribute_to_class(MyObject, '_meta')
assert MyObject._meta == options
assert options.parent == MyObject
def test_options_invalid_attrs(): def test_options_invalid_attrs():
@ -41,9 +36,10 @@ def test_options_invalid_attrs():
pass pass
options = Options(Meta, valid=True) options = Options(Meta, valid=True)
options.parent = MyObject
options.valid_attrs = ['valid'] options.valid_attrs = ['valid']
assert options.valid assert options.valid
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
options.contribute_to_class(MyObject, '_meta') options.validate_attrs()
assert "MyObject.Meta got invalid attributes: invalid" == str(excinfo.value) assert "MyObject.Meta got invalid attributes: invalid" == str(excinfo.value)

View File

@ -23,22 +23,27 @@ scalar_classes = {
@pytest.mark.parametrize("scalar_class,expected_graphql_type", scalar_classes.items()) @pytest.mark.parametrize("scalar_class,expected_graphql_type", scalar_classes.items())
def test_scalar_as_field(scalar_class, expected_graphql_type): def test_scalar_as_field(scalar_class, expected_graphql_type):
field_before = Field(None)
scalar = scalar_class() scalar = scalar_class()
field = scalar.as_field() field = scalar.as_field()
graphql_type = get_graphql_type(scalar_class) graphql_type = get_graphql_type(scalar_class)
field_after = Field(None)
assert isinstance(field, Field) assert isinstance(field, Field)
assert field.type == graphql_type assert field.type == graphql_type
assert graphql_type == expected_graphql_type assert graphql_type == expected_graphql_type
assert field_before < field < field_after
@pytest.mark.parametrize("scalar_class,graphql_type", scalar_classes.items()) @pytest.mark.parametrize("scalar_class,graphql_type", scalar_classes.items())
def test_scalar_in_objecttype(scalar_class, graphql_type): def test_scalar_in_objecttype(scalar_class, graphql_type):
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
before = Field(scalar_class)
field = scalar_class() field = scalar_class()
after = Field(scalar_class)
graphql_type = get_graphql_type(MyObjectType) graphql_type = get_graphql_type(MyObjectType)
fields = graphql_type.get_fields() fields = graphql_type.get_fields()
assert 'field' in fields assert fields.keys() == ['before', 'field', 'after']
assert isinstance(fields['field'], Field) assert isinstance(fields['field'], Field)

View File

@ -16,8 +16,11 @@ class Pet(ObjectType):
type = String() type = String()
@implements(Character) # @implements(Character)
class Human(ObjectType): class Human(ObjectType):
class Meta:
interfaces = [Character]
pet = Field(Pet) pet = Field(Pet)
def resolve_pet(self, *args): def resolve_pet(self, *args):

View File

@ -4,6 +4,7 @@ from graphql import GraphQLString, GraphQLList, GraphQLNonNull
from ..structures import List, NonNull from ..structures import List, NonNull
from ..scalars import String from ..scalars import String
from ..field import Field
def test_list(): def test_list():
@ -42,3 +43,10 @@ def test_nonnull_list():
assert isinstance(list_instance, GraphQLNonNull) assert isinstance(list_instance, GraphQLNonNull)
assert isinstance(list_instance.of_type, GraphQLList) assert isinstance(list_instance.of_type, GraphQLList)
assert list_instance.of_type.of_type == GraphQLString assert list_instance.of_type.of_type == GraphQLString
def test_preserve_order():
field1 = List(lambda: None)
field2 = Field(lambda: None)
assert field1 < field2

View File

View File

@ -15,6 +15,7 @@ def extract_fields(cls, attrs):
continue continue
field = value.as_mounted(cls) if is_field_proxy else copy.copy(value) field = value.as_mounted(cls) if is_field_proxy else copy.copy(value)
field.attname = attname field.attname = attname
field.parent = cls
fields.add(attname) fields.add(attname)
del attrs[attname] del attrs[attname]
_fields.append(field) _fields.append(field)
@ -24,10 +25,15 @@ def extract_fields(cls, attrs):
def get_base_fields(cls, bases): def get_base_fields(cls, bases):
fields = set() fields = set()
_fields = list()
for _class in bases: for _class in bases:
for attname, field in get_graphql_type(_class).get_fields().items(): for attname, field in get_graphql_type(_class).get_fields().items():
if attname in fields: if attname in fields:
continue continue
field = copy.copy(field) field = copy.copy(field)
if isinstance(field, Field):
field.parent = cls
fields.add(attname) fields.add(attname)
yield field _fields.append(field)
return sorted(_fields)

View File

@ -3,7 +3,7 @@ from functools import total_ordering
@total_ordering @total_ordering
class OrderedType(object): class OrderedType(object):
creation_counter = 0 creation_counter = 1
def __init__(self, _creation_counter=None): def __init__(self, _creation_counter=None):
self.creation_counter = _creation_counter or self.gen_counter() self.creation_counter = _creation_counter or self.gen_counter()