From 9f655d9416d74b6c85b396e7234e6f917df9a499 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Tue, 7 Jun 2016 21:58:06 -0700 Subject: [PATCH] Improved interface fields getter --- graphene/relay/node.py | 21 +++++------ graphene/types/interface.py | 20 +++++++++- graphene/utils/extract_fields.py | 35 ++++++++++++++++++ graphene/utils/tests/test_extract_fields.py | 41 +++++++++++++++++++++ 4 files changed, 105 insertions(+), 12 deletions(-) create mode 100644 graphene/utils/extract_fields.py create mode 100644 graphene/utils/tests/test_extract_fields.py diff --git a/graphene/relay/node.py b/graphene/relay/node.py index c081e223..ea3d9e5f 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -12,17 +12,16 @@ class NodeMeta(InterfaceTypeMeta): def construct_graphql_type(cls, bases): pass - def construct(cls, *args, **kwargs): - constructed = super(NodeMeta, cls).construct(*args, **kwargs) - if not cls._meta.graphql_type: - node_interface, node_field = node_definitions( - cls.get_node, - interface_class=partial(GrapheneInterfaceType, graphene_type=cls), - field_class=Field - ) - cls._meta.graphql_type = node_interface - cls._Field = node_field - return constructed + def construct(cls, bases, attrs): + cls.get_node = attrs.pop('get_node') + node_interface, node_field = node_definitions( + cls.get_node, + interface_class=partial(GrapheneInterfaceType, graphene_type=cls), + field_class=Field + ) + cls._meta.graphql_type = node_interface + cls._Field = node_field + return super(NodeMeta, cls).construct(bases, attrs) @property def Field(cls): diff --git a/graphene/types/interface.py b/graphene/types/interface.py index 9462db30..6c24bdd5 100644 --- a/graphene/types/interface.py +++ b/graphene/types/interface.py @@ -1,3 +1,6 @@ +from itertools import chain +from functools import partial +from collections import OrderedDict import six from graphql import GraphQLInterfaceType @@ -20,18 +23,33 @@ class InterfaceTypeMeta(ClassTypeMeta): ) def construct_graphql_type(cls, bases): + pass + + def _build_field_map(cls, local_fields, bases): + from ..utils.extract_fields import get_base_fields + extended_fields = get_base_fields(bases) + fields = chain(extended_fields, local_fields) + return OrderedDict((f.name, f) for f in fields) + + def construct(cls, bases, attrs): if not cls._meta.graphql_type and not cls._meta.abstract: from ..utils.is_graphene_type import is_graphene_type + from ..utils.extract_fields import extract_fields + inherited_types = [ base._meta.graphql_type for base in bases if is_graphene_type(base) ] + inherited_types = filter(None, inherited_types) + + local_fields = list(extract_fields(attrs)) cls._meta.graphql_type = GrapheneInterfaceType( graphene_type=cls, name=cls._meta.name or cls.__name__, description=cls._meta.description or cls.__doc__, - fields=FieldMap(cls, bases=filter(None, inherited_types)), + fields=partial(cls._build_field_map, local_fields, inherited_types), ) + return super(InterfaceTypeMeta, cls).construct(bases, attrs) class Interface(six.with_metaclass(InterfaceTypeMeta)): diff --git a/graphene/utils/extract_fields.py b/graphene/utils/extract_fields.py new file mode 100644 index 00000000..0dfe79e7 --- /dev/null +++ b/graphene/utils/extract_fields.py @@ -0,0 +1,35 @@ +import copy +from .get_graphql_type import get_graphql_type + +from ..types.field import Field +from ..types.proxy import TypeProxy + + +def extract_fields(attrs): + fields = set() + _fields = list() + for attname, value in list(attrs.items()): + is_field = isinstance(value, Field) + is_field_proxy = isinstance(value, TypeProxy) + if not (is_field or is_field_proxy): + continue + + field = value.as_field() if is_field_proxy else copy.copy(value) + field.attname = attname + fields.add(attname) + del attrs[attname] + _fields.append(field) + + return sorted(_fields) + + +def get_base_fields(bases): + fields = set() + for _class in bases: + for attname, field in get_graphql_type(_class).get_fields().items(): + if attname in fields: + continue + field = copy.copy(field) + field.name = attname + fields.add(attname) + yield field diff --git a/graphene/utils/tests/test_extract_fields.py b/graphene/utils/tests/test_extract_fields.py new file mode 100644 index 00000000..b73f565b --- /dev/null +++ b/graphene/utils/tests/test_extract_fields.py @@ -0,0 +1,41 @@ +from collections import OrderedDict +from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat +from ..extract_fields import extract_fields, get_base_fields + +from ...types import Field, String, Argument + + +def test_extract_fields_attrs(): + attrs = { + 'field_string': Field(String), + 'string': String(), + 'other': None, + 'argument': Argument(String), + 'graphql_field': GraphQLField(GraphQLString) + } + extracted_fields = list(extract_fields(attrs)) + assert [f.name for f in extracted_fields] == ['fieldString', 'string'] + assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other'] + + +def test_extract_fields(): + int_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([ + ('int', GraphQLField(GraphQLInt)), + ('num', GraphQLField(GraphQLInt)), + ('extra', GraphQLField(GraphQLInt)) + ])) + float_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([ + ('float', GraphQLField(GraphQLFloat)), + ('num', GraphQLField(GraphQLFloat)), + ('extra', GraphQLField(GraphQLFloat)) + ])) + + bases = (int_base, float_base) + base_fields = list(get_base_fields(bases)) + assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float'] + assert [f.type for f in base_fields] == [ + GraphQLInt, + GraphQLInt, + GraphQLInt, + GraphQLFloat, + ]