diff --git a/graphene/types/definitions.py b/graphene/types/definitions.py index 992bb43a..0980f655 100644 --- a/graphene/types/definitions.py +++ b/graphene/types/definitions.py @@ -1,6 +1,8 @@ from collections import OrderedDict import inspect import copy +from itertools import chain +from functools import partial from graphql.utils.assert_valid_name import assert_valid_name from graphql.type.definition import GraphQLObjectType @@ -58,6 +60,26 @@ class ClassTypeMeta(type): return cls +class FieldsMeta(type): + + def _build_field_map(cls, bases, local_fields): + from ..utils.extract_fields import get_base_fields + extended_fields = get_base_fields(cls, bases) + fields = chain(extended_fields, local_fields) + return OrderedDict((f.name, f) for f in fields) + + def _fields(cls, bases, attrs): + from ..utils.is_graphene_type import is_graphene_type + from ..utils.extract_fields import extract_fields + + inherited_types = [ + base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract + ] + + local_fields = extract_fields(cls, attrs) + return partial(cls._build_field_map, inherited_types, local_fields) + + class GrapheneGraphQLType(object): def __init__(self, *args, **kwargs): self.graphene_type = kwargs.pop('graphene_type') diff --git a/graphene/types/inputobjecttype.py b/graphene/types/inputobjecttype.py index 2fe41e7d..671611a5 100644 --- a/graphene/types/inputobjecttype.py +++ b/graphene/types/inputobjecttype.py @@ -2,7 +2,7 @@ import six from graphql import GraphQLInputObjectType -from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap +from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType from .proxy import TypeProxy @@ -10,7 +10,7 @@ class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType): pass -class InputObjectTypeMeta(ClassTypeMeta): +class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta): def get_options(cls, meta): return cls.options_class( @@ -22,20 +22,17 @@ class InputObjectTypeMeta(ClassTypeMeta): ) def construct_graphql_type(cls, bases): + pass + + def construct(cls, bases, attrs): if not cls._meta.graphql_type and not cls._meta.abstract: - from ..utils.get_graphql_type import get_graphql_type - from ..utils.is_graphene_type import is_graphene_type - - inherited_types = [ - base._meta.graphql_type for base in bases if is_graphene_type(base) - ] - cls._meta.graphql_type = GrapheneInputObjectType( graphene_type=cls, name=cls._meta.name or cls.__name__, description=cls._meta.description or cls.__doc__, - fields=FieldMap(cls, bases=filter(None, inherited_types)), + fields=cls._fields(bases, attrs), ) + return super(InputObjectTypeMeta, cls).construct(bases, attrs) class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)): diff --git a/graphene/types/interface.py b/graphene/types/interface.py index 6c24bdd5..5400dc06 100644 --- a/graphene/types/interface.py +++ b/graphene/types/interface.py @@ -1,17 +1,14 @@ -from itertools import chain -from functools import partial -from collections import OrderedDict import six from graphql import GraphQLInterfaceType -from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap +from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType): pass -class InterfaceTypeMeta(ClassTypeMeta): +class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta): def get_options(cls, meta): return cls.options_class( @@ -25,29 +22,14 @@ class InterfaceTypeMeta(ClassTypeMeta): def construct_graphql_type(cls, bases): pass - def _build_field_map(cls, local_fields, bases): - from ..utils.extract_fields import get_base_fields - extended_fields = get_base_fields(bases) - fields = chain(extended_fields, local_fields) - return OrderedDict((f.name, f) for f in fields) - def construct(cls, bases, attrs): if not cls._meta.graphql_type and not cls._meta.abstract: - from ..utils.is_graphene_type import is_graphene_type - from ..utils.extract_fields import extract_fields - - inherited_types = [ - base._meta.graphql_type for base in bases if is_graphene_type(base) - ] - inherited_types = filter(None, inherited_types) - - local_fields = list(extract_fields(attrs)) cls._meta.graphql_type = GrapheneInterfaceType( graphene_type=cls, name=cls._meta.name or cls.__name__, description=cls._meta.description or cls.__doc__, - fields=partial(cls._build_field_map, local_fields, inherited_types), + fields=cls._fields(bases, attrs), ) return super(InterfaceTypeMeta, cls).construct(bases, attrs) diff --git a/graphene/types/proxy.py b/graphene/types/proxy.py index f66cdde9..b56158fc 100644 --- a/graphene/types/proxy.py +++ b/graphene/types/proxy.py @@ -64,3 +64,17 @@ class TypeProxy(OrderedType): raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls)) return inner.contribute_to_class(cls, attname) + + def as_mounted(self, cls): + from .inputobjecttype import InputObjectType + from .objecttype import ObjectType + from .interface import Interface + + if issubclass(cls, (ObjectType, Interface)): + inner = self.as_field() + elif issubclass(cls, (InputObjectType)): + inner = self.as_inputfield() + else: + raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls)) + + return inner diff --git a/graphene/utils/extract_fields.py b/graphene/utils/extract_fields.py index 0dfe79e7..014c156e 100644 --- a/graphene/utils/extract_fields.py +++ b/graphene/utils/extract_fields.py @@ -1,20 +1,19 @@ import copy from .get_graphql_type import get_graphql_type -from ..types.field import Field +from ..types.field import Field, InputField from ..types.proxy import TypeProxy -def extract_fields(attrs): +def extract_fields(cls, attrs): fields = set() _fields = list() for attname, value in list(attrs.items()): - is_field = isinstance(value, Field) + is_field = isinstance(value, (Field, InputField)) is_field_proxy = isinstance(value, TypeProxy) if not (is_field or is_field_proxy): continue - - field = value.as_field() if is_field_proxy else copy.copy(value) + field = value.as_mounted(cls) if is_field_proxy else copy.copy(value) field.attname = attname fields.add(attname) del attrs[attname] @@ -23,13 +22,12 @@ def extract_fields(attrs): return sorted(_fields) -def get_base_fields(bases): +def get_base_fields(cls, bases): fields = set() for _class in bases: for attname, field in get_graphql_type(_class).get_fields().items(): if attname in fields: continue field = copy.copy(field) - field.name = attname fields.add(attname) yield field diff --git a/graphene/utils/tests/test_extract_fields.py b/graphene/utils/tests/test_extract_fields.py index b73f565b..7ccbf730 100644 --- a/graphene/utils/tests/test_extract_fields.py +++ b/graphene/utils/tests/test_extract_fields.py @@ -2,7 +2,7 @@ from collections import OrderedDict from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat from ..extract_fields import extract_fields, get_base_fields -from ...types import Field, String, Argument +from ...types import Field, String, Argument, ObjectType def test_extract_fields_attrs(): @@ -13,7 +13,7 @@ def test_extract_fields_attrs(): 'argument': Argument(String), 'graphql_field': GraphQLField(GraphQLString) } - extracted_fields = list(extract_fields(attrs)) + extracted_fields = list(extract_fields(ObjectType, attrs)) assert [f.name for f in extracted_fields] == ['fieldString', 'string'] assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other'] @@ -31,7 +31,7 @@ def test_extract_fields(): ])) bases = (int_base, float_base) - base_fields = list(get_base_fields(bases)) + base_fields = list(get_base_fields(ObjectType, bases)) assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float'] assert [f.type for f in base_fields] == [ GraphQLInt,