mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-30 05:23:57 +03:00
Improved interface fields getter
This commit is contained in:
parent
01190fb6ff
commit
9f655d9416
|
@ -12,17 +12,16 @@ class NodeMeta(InterfaceTypeMeta):
|
||||||
def construct_graphql_type(cls, bases):
|
def construct_graphql_type(cls, bases):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def construct(cls, *args, **kwargs):
|
def construct(cls, bases, attrs):
|
||||||
constructed = super(NodeMeta, cls).construct(*args, **kwargs)
|
cls.get_node = attrs.pop('get_node')
|
||||||
if not cls._meta.graphql_type:
|
node_interface, node_field = node_definitions(
|
||||||
node_interface, node_field = node_definitions(
|
cls.get_node,
|
||||||
cls.get_node,
|
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
|
||||||
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
|
field_class=Field
|
||||||
field_class=Field
|
)
|
||||||
)
|
cls._meta.graphql_type = node_interface
|
||||||
cls._meta.graphql_type = node_interface
|
cls._Field = node_field
|
||||||
cls._Field = node_field
|
return super(NodeMeta, cls).construct(bases, attrs)
|
||||||
return constructed
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def Field(cls):
|
def Field(cls):
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
from itertools import chain
|
||||||
|
from functools import partial
|
||||||
|
from collections import OrderedDict
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from graphql import GraphQLInterfaceType
|
from graphql import GraphQLInterfaceType
|
||||||
|
@ -20,18 +23,33 @@ class InterfaceTypeMeta(ClassTypeMeta):
|
||||||
)
|
)
|
||||||
|
|
||||||
def construct_graphql_type(cls, bases):
|
def construct_graphql_type(cls, bases):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _build_field_map(cls, local_fields, bases):
|
||||||
|
from ..utils.extract_fields import get_base_fields
|
||||||
|
extended_fields = get_base_fields(bases)
|
||||||
|
fields = chain(extended_fields, local_fields)
|
||||||
|
return OrderedDict((f.name, f) for f in fields)
|
||||||
|
|
||||||
|
def construct(cls, bases, attrs):
|
||||||
if not cls._meta.graphql_type and not cls._meta.abstract:
|
if not cls._meta.graphql_type and not cls._meta.abstract:
|
||||||
from ..utils.is_graphene_type import is_graphene_type
|
from ..utils.is_graphene_type import is_graphene_type
|
||||||
|
from ..utils.extract_fields import extract_fields
|
||||||
|
|
||||||
inherited_types = [
|
inherited_types = [
|
||||||
base._meta.graphql_type for base in bases if is_graphene_type(base)
|
base._meta.graphql_type for base in bases if is_graphene_type(base)
|
||||||
]
|
]
|
||||||
|
inherited_types = filter(None, inherited_types)
|
||||||
|
|
||||||
|
local_fields = list(extract_fields(attrs))
|
||||||
|
|
||||||
cls._meta.graphql_type = GrapheneInterfaceType(
|
cls._meta.graphql_type = GrapheneInterfaceType(
|
||||||
graphene_type=cls,
|
graphene_type=cls,
|
||||||
name=cls._meta.name or cls.__name__,
|
name=cls._meta.name or cls.__name__,
|
||||||
description=cls._meta.description or cls.__doc__,
|
description=cls._meta.description or cls.__doc__,
|
||||||
fields=FieldMap(cls, bases=filter(None, inherited_types)),
|
fields=partial(cls._build_field_map, local_fields, inherited_types),
|
||||||
)
|
)
|
||||||
|
return super(InterfaceTypeMeta, cls).construct(bases, attrs)
|
||||||
|
|
||||||
|
|
||||||
class Interface(six.with_metaclass(InterfaceTypeMeta)):
|
class Interface(six.with_metaclass(InterfaceTypeMeta)):
|
||||||
|
|
35
graphene/utils/extract_fields.py
Normal file
35
graphene/utils/extract_fields.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
import copy
|
||||||
|
from .get_graphql_type import get_graphql_type
|
||||||
|
|
||||||
|
from ..types.field import Field
|
||||||
|
from ..types.proxy import TypeProxy
|
||||||
|
|
||||||
|
|
||||||
|
def extract_fields(attrs):
|
||||||
|
fields = set()
|
||||||
|
_fields = list()
|
||||||
|
for attname, value in list(attrs.items()):
|
||||||
|
is_field = isinstance(value, Field)
|
||||||
|
is_field_proxy = isinstance(value, TypeProxy)
|
||||||
|
if not (is_field or is_field_proxy):
|
||||||
|
continue
|
||||||
|
|
||||||
|
field = value.as_field() if is_field_proxy else copy.copy(value)
|
||||||
|
field.attname = attname
|
||||||
|
fields.add(attname)
|
||||||
|
del attrs[attname]
|
||||||
|
_fields.append(field)
|
||||||
|
|
||||||
|
return sorted(_fields)
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_fields(bases):
|
||||||
|
fields = set()
|
||||||
|
for _class in bases:
|
||||||
|
for attname, field in get_graphql_type(_class).get_fields().items():
|
||||||
|
if attname in fields:
|
||||||
|
continue
|
||||||
|
field = copy.copy(field)
|
||||||
|
field.name = attname
|
||||||
|
fields.add(attname)
|
||||||
|
yield field
|
41
graphene/utils/tests/test_extract_fields.py
Normal file
41
graphene/utils/tests/test_extract_fields.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
|
||||||
|
from ..extract_fields import extract_fields, get_base_fields
|
||||||
|
|
||||||
|
from ...types import Field, String, Argument
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_fields_attrs():
|
||||||
|
attrs = {
|
||||||
|
'field_string': Field(String),
|
||||||
|
'string': String(),
|
||||||
|
'other': None,
|
||||||
|
'argument': Argument(String),
|
||||||
|
'graphql_field': GraphQLField(GraphQLString)
|
||||||
|
}
|
||||||
|
extracted_fields = list(extract_fields(attrs))
|
||||||
|
assert [f.name for f in extracted_fields] == ['fieldString', 'string']
|
||||||
|
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_fields():
|
||||||
|
int_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([
|
||||||
|
('int', GraphQLField(GraphQLInt)),
|
||||||
|
('num', GraphQLField(GraphQLInt)),
|
||||||
|
('extra', GraphQLField(GraphQLInt))
|
||||||
|
]))
|
||||||
|
float_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([
|
||||||
|
('float', GraphQLField(GraphQLFloat)),
|
||||||
|
('num', GraphQLField(GraphQLFloat)),
|
||||||
|
('extra', GraphQLField(GraphQLFloat))
|
||||||
|
]))
|
||||||
|
|
||||||
|
bases = (int_base, float_base)
|
||||||
|
base_fields = list(get_base_fields(bases))
|
||||||
|
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
|
||||||
|
assert [f.type for f in base_fields] == [
|
||||||
|
GraphQLInt,
|
||||||
|
GraphQLInt,
|
||||||
|
GraphQLInt,
|
||||||
|
GraphQLFloat,
|
||||||
|
]
|
Loading…
Reference in New Issue
Block a user