Improved interface fields getter

This commit is contained in:
Syrus Akbary 2016-06-07 21:58:06 -07:00
parent 01190fb6ff
commit 9f655d9416
4 changed files with 105 additions and 12 deletions

View File

@ -12,17 +12,16 @@ class NodeMeta(InterfaceTypeMeta):
def construct_graphql_type(cls, bases): def construct_graphql_type(cls, bases):
pass pass
def construct(cls, *args, **kwargs): def construct(cls, bases, attrs):
constructed = super(NodeMeta, cls).construct(*args, **kwargs) cls.get_node = attrs.pop('get_node')
if not cls._meta.graphql_type: node_interface, node_field = node_definitions(
node_interface, node_field = node_definitions( cls.get_node,
cls.get_node, interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
interface_class=partial(GrapheneInterfaceType, graphene_type=cls), field_class=Field
field_class=Field )
) cls._meta.graphql_type = node_interface
cls._meta.graphql_type = node_interface cls._Field = node_field
cls._Field = node_field return super(NodeMeta, cls).construct(bases, attrs)
return constructed
@property @property
def Field(cls): def Field(cls):

View File

@ -1,3 +1,6 @@
from itertools import chain
from functools import partial
from collections import OrderedDict
import six import six
from graphql import GraphQLInterfaceType from graphql import GraphQLInterfaceType
@ -20,18 +23,33 @@ class InterfaceTypeMeta(ClassTypeMeta):
) )
def construct_graphql_type(cls, bases): def construct_graphql_type(cls, bases):
pass
def _build_field_map(cls, local_fields, bases):
from ..utils.extract_fields import get_base_fields
extended_fields = get_base_fields(bases)
fields = chain(extended_fields, local_fields)
return OrderedDict((f.name, f) for f in fields)
def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract: if not cls._meta.graphql_type and not cls._meta.abstract:
from ..utils.is_graphene_type import is_graphene_type from ..utils.is_graphene_type import is_graphene_type
from ..utils.extract_fields import extract_fields
inherited_types = [ inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base) base._meta.graphql_type for base in bases if is_graphene_type(base)
] ]
inherited_types = filter(None, inherited_types)
local_fields = list(extract_fields(attrs))
cls._meta.graphql_type = GrapheneInterfaceType( cls._meta.graphql_type = GrapheneInterfaceType(
graphene_type=cls, graphene_type=cls,
name=cls._meta.name or cls.__name__, name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__, description=cls._meta.description or cls.__doc__,
fields=FieldMap(cls, bases=filter(None, inherited_types)), fields=partial(cls._build_field_map, local_fields, inherited_types),
) )
return super(InterfaceTypeMeta, cls).construct(bases, attrs)
class Interface(six.with_metaclass(InterfaceTypeMeta)): class Interface(six.with_metaclass(InterfaceTypeMeta)):

View File

@ -0,0 +1,35 @@
import copy
from .get_graphql_type import get_graphql_type
from ..types.field import Field
from ..types.proxy import TypeProxy
def extract_fields(attrs):
fields = set()
_fields = list()
for attname, value in list(attrs.items()):
is_field = isinstance(value, Field)
is_field_proxy = isinstance(value, TypeProxy)
if not (is_field or is_field_proxy):
continue
field = value.as_field() if is_field_proxy else copy.copy(value)
field.attname = attname
fields.add(attname)
del attrs[attname]
_fields.append(field)
return sorted(_fields)
def get_base_fields(bases):
fields = set()
for _class in bases:
for attname, field in get_graphql_type(_class).get_fields().items():
if attname in fields:
continue
field = copy.copy(field)
field.name = attname
fields.add(attname)
yield field

View File

@ -0,0 +1,41 @@
from collections import OrderedDict
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
from ..extract_fields import extract_fields, get_base_fields
from ...types import Field, String, Argument
def test_extract_fields_attrs():
attrs = {
'field_string': Field(String),
'string': String(),
'other': None,
'argument': Argument(String),
'graphql_field': GraphQLField(GraphQLString)
}
extracted_fields = list(extract_fields(attrs))
assert [f.name for f in extracted_fields] == ['fieldString', 'string']
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
def test_extract_fields():
int_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([
('int', GraphQLField(GraphQLInt)),
('num', GraphQLField(GraphQLInt)),
('extra', GraphQLField(GraphQLInt))
]))
float_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([
('float', GraphQLField(GraphQLFloat)),
('num', GraphQLField(GraphQLFloat)),
('extra', GraphQLField(GraphQLFloat))
]))
bases = (int_base, float_base)
base_fields = list(get_base_fields(bases))
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
assert [f.type for f in base_fields] == [
GraphQLInt,
GraphQLInt,
GraphQLInt,
GraphQLFloat,
]