Improved fields mounting

This commit is contained in:
Syrus Akbary 2016-06-07 22:39:29 -07:00
parent 9f655d9416
commit 25e967200b
6 changed files with 54 additions and 41 deletions

View File

@ -1,6 +1,8 @@
from collections import OrderedDict from collections import OrderedDict
import inspect import inspect
import copy import copy
from itertools import chain
from functools import partial
from graphql.utils.assert_valid_name import assert_valid_name from graphql.utils.assert_valid_name import assert_valid_name
from graphql.type.definition import GraphQLObjectType from graphql.type.definition import GraphQLObjectType
@ -58,6 +60,26 @@ class ClassTypeMeta(type):
return cls return cls
class FieldsMeta(type):
def _build_field_map(cls, bases, local_fields):
from ..utils.extract_fields import get_base_fields
extended_fields = get_base_fields(cls, bases)
fields = chain(extended_fields, local_fields)
return OrderedDict((f.name, f) for f in fields)
def _fields(cls, bases, attrs):
from ..utils.is_graphene_type import is_graphene_type
from ..utils.extract_fields import extract_fields
inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract
]
local_fields = extract_fields(cls, attrs)
return partial(cls._build_field_map, inherited_types, local_fields)
class GrapheneGraphQLType(object): class GrapheneGraphQLType(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.graphene_type = kwargs.pop('graphene_type') self.graphene_type = kwargs.pop('graphene_type')

View File

@ -2,7 +2,7 @@ import six
from graphql import GraphQLInputObjectType from graphql import GraphQLInputObjectType
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
from .proxy import TypeProxy from .proxy import TypeProxy
@ -10,7 +10,7 @@ class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType):
pass pass
class InputObjectTypeMeta(ClassTypeMeta): class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta):
def get_options(cls, meta): def get_options(cls, meta):
return cls.options_class( return cls.options_class(
@ -22,20 +22,17 @@ class InputObjectTypeMeta(ClassTypeMeta):
) )
def construct_graphql_type(cls, bases): def construct_graphql_type(cls, bases):
pass
def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract: if not cls._meta.graphql_type and not cls._meta.abstract:
from ..utils.get_graphql_type import get_graphql_type
from ..utils.is_graphene_type import is_graphene_type
inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base)
]
cls._meta.graphql_type = GrapheneInputObjectType( cls._meta.graphql_type = GrapheneInputObjectType(
graphene_type=cls, graphene_type=cls,
name=cls._meta.name or cls.__name__, name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__, description=cls._meta.description or cls.__doc__,
fields=FieldMap(cls, bases=filter(None, inherited_types)), fields=cls._fields(bases, attrs),
) )
return super(InputObjectTypeMeta, cls).construct(bases, attrs)
class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)): class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)):

View File

@ -1,17 +1,14 @@
from itertools import chain
from functools import partial
from collections import OrderedDict
import six import six
from graphql import GraphQLInterfaceType from graphql import GraphQLInterfaceType
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType): class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType):
pass pass
class InterfaceTypeMeta(ClassTypeMeta): class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta):
def get_options(cls, meta): def get_options(cls, meta):
return cls.options_class( return cls.options_class(
@ -25,29 +22,14 @@ class InterfaceTypeMeta(ClassTypeMeta):
def construct_graphql_type(cls, bases): def construct_graphql_type(cls, bases):
pass pass
def _build_field_map(cls, local_fields, bases):
from ..utils.extract_fields import get_base_fields
extended_fields = get_base_fields(bases)
fields = chain(extended_fields, local_fields)
return OrderedDict((f.name, f) for f in fields)
def construct(cls, bases, attrs): def construct(cls, bases, attrs):
if not cls._meta.graphql_type and not cls._meta.abstract: if not cls._meta.graphql_type and not cls._meta.abstract:
from ..utils.is_graphene_type import is_graphene_type
from ..utils.extract_fields import extract_fields
inherited_types = [
base._meta.graphql_type for base in bases if is_graphene_type(base)
]
inherited_types = filter(None, inherited_types)
local_fields = list(extract_fields(attrs))
cls._meta.graphql_type = GrapheneInterfaceType( cls._meta.graphql_type = GrapheneInterfaceType(
graphene_type=cls, graphene_type=cls,
name=cls._meta.name or cls.__name__, name=cls._meta.name or cls.__name__,
description=cls._meta.description or cls.__doc__, description=cls._meta.description or cls.__doc__,
fields=partial(cls._build_field_map, local_fields, inherited_types), fields=cls._fields(bases, attrs),
) )
return super(InterfaceTypeMeta, cls).construct(bases, attrs) return super(InterfaceTypeMeta, cls).construct(bases, attrs)

View File

@ -64,3 +64,17 @@ class TypeProxy(OrderedType):
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls)) raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
return inner.contribute_to_class(cls, attname) return inner.contribute_to_class(cls, attname)
def as_mounted(self, cls):
from .inputobjecttype import InputObjectType
from .objecttype import ObjectType
from .interface import Interface
if issubclass(cls, (ObjectType, Interface)):
inner = self.as_field()
elif issubclass(cls, (InputObjectType)):
inner = self.as_inputfield()
else:
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
return inner

