mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-11 12:16:58 +03:00
Improved fields mounting
This commit is contained in:
parent
9f655d9416
commit
25e967200b
|
@ -1,6 +1,8 @@
|
|||
from collections import OrderedDict
|
||||
import inspect
|
||||
import copy
|
||||
from itertools import chain
|
||||
from functools import partial
|
||||
|
||||
from graphql.utils.assert_valid_name import assert_valid_name
|
||||
from graphql.type.definition import GraphQLObjectType
|
||||
|
@ -58,6 +60,26 @@ class ClassTypeMeta(type):
|
|||
return cls
|
||||
|
||||
|
||||
class FieldsMeta(type):
|
||||
|
||||
def _build_field_map(cls, bases, local_fields):
|
||||
from ..utils.extract_fields import get_base_fields
|
||||
extended_fields = get_base_fields(cls, bases)
|
||||
fields = chain(extended_fields, local_fields)
|
||||
return OrderedDict((f.name, f) for f in fields)
|
||||
|
||||
def _fields(cls, bases, attrs):
|
||||
from ..utils.is_graphene_type import is_graphene_type
|
||||
from ..utils.extract_fields import extract_fields
|
||||
|
||||
inherited_types = [
|
||||
base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract
|
||||
]
|
||||
|
||||
local_fields = extract_fields(cls, attrs)
|
||||
return partial(cls._build_field_map, inherited_types, local_fields)
|
||||
|
||||
|
||||
class GrapheneGraphQLType(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.graphene_type = kwargs.pop('graphene_type')
|
||||
|
|
|
@ -2,7 +2,7 @@ import six
|
|||
|
||||
from graphql import GraphQLInputObjectType
|
||||
|
||||
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
|
||||
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
|
||||
from .proxy import TypeProxy
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType):
|
|||
pass
|
||||
|
||||
|
||||
class InputObjectTypeMeta(ClassTypeMeta):
|
||||
class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta):
|
||||
|
||||
def get_options(cls, meta):
|
||||
return cls.options_class(
|
||||
|
@ -22,20 +22,17 @@ class InputObjectTypeMeta(ClassTypeMeta):
|
|||
)
|
||||
|
||||
def construct_graphql_type(cls, bases):
|
||||
pass
|
||||
|
||||
def construct(cls, bases, attrs):
|
||||
if not cls._meta.graphql_type and not cls._meta.abstract:
|
||||
from ..utils.get_graphql_type import get_graphql_type
|
||||
from ..utils.is_graphene_type import is_graphene_type
|
||||
|
||||
inherited_types = [
|
||||
base._meta.graphql_type for base in bases if is_graphene_type(base)
|
||||
]
|
||||
|
||||
cls._meta.graphql_type = GrapheneInputObjectType(
|
||||
graphene_type=cls,
|
||||
name=cls._meta.name or cls.__name__,
|
||||
description=cls._meta.description or cls.__doc__,
|
||||
fields=FieldMap(cls, bases=filter(None, inherited_types)),
|
||||
fields=cls._fields(bases, attrs),
|
||||
)
|
||||
return super(InputObjectTypeMeta, cls).construct(bases, attrs)
|
||||
|
||||
|
||||
class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)):
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
from itertools import chain
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
import six
|
||||
|
||||
from graphql import GraphQLInterfaceType
|
||||
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
|
||||
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
|
||||
|
||||
|
||||
class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType):
|
||||
pass
|
||||
|
||||
|
||||
class InterfaceTypeMeta(ClassTypeMeta):
|
||||
class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta):
|
||||
|
||||
def get_options(cls, meta):
|
||||
return cls.options_class(
|
||||
|
@ -25,29 +22,14 @@ class InterfaceTypeMeta(ClassTypeMeta):
|
|||
def construct_graphql_type(cls, bases):
|
||||
pass
|
||||
|
||||
def _build_field_map(cls, local_fields, bases):
|
||||
from ..utils.extract_fields import get_base_fields
|
||||
extended_fields = get_base_fields(bases)
|
||||
fields = chain(extended_fields, local_fields)
|
||||
return OrderedDict((f.name, f) for f in fields)
|
||||
|
||||
def construct(cls, bases, attrs):
|
||||
if not cls._meta.graphql_type and not cls._meta.abstract:
|
||||
from ..utils.is_graphene_type import is_graphene_type
|
||||
from ..utils.extract_fields import extract_fields
|
||||
|
||||
inherited_types = [
|
||||
base._meta.graphql_type for base in bases if is_graphene_type(base)
|
||||
]
|
||||
inherited_types = filter(None, inherited_types)
|
||||
|
||||
local_fields = list(extract_fields(attrs))
|
||||
|
||||
cls._meta.graphql_type = GrapheneInterfaceType(
|
||||
graphene_type=cls,
|
||||
name=cls._meta.name or cls.__name__,
|
||||
description=cls._meta.description or cls.__doc__,
|
||||
fields=partial(cls._build_field_map, local_fields, inherited_types),
|
||||
fields=cls._fields(bases, attrs),
|
||||
)
|
||||
return super(InterfaceTypeMeta, cls).construct(bases, attrs)
|
||||
|
||||
|
|
|
@ -64,3 +64,17 @@ class TypeProxy(OrderedType):
|
|||
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
|
||||
|
||||
return inner.contribute_to_class(cls, attname)
|
||||
|
||||
def as_mounted(self, cls):
|
||||
from .inputobjecttype import InputObjectType
|
||||
from .objecttype import ObjectType
|
||||
from .interface import Interface
|
||||
|
||||
if issubclass(cls, (ObjectType, Interface)):
|
||||
inner = self.as_field()
|
||||
elif issubclass(cls, (InputObjectType)):
|
||||
inner = self.as_inputfield()
|
||||
else:
|
||||
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
|
||||
|
||||
return inner
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
import copy
|
||||
from .get_graphql_type import get_graphql_type
|
||||
|
||||
from ..types.field import Field
|
||||
from ..types.field import Field, InputField
|
||||
from ..types.proxy import TypeProxy
|
||||
|
||||
|
||||
def extract_fields(attrs):
|
||||
def extract_fields(cls, attrs):
|
||||
fields = set()
|
||||
_fields = list()
|
||||
for attname, value in list(attrs.items()):
|
||||
is_field = isinstance(value, Field)
|
||||
is_field = isinstance(value, (Field, InputField))
|
||||
is_field_proxy = isinstance(value, TypeProxy)
|
||||
if not (is_field or is_field_proxy):
|
||||
continue
|
||||
|
||||
field = value.as_field() if is_field_proxy else copy.copy(value)
|
||||
field = value.as_mounted(cls) if is_field_proxy else copy.copy(value)
|
||||
field.attname = attname
|
||||
fields.add(attname)
|
||||
del attrs[attname]
|
||||
|
@ -23,13 +22,12 @@ def extract_fields(attrs):
|
|||
return sorted(_fields)
|
||||
|
||||
|
||||
def get_base_fields(bases):
|
||||
def get_base_fields(cls, bases):
|
||||
fields = set()
|
||||
for _class in bases:
|
||||
for attname, field in get_graphql_type(_class).get_fields().items():
|
||||
if attname in fields:
|
||||
continue
|
||||
field = copy.copy(field)
|
||||
field.name = attname
|
||||
fields.add(attname)
|
||||
yield field
|
||||
|
|
|
@ -2,7 +2,7 @@ from collections import OrderedDict
|
|||
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
|
||||
from ..extract_fields import extract_fields, get_base_fields
|
||||
|
||||
from ...types import Field, String, Argument
|
||||
from ...types import Field, String, Argument, ObjectType
|
||||
|
||||
|
||||
def test_extract_fields_attrs():
|
||||
|
@ -13,7 +13,7 @@ def test_extract_fields_attrs():
|
|||
'argument': Argument(String),
|
||||
'graphql_field': GraphQLField(GraphQLString)
|
||||
}
|
||||
extracted_fields = list(extract_fields(attrs))
|
||||
extracted_fields = list(extract_fields(ObjectType, attrs))
|
||||
assert [f.name for f in extracted_fields] == ['fieldString', 'string']
|
||||
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
|
||||
|
||||
|
@ -31,7 +31,7 @@ def test_extract_fields():
|
|||
]))
|
||||
|
||||
bases = (int_base, float_base)
|
||||
base_fields = list(get_base_fields(bases))
|
||||
base_fields = list(get_base_fields(ObjectType, bases))
|
||||
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
|
||||
assert [f.type for f in base_fields] == [
|
||||
GraphQLInt,
|
||||
|
|
Loading…
Reference in New Issue
Block a user