Improved interface fields getter

This commit is contained in:
Syrus Akbary 2016-06-07 21:58:06 -07:00
parent 01190fb6ff
commit 9f655d9416
4 changed files with 105 additions and 12 deletions

View File

@ -12,17 +12,16 @@ class NodeMeta(InterfaceTypeMeta):
def construct_graphql_type(cls, bases):
pass
def construct(cls, *args, **kwargs):
constructed = super(NodeMeta, cls).construct(*args, **kwargs)
if not cls._meta.graphql_type:
node_interface, node_field = node_definitions(
cls.get_node,
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
field_class=Field
)
cls._meta.graphql_type = node_interface
cls._Field = node_field
return constructed
def construct(cls, bases, attrs):
cls.get_node = attrs.pop('get_node')
node_interface, node_field = node_definitions(
cls.get_node,
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
field_class=Field
)
cls._meta.graphql_type = node_interface
cls._Field = node_field
return super(NodeMeta, cls).construct(bases, attrs)
@property
def Field(cls):

View File

@ -1,3 +1,6 @@
from itertools import chain
from functools import partial
from collections import OrderedDict
import six
from graphql import GraphQLInterfaceType
@ -20,18 +23,33 @@ class InterfaceTypeMeta(ClassTypeMeta):
)
def construct_graphql_type(cls, bases):
pass
def _build_field_map(cls, local_fields, bases):
from ..utils.extract_fields import get_base_fields
extended_fields = get_base_fields(bases)
fields = chain(extended_fields, local_fields)
return OrderedDict((f.name, f) for f in fields)
def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract:
from ..utils.is_graphene_type import is_graphene_type
from ..utils.extract_fields import extract_fields
inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base)
]
inherited_types = filter(None, inherited_types)
local_fields = list(extract_fields(attrs))
cls._meta.graphql_type = GrapheneInterfaceType(
graphene_type=cls,
name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__,
fields=FieldMap(cls, bases=filter(None, inherited_types)),
fields=partial(cls._build_field_map, local_fields, inherited_types),
)
return super(InterfaceTypeMeta, cls).construct(bases, attrs)
class Interface(six.with_metaclass(InterfaceTypeMeta)):

View File

@ -0,0 +1,35 @@
import copy
from .get_graphql_type import get_graphql_type
from ..types.field import Field
from ..types.proxy import TypeProxy
def extract_fields(attrs):
fields = set()
_fields = list()
for attname, value in list(attrs.items()):
is_field = isinstance(value, Field)
is_field_proxy = isinstance(value, TypeProxy)
if not (is_field or is_field_proxy):
continue
field = value.as_field() if is_field_proxy else copy.copy(value)
field.attname = attname
fields.add(attname)
del attrs[attname]
_fields.append(field)
return sorted(_fields)
def get_base_fields(bases):
fields = set()
for _class in bases:
for attname, field in get_graphql_type(_class).get_fields().items():
if attname in fields:
continue
field = copy.copy(field)
field.name = attname
fields.add(attname)
yield field

View File

@ -0,0 +1,41 @@
from collections import OrderedDict
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
from ..extract_fields import extract_fields, get_base_fields
from ...types import Field, String, Argument
def test_extract_fields_attrs():
attrs = {
'field_string': Field(String),
'string': String(),
'other': None,
'argument': Argument(String),
'graphql_field': GraphQLField(GraphQLString)
}
extracted_fields = list(extract_fields(attrs))
assert [f.name for f in extracted_fields] == ['fieldString', 'string']
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
def test_extract_fields():
int_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([
('int', GraphQLField(GraphQLInt)),
('num', GraphQLField(GraphQLInt)),
('extra', GraphQLField(GraphQLInt))
]))
float_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([
('float', GraphQLField(GraphQLFloat)),
('num', GraphQLField(GraphQLFloat)),
('extra', GraphQLField(GraphQLFloat))
]))
bases = (int_base, float_base)
base_fields = list(get_base_fields(bases))
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
assert [f.type for f in base_fields] == [
GraphQLInt,
GraphQLInt,
GraphQLInt,
GraphQLFloat,
]