mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-08 23:50:38 +03:00
Improved fields mounting
This commit is contained in:
parent
9f655d9416
commit
25e967200b
|
@ -1,6 +1,8 @@
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import inspect
|
import inspect
|
||||||
import copy
|
import copy
|
||||||
|
from itertools import chain
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from graphql.utils.assert_valid_name import assert_valid_name
|
from graphql.utils.assert_valid_name import assert_valid_name
|
||||||
from graphql.type.definition import GraphQLObjectType
|
from graphql.type.definition import GraphQLObjectType
|
||||||
|
@ -58,6 +60,26 @@ class ClassTypeMeta(type):
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class FieldsMeta(type):
|
||||||
|
|
||||||
|
def _build_field_map(cls, bases, local_fields):
|
||||||
|
from ..utils.extract_fields import get_base_fields
|
||||||
|
extended_fields = get_base_fields(cls, bases)
|
||||||
|
fields = chain(extended_fields, local_fields)
|
||||||
|
return OrderedDict((f.name, f) for f in fields)
|
||||||
|
|
||||||
|
def _fields(cls, bases, attrs):
|
||||||
|
from ..utils.is_graphene_type import is_graphene_type
|
||||||
|
from ..utils.extract_fields import extract_fields
|
||||||
|
|
||||||
|
inherited_types = [
|
||||||
|
base._meta.graphql_type for base in bases if is_graphene_type(base) and not base._meta.abstract
|
||||||
|
]
|
||||||
|
|
||||||
|
local_fields = extract_fields(cls, attrs)
|
||||||
|
return partial(cls._build_field_map, inherited_types, local_fields)
|
||||||
|
|
||||||
|
|
||||||
class GrapheneGraphQLType(object):
|
class GrapheneGraphQLType(object):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.graphene_type = kwargs.pop('graphene_type')
|
self.graphene_type = kwargs.pop('graphene_type')
|
||||||
|
|
|
@ -2,7 +2,7 @@ import six
|
||||||
|
|
||||||
from graphql import GraphQLInputObjectType
|
from graphql import GraphQLInputObjectType
|
||||||
|
|
||||||
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
|
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
|
||||||
from .proxy import TypeProxy
|
from .proxy import TypeProxy
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ class GrapheneInputObjectType(GrapheneFieldsType, GraphQLInputObjectType):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InputObjectTypeMeta(ClassTypeMeta):
|
class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta):
|
||||||
|
|
||||||
def get_options(cls, meta):
|
def get_options(cls, meta):
|
||||||
return cls.options_class(
|
return cls.options_class(
|
||||||
|
@ -22,20 +22,17 @@ class InputObjectTypeMeta(ClassTypeMeta):
|
||||||
)
|
)
|
||||||
|
|
||||||
def construct_graphql_type(cls, bases):
|
def construct_graphql_type(cls, bases):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def construct(cls, bases, attrs):
|
||||||
if not cls._meta.graphql_type and not cls._meta.abstract:
|
if not cls._meta.graphql_type and not cls._meta.abstract:
|
||||||
from ..utils.get_graphql_type import get_graphql_type
|
|
||||||
from ..utils.is_graphene_type import is_graphene_type
|
|
||||||
|
|
||||||
inherited_types = [
|
|
||||||
base._meta.graphql_type for base in bases if is_graphene_type(base)
|
|
||||||
]
|
|
||||||
|
|
||||||
cls._meta.graphql_type = GrapheneInputObjectType(
|
cls._meta.graphql_type = GrapheneInputObjectType(
|
||||||
graphene_type=cls,
|
graphene_type=cls,
|
||||||
name=cls._meta.name or cls.__name__,
|
name=cls._meta.name or cls.__name__,
|
||||||
description=cls._meta.description or cls.__doc__,
|
description=cls._meta.description or cls.__doc__,
|
||||||
fields=FieldMap(cls, bases=filter(None, inherited_types)),
|
fields=cls._fields(bases, attrs),
|
||||||
)
|
)
|
||||||
|
return super(InputObjectTypeMeta, cls).construct(bases, attrs)
|
||||||
|
|
||||||
|
|
||||||
class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)):
|
class InputObjectType(six.with_metaclass(InputObjectTypeMeta, TypeProxy)):
|
||||||
|
|
|
@ -1,17 +1,14 @@
|
||||||
from itertools import chain
|
|
||||||
from functools import partial
|
|
||||||
from collections import OrderedDict
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from graphql import GraphQLInterfaceType
|
from graphql import GraphQLInterfaceType
|
||||||
from .definitions import ClassTypeMeta, GrapheneFieldsType, FieldMap
|
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneFieldsType
|
||||||
|
|
||||||
|
|
||||||
class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType):
|
class GrapheneInterfaceType(GrapheneFieldsType, GraphQLInterfaceType):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InterfaceTypeMeta(ClassTypeMeta):
|
class InterfaceTypeMeta(FieldsMeta, ClassTypeMeta):
|
||||||
|
|
||||||
def get_options(cls, meta):
|
def get_options(cls, meta):
|
||||||
return cls.options_class(
|
return cls.options_class(
|
||||||
|
@ -25,29 +22,14 @@ class InterfaceTypeMeta(ClassTypeMeta):
|
||||||
def construct_graphql_type(cls, bases):
|
def construct_graphql_type(cls, bases):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _build_field_map(cls, local_fields, bases):
|
|
||||||
from ..utils.extract_fields import get_base_fields
|
|
||||||
extended_fields = get_base_fields(bases)
|
|
||||||
fields = chain(extended_fields, local_fields)
|
|
||||||
return OrderedDict((f.name, f) for f in fields)
|
|
||||||
|
|
||||||
def construct(cls, bases, attrs):
|
def construct(cls, bases, attrs):
|
||||||
if not cls._meta.graphql_type and not cls._meta.abstract:
|
if not cls._meta.graphql_type and not cls._meta.abstract:
|
||||||
from ..utils.is_graphene_type import is_graphene_type
|
|
||||||
from ..utils.extract_fields import extract_fields
|
|
||||||
|
|
||||||
inherited_types = [
|
|
||||||
base._meta.graphql_type for base in bases if is_graphene_type(base)
|
|
||||||
]
|
|
||||||
inherited_types = filter(None, inherited_types)
|
|
||||||
|
|
||||||
local_fields = list(extract_fields(attrs))
|
|
||||||
|
|
||||||
cls._meta.graphql_type = GrapheneInterfaceType(
|
cls._meta.graphql_type = GrapheneInterfaceType(
|
||||||
graphene_type=cls,
|
graphene_type=cls,
|
||||||
name=cls._meta.name or cls.__name__,
|
name=cls._meta.name or cls.__name__,
|
||||||
description=cls._meta.description or cls.__doc__,
|
description=cls._meta.description or cls.__doc__,
|
||||||
fields=partial(cls._build_field_map, local_fields, inherited_types),
|
fields=cls._fields(bases, attrs),
|
||||||
)
|
)
|
||||||
return super(InterfaceTypeMeta, cls).construct(bases, attrs)
|
return super(InterfaceTypeMeta, cls).construct(bases, attrs)
|
||||||
|
|
||||||
|
|
|
@ -64,3 +64,17 @@ class TypeProxy(OrderedType):
|
||||||
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
|
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
|
||||||
|
|
||||||
return inner.contribute_to_class(cls, attname)
|
return inner.contribute_to_class(cls, attname)
|
||||||
|
|
||||||
|
def as_mounted(self, cls):
|
||||||
|
from .inputobjecttype import InputObjectType
|
||||||
|
from .objecttype import ObjectType
|
||||||
|
from .interface import Interface
|
||||||
|
|
||||||
|
if issubclass(cls, (ObjectType, Interface)):
|
||||||
|
inner = self.as_field()
|
||||||
|
elif issubclass(cls, (InputObjectType)):
|
||||||
|
inner = self.as_inputfield()
|
||||||
|
else:
|
||||||
|
raise Exception('TypedProxy "{}" cannot be mounted in {}'.format(self.get_type(), cls))
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
|
@ -1,20 +1,19 @@
|
||||||
import copy
|
import copy
|
||||||
from .get_graphql_type import get_graphql_type
|
from .get_graphql_type import get_graphql_type
|
||||||
|
|
||||||
from ..types.field import Field
|
from ..types.field import Field, InputField
|
||||||
from ..types.proxy import TypeProxy
|
from ..types.proxy import TypeProxy
|
||||||
|
|
||||||
|
|
||||||
def extract_fields(attrs):
|
def extract_fields(cls, attrs):
|
||||||
fields = set()
|
fields = set()
|
||||||
_fields = list()
|
_fields = list()
|
||||||
for attname, value in list(attrs.items()):
|
for attname, value in list(attrs.items()):
|
||||||
is_field = isinstance(value, Field)
|
is_field = isinstance(value, (Field, InputField))
|
||||||
is_field_proxy = isinstance(value, TypeProxy)
|
is_field_proxy = isinstance(value, TypeProxy)
|
||||||
if not (is_field or is_field_proxy):
|
if not (is_field or is_field_proxy):
|
||||||
continue
|
continue
|
||||||
|
field = value.as_mounted(cls) if is_field_proxy else copy.copy(value)
|
||||||
field = value.as_field() if is_field_proxy else copy.copy(value)
|
|
||||||
field.attname = attname
|
field.attname = attname
|
||||||
fields.add(attname)
|
fields.add(attname)
|
||||||
del attrs[attname]
|
del attrs[attname]
|
||||||
|
@ -23,13 +22,12 @@ def extract_fields(attrs):
|
||||||
return sorted(_fields)
|
return sorted(_fields)
|
||||||
|
|
||||||
|
|
||||||
def get_base_fields(bases):
|
def get_base_fields(cls, bases):
|
||||||
fields = set()
|
fields = set()
|
||||||
for _class in bases:
|
for _class in bases:
|
||||||
for attname, field in get_graphql_type(_class).get_fields().items():
|
for attname, field in get_graphql_type(_class).get_fields().items():
|
||||||
if attname in fields:
|
if attname in fields:
|
||||||
continue
|
continue
|
||||||
field = copy.copy(field)
|
field = copy.copy(field)
|
||||||
field.name = attname
|
|
||||||
fields.add(attname)
|
fields.add(attname)
|
||||||
yield field
|
yield field
|
||||||
|
|
|
@ -2,7 +2,7 @@ from collections import OrderedDict
|
||||||
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
|
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
|
||||||
from ..extract_fields import extract_fields, get_base_fields
|
from ..extract_fields import extract_fields, get_base_fields
|
||||||
|
|
||||||
from ...types import Field, String, Argument
|
from ...types import Field, String, Argument, ObjectType
|
||||||
|
|
||||||
|
|
||||||
def test_extract_fields_attrs():
|
def test_extract_fields_attrs():
|
||||||
|
@ -13,7 +13,7 @@ def test_extract_fields_attrs():
|
||||||
'argument': Argument(String),
|
'argument': Argument(String),
|
||||||
'graphql_field': GraphQLField(GraphQLString)
|
'graphql_field': GraphQLField(GraphQLString)
|
||||||
}
|
}
|
||||||
extracted_fields = list(extract_fields(attrs))
|
extracted_fields = list(extract_fields(ObjectType, attrs))
|
||||||
assert [f.name for f in extracted_fields] == ['fieldString', 'string']
|
assert [f.name for f in extracted_fields] == ['fieldString', 'string']
|
||||||
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
|
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def test_extract_fields():
|
||||||
]))
|
]))
|
||||||
|
|
||||||
bases = (int_base, float_base)
|
bases = (int_base, float_base)
|
||||||
base_fields = list(get_base_fields(bases))
|
base_fields = list(get_base_fields(ObjectType, bases))
|
||||||
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
|
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
|
||||||
assert [f.type for f in base_fields] == [
|
assert [f.type for f in base_fields] == [
|
||||||
GraphQLInt,
|
GraphQLInt,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user