diff --git a/graphene/new_types/abstracttype.py b/graphene/new_types/abstracttype.py index 93243b05..3de3b290 100644 --- a/graphene/new_types/abstracttype.py +++ b/graphene/new_types/abstracttype.py @@ -1,33 +1,30 @@ import six -from collections import OrderedDict from ..utils.is_base_type import is_base_type from .options import Options -from .utils import get_fields_in_type, attrs_without_fields - - -def merge_fields_in_attrs(bases, attrs): - for base in bases: - if not issubclass(base, AbstractType): - continue - for name, field in base._meta.fields.items(): - if name in attrs: - continue - attrs[name] = field - return attrs +from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs class AbstractTypeMeta(type): def __new__(cls, name, bases, attrs): - options = attrs.get('_meta', Options()) + # Also ensure initialization is only performed for subclasses of + # ObjectType + if not is_base_type(bases, AbstractTypeMeta): + return type.__new__(cls, name, bases, attrs) + + for base in bases: + if not issubclass(base, AbstractType) and issubclass(type(base), AbstractTypeMeta): + # raise Exception('You can only') + return type.__new__(cls, name, bases, attrs) attrs = merge_fields_in_attrs(bases, attrs) fields = get_fields_in_type(cls, attrs) - options.fields = OrderedDict(sorted(fields, key=lambda f: f[1])) + yank_fields_from_attrs(attrs, fields) - attrs = attrs_without_fields(attrs, fields) + options = attrs.get('_meta', Options()) + options.fields = fields cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) return cls diff --git a/graphene/new_types/interface.py b/graphene/new_types/interface.py index 418f76e6..2c46f814 100644 --- a/graphene/new_types/interface.py +++ b/graphene/new_types/interface.py @@ -1,17 +1,17 @@ import six -from collections import OrderedDict from ..utils.is_base_type import is_base_type from .options import Options -from .utils import get_fields_in_type, attrs_without_fields +from .abstracttype import AbstractTypeMeta +from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs -class InterfaceMeta(type): +class InterfaceMeta(AbstractTypeMeta): def __new__(cls, name, bases, attrs): # Also ensure initialization is only performed for subclasses of - # ObjectType + # Interface if not is_base_type(bases, InterfaceMeta): return type.__new__(cls, name, bases, attrs) @@ -19,15 +19,14 @@ class InterfaceMeta(type): attrs.pop('Meta', None), name=name, description=attrs.get('__doc__'), + interfaces=(), ) - fields = get_fields_in_type(Interface, attrs) - options.fields = OrderedDict(sorted(fields, key=lambda f: f[1])) + attrs = merge_fields_in_attrs(bases, attrs) + options.fields = get_fields_in_type(cls, attrs) + yank_fields_from_attrs(attrs, options.fields) - attrs = attrs_without_fields(attrs, fields) - cls = super(InterfaceMeta, cls).__new__(cls, name, bases, dict(attrs, _meta=options)) - - return cls + return type.__new__(cls, name, bases, dict(attrs, _meta=options)) class Interface(six.with_metaclass(InterfaceMeta)): diff --git a/graphene/new_types/objecttype.py b/graphene/new_types/objecttype.py index cbff92fd..fa5e6303 100644 --- a/graphene/new_types/objecttype.py +++ b/graphene/new_types/objecttype.py @@ -1,13 +1,13 @@ import six -from collections import OrderedDict from ..utils.is_base_type import is_base_type from .options import Options -from .utils import get_fields_in_type, attrs_without_fields +from .abstracttype import AbstractTypeMeta +from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs -class ObjectTypeMeta(type): +class ObjectTypeMeta(AbstractTypeMeta): def __new__(cls, name, bases, attrs): # Also ensure initialization is only performed for subclasses of @@ -22,13 +22,11 @@ class ObjectTypeMeta(type): interfaces=(), ) - fields = get_fields_in_type(ObjectType, attrs) - options.fields = OrderedDict(sorted(fields, key=lambda f: f[1])) + attrs = merge_fields_in_attrs(bases, attrs) + options.fields = get_fields_in_type(cls, attrs) + yank_fields_from_attrs(attrs, options.fields) - attrs = attrs_without_fields(attrs, fields) - cls = super(ObjectTypeMeta, cls).__new__(cls, name, bases, dict(attrs, _meta=options)) - - return cls + return type.__new__(cls, name, bases, dict(attrs, _meta=options)) class ObjectType(six.with_metaclass(ObjectTypeMeta)): diff --git a/graphene/new_types/tests/test_abstracttype.py b/graphene/new_types/tests/test_abstracttype.py index 6be6eac3..d6e9069c 100644 --- a/graphene/new_types/tests/test_abstracttype.py +++ b/graphene/new_types/tests/test_abstracttype.py @@ -38,7 +38,8 @@ def test_generate_abstracttype_inheritance(): field2 = UnmountedType(MyType) assert MyAbstractType2._meta.fields.keys() == ['field1', 'field2'] - + assert not hasattr(MyAbstractType1, 'field1') + assert not hasattr(MyAbstractType2, 'field2') # def test_ordered_fields_in_objecttype(): # class MyObjectType(ObjectType): diff --git a/graphene/new_types/tests/test_objecttype.py b/graphene/new_types/tests/test_objecttype.py index d652b165..aefa9292 100644 --- a/graphene/new_types/tests/test_objecttype.py +++ b/graphene/new_types/tests/test_objecttype.py @@ -3,6 +3,7 @@ import pytest from ..field import Field from ..objecttype import ObjectType from ..unmountedtype import UnmountedType +from ..abstracttype import AbstractType class MyType(object): @@ -59,6 +60,27 @@ def test_ordered_fields_in_objecttype(): assert list(MyObjectType._meta.fields.keys()) == ['b', 'a', 'field', 'asa'] +def test_generate_objecttype_inherit_abstracttype(): + class MyAbstractType(AbstractType): + field1 = MyScalar(MyType) + + class MyObjectType(ObjectType, MyAbstractType): + field2 = MyScalar(MyType) + + assert MyObjectType._meta.fields.keys() == ['field1', 'field2'] + assert [type(x) for x in MyObjectType._meta.fields.values()] == [Field, Field] + +def test_generate_objecttype_inherit_abstracttype_reversed(): + class MyAbstractType(AbstractType): + field1 = MyScalar(MyType) + + class MyObjectType(MyAbstractType, ObjectType): + field2 = MyScalar(MyType) + + assert MyObjectType._meta.fields.keys() == ['field1', 'field2'] + assert [type(x) for x in MyObjectType._meta.fields.values()] == [Field, Field] + + def test_generate_objecttype_unmountedtype(): class MyObjectType(ObjectType): field = MyScalar(MyType) diff --git a/graphene/new_types/utils.py b/graphene/new_types/utils.py index 1a1ed78f..6ffbb8d0 100644 --- a/graphene/new_types/utils.py +++ b/graphene/new_types/utils.py @@ -1,7 +1,21 @@ +from collections import OrderedDict + from .unmountedtype import UnmountedType from .field import Field +def merge_fields_in_attrs(bases, attrs): + from ..new_types.abstracttype import AbstractType + for base in bases: + if base == AbstractType or not issubclass(base, AbstractType): + continue + for name, field in base._meta.fields.items(): + if name in attrs: + continue + attrs[name] = field + return attrs + + def unmounted_field_in_type(attname, unmounted_field, type): ''' Mount the UnmountedType dinamically as Field or InputField @@ -12,10 +26,10 @@ def unmounted_field_in_type(attname, unmounted_field, type): ''' # from ..types.inputobjecttype import InputObjectType from ..new_types.objecttype import ObjectTypeMeta - from ..new_types.interface import Interface + from ..new_types.interface import InterfaceMeta from ..new_types.abstracttype import AbstractTypeMeta - if issubclass(type, (ObjectTypeMeta, Interface)): + if issubclass(type, (ObjectTypeMeta, InterfaceMeta)): return unmounted_field.as_field() elif issubclass(type, (AbstractTypeMeta)): @@ -31,12 +45,23 @@ def unmounted_field_in_type(attname, unmounted_field, type): def get_fields_in_type(in_type, attrs): + fields_with_names = [] for attname, value in list(attrs.items()): if isinstance(value, (Field)): # , InputField - yield attname, value + fields_with_names.append( + (attname, value) + ) elif isinstance(value, UnmountedType): - yield attname, unmounted_field_in_type(attname, value, in_type) + fields_with_names.append( + (attname, unmounted_field_in_type(attname, value, in_type)) + ) + + return OrderedDict(sorted(fields_with_names, key=lambda f: f[1])) -def attrs_without_fields(attrs, fields): - return {k: v for k, v in attrs.items() if k not in fields} +def yank_fields_from_attrs(attrs, fields): + for name, field in fields.items(): + # attrs.pop(name, None) + del attrs[name] + # return attrs + # return {k: v for k, v in attrs.items() if k not in fields}