mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-23 01:56:54 +03:00
Improved interface fields getter
This commit is contained in:
parent
01190fb6ff
commit
9f655d9416
|
@ -12,17 +12,16 @@ class NodeMeta(InterfaceTypeMeta):
|
|||
def construct_graphql_type(cls, bases):
|
||||
pass
|
||||
|
||||
def construct(cls, *args, **kwargs):
|
||||
constructed = super(NodeMeta, cls).construct(*args, **kwargs)
|
||||
if not cls._meta.graphql_type:
|
||||
node_interface, node_field = node_definitions(
|
||||
cls.get_node,
|
||||
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
|
||||
field_class=Field
|
||||
)
|
||||
cls._meta.graphql_type = node_interface
|
||||
cls._Field = node_field
|
||||
return constructed
|
||||
def construct(cls, bases, attrs):
|
||||
cls.get_node = attrs.pop('get_node')
|
||||
node_interface, node_field = node_definitions(
|
||||
cls.get_node,
|
||||
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
|
||||
field_class=Field
|
||||
)
|
||||
cls._meta.graphql_type = node_interface
|
||||
cls._Field = node_field
|
||||
return super(NodeMeta, cls).construct(bases, attrs)
|
||||
|
||||
@property
|
||||
def Field(cls):
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from itertools import chain
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
import six
|
||||
|
||||
from graphql import GraphQLInterfaceType
|
||||
|
@ -20,18 +23,33 @@ class InterfaceTypeMeta(ClassTypeMeta):
|
|||
)
|
||||
|
||||
def construct_graphql_type(cls, bases):
|
||||
pass
|
||||
|
||||
def _build_field_map(cls, local_fields, bases):
|
||||
from ..utils.extract_fields import get_base_fields
|
||||
extended_fields = get_base_fields(bases)
|
||||
fields = chain(extended_fields, local_fields)
|
||||
return OrderedDict((f.name, f) for f in fields)
|
||||
|
||||
def construct(cls, bases, attrs):
|
||||
if not cls._meta.graphql_type and not cls._meta.abstract:
|
||||
from ..utils.is_graphene_type import is_graphene_type
|
||||
from ..utils.extract_fields import extract_fields
|
||||
|
||||
inherited_types = [
|
||||
base._meta.graphql_type for base in bases if is_graphene_type(base)
|
||||
]
|
||||
inherited_types = filter(None, inherited_types)
|
||||
|
||||
local_fields = list(extract_fields(attrs))
|
||||
|
||||
cls._meta.graphql_type = GrapheneInterfaceType(
|
||||
graphene_type=cls,
|
||||
name=cls._meta.name or cls.__name__,
|
||||
description=cls._meta.description or cls.__doc__,
|
||||
fields=FieldMap(cls, bases=filter(None, inherited_types)),
|
||||
fields=partial(cls._build_field_map, local_fields, inherited_types),
|
||||
)
|
||||
return super(InterfaceTypeMeta, cls).construct(bases, attrs)
|
||||
|
||||
|
||||
class Interface(six.with_metaclass(InterfaceTypeMeta)):
|
||||
|
|
35
graphene/utils/extract_fields.py
Normal file
35
graphene/utils/extract_fields.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import copy
|
||||
from .get_graphql_type import get_graphql_type
|
||||
|
||||
from ..types.field import Field
|
||||
from ..types.proxy import TypeProxy
|
||||
|
||||
|
||||
def extract_fields(attrs):
|
||||
fields = set()
|
||||
_fields = list()
|
||||
for attname, value in list(attrs.items()):
|
||||
is_field = isinstance(value, Field)
|
||||
is_field_proxy = isinstance(value, TypeProxy)
|
||||
if not (is_field or is_field_proxy):
|
||||
continue
|
||||
|
||||
field = value.as_field() if is_field_proxy else copy.copy(value)
|
||||
field.attname = attname
|
||||
fields.add(attname)
|
||||
del attrs[attname]
|
||||
_fields.append(field)
|
||||
|
||||
return sorted(_fields)
|
||||
|
||||
|
||||
def get_base_fields(bases):
|
||||
fields = set()
|
||||
for _class in bases:
|
||||
for attname, field in get_graphql_type(_class).get_fields().items():
|
||||
if attname in fields:
|
||||
continue
|
||||
field = copy.copy(field)
|
||||
field.name = attname
|
||||
fields.add(attname)
|
||||
yield field
|
41
graphene/utils/tests/test_extract_fields.py
Normal file
41
graphene/utils/tests/test_extract_fields.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from collections import OrderedDict
|
||||
from graphql import GraphQLField, GraphQLString, GraphQLInterfaceType, GraphQLInt, GraphQLFloat
|
||||
from ..extract_fields import extract_fields, get_base_fields
|
||||
|
||||
from ...types import Field, String, Argument
|
||||
|
||||
|
||||
def test_extract_fields_attrs():
|
||||
attrs = {
|
||||
'field_string': Field(String),
|
||||
'string': String(),
|
||||
'other': None,
|
||||
'argument': Argument(String),
|
||||
'graphql_field': GraphQLField(GraphQLString)
|
||||
}
|
||||
extracted_fields = list(extract_fields(attrs))
|
||||
assert [f.name for f in extracted_fields] == ['fieldString', 'string']
|
||||
assert sorted(attrs.keys()) == ['argument', 'graphql_field', 'other']
|
||||
|
||||
|
||||
def test_extract_fields():
|
||||
int_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([
|
||||
('int', GraphQLField(GraphQLInt)),
|
||||
('num', GraphQLField(GraphQLInt)),
|
||||
('extra', GraphQLField(GraphQLInt))
|
||||
]))
|
||||
float_base = GraphQLInterfaceType('IntInterface', fields=OrderedDict([
|
||||
('float', GraphQLField(GraphQLFloat)),
|
||||
('num', GraphQLField(GraphQLFloat)),
|
||||
('extra', GraphQLField(GraphQLFloat))
|
||||
]))
|
||||
|
||||
bases = (int_base, float_base)
|
||||
base_fields = list(get_base_fields(bases))
|
||||
assert [f.name for f in base_fields] == ['int', 'num', 'extra', 'float']
|
||||
assert [f.type for f in base_fields] == [
|
||||
GraphQLInt,
|
||||
GraphQLInt,
|
||||
GraphQLInt,
|
||||
GraphQLFloat,
|
||||
]
|
Loading…
Reference in New Issue
Block a user