diff --git a/graphene/types/field.py b/graphene/types/field.py index b7abaead..25d29de8 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -115,7 +115,7 @@ class Field(AbstractField, GraphQLField, OrderedType): return self.copy_and_extend(self) @classmethod - def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, **extra_args): + def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, parent=None, attname=None, **extra_args): if isinstance(field, Field): type = type or field._type resolver = resolver or field._resolver @@ -123,8 +123,8 @@ class Field(AbstractField, GraphQLField, OrderedType): name = name or field._name required = required or field.required _creation_counter = field.creation_counter if _creation_counter is False else None - attname = field.attname - parent = field.parent + attname = attname or field.attname + parent = parent or field.parent else: # If is a GraphQLField type = type or field.type @@ -133,8 +133,8 @@ class Field(AbstractField, GraphQLField, OrderedType): name = field.name required = None _creation_counter = None - attname = name - parent = None + attname = attname or name + parent = parent new_field = cls( type=type, diff --git a/graphene/types/interface.py b/graphene/types/interface.py index 6d51f238..58c3e49e 100644 --- a/graphene/types/interface.py +++ b/graphene/types/interface.py @@ -2,42 +2,74 @@ import six from graphql import GraphQLInterfaceType from .definitions import FieldsMeta, ClassTypeMeta, GrapheneGraphQLType - +from .options import Options +from ..utils.is_base_type import is_base_type class GrapheneInterfaceType(GrapheneGraphQLType, GraphQLInterfaceType): pass -class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta): +class InterfaceTypeMeta(type): - def get_options(cls, meta): - return cls.options_class( - meta, + def __new__(cls, name, bases, attrs): + super_new = super(InterfaceTypeMeta, cls).__new__ + + # Also ensure initialization is only performed for subclasses of Model + # (excluding Model class itself). + if not is_base_type(bases, InterfaceTypeMeta): + return super_new(cls, name, bases, attrs) + + options = Options( + attrs.pop('Meta', None), name=None, description=None, graphql_type=None, abstract=False ) - def construct(cls, bases, attrs): - if not cls._meta.abstract: - local_fields = cls._extract_local_fields(attrs) - if not cls._meta.graphql_type: - cls._meta.graphql_type = GrapheneInterfaceType( - graphene_type=cls, - name=cls._meta.name or cls.__name__, - description=cls._meta.description or cls.__doc__, - fields=cls._fields(bases, attrs, local_fields), - ) - else: - assert not local_fields, "Can't mount Fields in an Interface with a defined graphql_type" + from ..utils.get_fields import get_fields - return super(InterfaceTypeMeta, cls).construct(bases, attrs) + fields = get_fields(Interface, attrs, bases) + attrs = attrs_without_fields(attrs, fields) + cls = super_new(cls, name, bases, dict(attrs, _meta=options)) + + if not options.graphql_type: + fields = copy_fields(fields) + options.graphql_type = GrapheneInterfaceType( + graphene_type=cls, + name=options.name or cls.__name__, + description=options.description or cls.__doc__, + fields=fields, + ) + else: + assert not fields, "Can't mount Fields in an Interface with a defined graphql_type" + fields = copy_fields(options.graphql_type.get_fields()) + + for attname, field in fields.items(): + field.parent = cls + # setattr(cls, field.name, field) + + return cls + + +def attrs_without_fields(attrs, fields): + fields_names = fields.keys() + return {k: v for k, v in attrs.items() if k not in fields_names} + + +def copy_fields(fields): + from collections import OrderedDict + from .field import Field + + _fields = [] + for attname, field in fields.items(): + field = Field.copy_and_extend(field, attname=attname) + _fields.append(field) + + return OrderedDict((f.name, f) for f in sorted(_fields)) class Interface(six.with_metaclass(InterfaceTypeMeta)): - class Meta: - abstract = True def __init__(self, *args, **kwargs): from .objecttype import ObjectType diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 607c8524..d461aac6 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -46,7 +46,7 @@ def get_interfaces(cls, interfaces): yield graphql_type -class ObjectTypeMeta(InterfaceTypeMeta): +class ObjectTypeMeta(FieldsMeta, ClassTypeMeta, InterfaceTypeMeta): def get_options(cls, meta): return cls.options_class( diff --git a/graphene/utils/get_fields.py b/graphene/utils/get_fields.py new file mode 100644 index 00000000..080e6084 --- /dev/null +++ b/graphene/utils/get_fields.py @@ -0,0 +1,48 @@ +from collections import OrderedDict + +from .get_graphql_type import get_graphql_type +from .is_graphene_type import is_graphene_type +from ..types.field import Field, InputField +from ..types.unmountedtype import UnmountedType + + +def get_fields_from_attrs(in_type, attrs): + for attname, value in list(attrs.items()): + is_field = isinstance(value, (Field, InputField)) + is_field_proxy = isinstance(value, UnmountedType) + if not (is_field or is_field_proxy): + continue + field = value.as_mounted(in_type) if is_field_proxy else value + yield attname, field + + +def get_fields_from_types(bases): + fields = set() + for _class in bases: + for attname, field in get_graphql_type(_class).get_fields().items(): + if attname in fields: + continue + fields.add(attname) + yield attname, field + + +def get_fields(in_type, attrs, bases): + fields = [] + + graphene_bases = tuple( + base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract + ) + + extended_fields = list(get_fields_from_types(graphene_bases)) + local_fields = list(get_fields_from_attrs(in_type, attrs)) + + field_names = set(f[0] for f in local_fields) + for name, extended_field in extended_fields: + if name in field_names: + continue + fields.append((name, extended_field)) + field_names.add(name) + + fields.extend(local_fields) + + return OrderedDict(fields) diff --git a/graphene/utils/is_base_type.py b/graphene/utils/is_base_type.py new file mode 100644 index 00000000..0e079f54 --- /dev/null +++ b/graphene/utils/is_base_type.py @@ -0,0 +1,2 @@ +def is_base_type(bases, _type): + return any(b for b in bases if isinstance(b, _type)) diff --git a/graphene/utils/is_graphene_type.py b/graphene/utils/is_graphene_type.py index 89a34232..3a2c9caa 100644 --- a/graphene/utils/is_graphene_type.py +++ b/graphene/utils/is_graphene_type.py @@ -7,6 +7,8 @@ from ..types.enum import Enum def is_graphene_type(_type): + if _type in [Interface]: + return False return inspect.isclass(_type) and issubclass(_type, ( Interface, ObjectType,