Improved fields mounting

This commit is contained in:
Syrus Akbary 2016-06-07 22:39:29 -07:00
parent 9f655d9416
commit 25e967200b
6 changed files with 54 additions and 41 deletions

View File

@ -1,6 +1,8 @@
from collections import OrderedDict
import inspect
import copy
from itertools import chain
from functools import partial
from graphql.utils.assert_valid_name import assert_valid_name
from graphql.type.definition import GraphQLObjectType
@ -58,6 +60,26 @@ class ClassTypeMeta(type):
return cls
class FieldsMeta(type):
def _build_field_map(cls, bases, local_fields):
from ..utils.extract_fields import get_base_fields
extended_fields = get_base_fields(cls, bases)
fields = chain(extended_fields, local_fields)
return OrderedDict((f.name, f) for f in fields)
def _fields(cls, bases, attrs):
from ..utils.is_graphene_type import is_graphene_type
from ..utils.extract_fields import extract_fields
inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract
]
local_fields = extract_fields(cls, attrs)
return partial(cls._build_field_map, inherited_types, local_fields)
class GrapheneGraphQLType(object):
def __init__(self, *args, **kwargs):
self.graphene_type = kwargs.pop('graphene_type')

View File

@ -2,7 +2,7 @@ import six
from graphql import GraphQLInputObjectType
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
from .proxy import TypeProxy
@ -10,7 +10,7 @@ class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType):
pass
class InputObjectTypeMeta(ClassTypeMeta):
class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta):
def get_options(cls, meta):
return cls.options_class(
@ -22,20 +22,17 @@ class InputObjectTypeMeta(ClassTypeMeta):
)
def construct_graphql_type(cls, bases):
pass
def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract:
from ..utils.get_graphql_type import get_graphql_type
from ..utils.is_graphene_type import is_graphene_type
inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base)
]
cls._meta.graphql_type = GrapheneInputObjectType(
graphene_type=cls,
name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__,
fields=FieldMap(cls, bases=filter(None, inherited_types)),
fields=cls._fields(bases, attrs),
)
return super(InputObjectTypeMeta, cls).construct(bases, attrs)
class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)):

View File

@ -1,17 +1,14 @@
from itertools import chain
from functools import partial
from collections import OrderedDict
import six
from graphql import GraphQLInterfaceType
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType):
pass
class InterfaceTypeMeta(ClassTypeMeta):
class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta):
def get_options(cls, meta):
return cls.options_class(
@ -25,29 +22,14 @@ class InterfaceTypeMeta(ClassTypeMeta):
def construct_graphql_type(cls, bases):
pass
def _build_field_map(cls, local_fields, bases):
from ..utils.extract_fields import get_base_fields
extended_fields = get_base_fields(bases)
fields = chain(extended_fields, local_fields)
return OrderedDict((f.name, f) for f in fields)
def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract:
from ..utils.is_graphene_type import is_graphene_type
from ..utils.extract_fields import extract_fields
inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base)
]
inherited_types = filter(None, inherited_types)
local_fields = list(extract_fields(attrs))
cls._meta.graphql_type = GrapheneInterfaceType(
graphene_type=cls,
name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__,
fields=partial(cls._build_field_map, local_fields, inherited_types),
fields=cls._fields(bases, attrs),
)
return super(InterfaceTypeMeta, cls).construct(bases, attrs)

View File

@ -64,3 +64,17 @@ class TypeProxy(OrderedType):
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
return inner.contribute_to_class(cls, attname)
def as_mounted(self, cls):
from .inputobjecttype import InputObjectType
from .objecttype import ObjectType
from .interface import Interface
if issubclass(cls, (ObjectType, Interface)):
inner = self.as_field()
elif issubclass(cls, (InputObjectType)):
inner = self.as_inputfield()
else:
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
return inner

View File

@ -1,20 +1,19 @@
import copy
from .get_graphql_type import get_graphql_type
from ..types.field import Field
from ..types.field import Field, InputField
from ..types.proxy import TypeProxy
def extract_fields(attrs):
def extract_fields(cls, attrs):
fields = set()
_fields = list()
for attname, value in list(attrs.items()):
is_field = isinstance(value, Field)
is_field = isinstance(value, (Field, InputField))
is_field_proxy = isinstance(value, TypeProxy)
if not (is_field or is_field_proxy):
continue
field = value.as_field() if is_field_proxy else copy.copy(value)
field = value.as_mounted(cls) if is_field_proxy else copy.copy(value)
field.attname = attname
fields.add(attname)
del attrs[attname]
@ -23,13 +22,12 @@ def extract_fields(attrs):
return sorted(_fields)
def get_base_fields(bases):
def get_base_fields(cls, bases):
fields = set()
for _class in bases:
for attname, field in get_graphql_type(_class).get_fields().items():
if attname in fields:
continue
field = copy.copy(field)
field.name = attname
fields.add(attname)
yield field

View File

@ -2,7 +2,7 @@ from collections import OrderedDict
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
from ..extract_fields import extract_fields, get_base_fields
from ...types import Field, String, Argument
from ...types import Field, String, Argument, ObjectType
def test_extract_fields_attrs():
@ -13,7 +13,7 @@ def test_extract_fields_attrs():
'argument': Argument(String),
'graphql_field': GraphQLField(GraphQLString)
}
extracted_fields = list(extract_fields(attrs))
extracted_fields = list(extract_fields(ObjectType, attrs))
assert [f.name for f in extracted_fields] == ['fieldString', 'string']
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
@ -31,7 +31,7 @@ def test_extract_fields():
]))
bases = (int_base, float_base)
base_fields = list(get_base_fields(bases))
base_fields = list(get_base_fields(ObjectType, bases))
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
assert [f.type for f in base_fields] == [
GraphQLInt,