View File

@ -1,20 +1,19 @@
import copy import copy
from .get_graphql_type import get_graphql_type from .get_graphql_type import get_graphql_type
from ..types.field import Field from ..types.field import Field, InputField
from ..types.proxy import TypeProxy from ..types.proxy import TypeProxy
def extract_fields(attrs): def extract_fields(cls, attrs):
fields = set() fields = set()
_fields = list() _fields = list()
for attname, value in list(attrs.items()): for attname, value in list(attrs.items()):
is_field = isinstance(value, Field) is_field = isinstance(value, (Field, InputField))
is_field_proxy = isinstance(value, TypeProxy) is_field_proxy = isinstance(value, TypeProxy)
if not (is_field or is_field_proxy): if not (is_field or is_field_proxy):
continue continue
field = value.as_mounted(cls) if is_field_proxy else copy.copy(value)
field = value.as_field() if is_field_proxy else copy.copy(value)
field.attname = attname field.attname = attname
fields.add(attname) fields.add(attname)
del attrs[attname] del attrs[attname]
@ -23,13 +22,12 @@ def extract_fields(attrs):
return sorted(_fields) return sorted(_fields)
def get_base_fields(bases): def get_base_fields(cls, bases):
fields = set() fields = set()
for _class in bases: for _class in bases:
for attname, field in get_graphql_type(_class).get_fields().items(): for attname, field in get_graphql_type(_class).get_fields().items():
if attname in fields: if attname in fields:
continue continue
field = copy.copy(field) field = copy.copy(field)
field.name = attname
fields.add(attname) fields.add(attname)
yield field yield field

View File

@ -2,7 +2,7 @@ from collections import OrderedDict
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
from ..extract_fields import extract_fields, get_base_fields from ..extract_fields import extract_fields, get_base_fields
from ...types import Field, String, Argument from ...types import Field, String, Argument, ObjectType
def test_extract_fields_attrs(): def test_extract_fields_attrs():
@ -13,7 +13,7 @@ def test_extract_fields_attrs():
'argument': Argument(String), 'argument': Argument(String),
'graphql_field': GraphQLField(GraphQLString) 'graphql_field': GraphQLField(GraphQLString)
} }
extracted_fields = list(extract_fields(attrs)) extracted_fields = list(extract_fields(ObjectType, attrs))
assert [f.name for f in extracted_fields] == ['fieldString', 'string'] assert [f.name for f in extracted_fields] == ['fieldString', 'string']
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other'] assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
@ -31,7 +31,7 @@ def test_extract_fields():
])) ]))
bases = (int_base, float_base) bases = (int_base, float_base)
base_fields = list(get_base_fields(bases)) base_fields = list(get_base_fields(ObjectType, bases))
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float'] assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
assert [f.type for f in base_fields] == [ assert [f.type for f in base_fields] == [
GraphQLInt, GraphQLInt,