Improved ObjectType fields

This commit is contained in:
Syrus Akbary 2016-06-21 13:09:14 -07:00
parent 568718d573
commit c87d87d1ea
4 changed files with 33 additions and 28 deletions

View File

@ -1,9 +0,0 @@
from ..schema import Droid
def test_query_types():
graphql_type = Droid._meta.graphql_type
fields = graphql_type.get_fields()
assert fields['friends'].parent == Droid
assert fields

View File

@ -146,9 +146,13 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# GraphQL ObjectType acting as container # GraphQL ObjectType acting as container
args_len = len(args) args_len = len(args)
fields = self._meta.graphql_type.get_fields().values() _fields = self._meta.graphql_type._fields
for f in fields: if callable(_fields):
setattr(self, getattr(f, 'attname', f.name), None) _fields = _fields()
fields = _fields.items()
for name, f in fields:
setattr(self, getattr(f, 'attname', name), None)
if args_len > len(fields): if args_len > len(fields):
# Daft, but matches old exception sans the err msg. # Daft, but matches old exception sans the err msg.
@ -156,18 +160,18 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)):
fields_iter = iter(fields) fields_iter = iter(fields)
if not kwargs: if not kwargs:
for val, field in zip(args, fields_iter): for val, (name, field) in zip(args, fields_iter):
attname = getattr(field, 'attname', field.name) attname = getattr(field, 'attname', name)
setattr(self, attname, val) setattr(self, attname, val)
else: else:
for val, field in zip(args, fields_iter): for val, (name, field) in zip(args, fields_iter):
attname = getattr(field, 'attname', field.name) attname = getattr(field, 'attname', name)
setattr(self, attname, val) setattr(self, attname, val)
kwargs.pop(attname, None) kwargs.pop(attname, None)
for field in fields_iter: for name, field in fields_iter:
try: try:
attname = getattr(field, 'attname', field.name) attname = getattr(field, 'attname', name)
val = kwargs.pop(attname) val = kwargs.pop(attname)
setattr(self, attname, val) setattr(self, attname, val)
except KeyError: except KeyError:

View File

@ -16,10 +16,24 @@ def get_fields_from_attrs(in_type, attrs):
yield attname, field yield attname, field
def get_fields_from_types(bases): def get_fields_from_bases_and_types(bases, types):
fields = set() fields = set()
for _class in bases: for _class in bases:
for attname, field in get_graphql_type(_class).get_fields().items(): if not is_graphene_type(_class):
continue
_fields = get_graphql_type(_class)._fields
if callable(_fields):
_fields = _fields()
for default_attname, field in _fields.items():
attname = getattr(field, 'attname', default_attname)
if attname in fields:
continue
fields.add(attname)
yield attname, field
for grapqhl_type in types:
for attname, field in get_graphql_type(grapqhl_type).get_fields().items():
if attname in fields: if attname in fields:
continue continue
fields.add(attname) fields.add(attname)
@ -29,11 +43,7 @@ def get_fields_from_types(bases):
def get_fields(in_type, attrs, bases, graphql_types=()): def get_fields(in_type, attrs, bases, graphql_types=()):
fields = [] fields = []
graphene_bases = tuple( extended_fields = list(get_fields_from_bases_and_types(bases, graphql_types))
base._meta.graphql_type for base in bases if is_graphene_type(base)
) + graphql_types
extended_fields = list(get_fields_from_types(graphene_bases))
local_fields = list(get_fields_from_attrs(in_type, attrs)) local_fields = list(get_fields_from_attrs(in_type, attrs))
# We asume the extended fields are already sorted, so we only # We asume the extended fields are already sorted, so we only
# have to sort the local fields, that are get from attrs # have to sort the local fields, that are get from attrs

View File

@ -4,7 +4,7 @@ from graphql import (GraphQLField, GraphQLFloat, GraphQLInt,
GraphQLInterfaceType, GraphQLString) GraphQLInterfaceType, GraphQLString)
from ...types import Argument, Field, ObjectType, String from ...types import Argument, Field, ObjectType, String
from ..get_fields import get_fields_from_attrs, get_fields_from_types from ..get_fields import get_fields_from_attrs, get_fields_from_bases_and_types
def test_get_fields_from_attrs(): def test_get_fields_from_attrs():
@ -31,8 +31,8 @@ def test_get_fields_from_types():
('extra', GraphQLField(GraphQLFloat)) ('extra', GraphQLField(GraphQLFloat))
])) ]))
bases = (int_base, float_base) _types = (int_base, float_base)
base_fields = OrderedDict(get_fields_from_types(bases)) base_fields = OrderedDict(get_fields_from_bases_and_types((), _types))
assert [f for f in base_fields.keys()] == ['int', 'num', 'extra', 'float'] assert [f for f in base_fields.keys()] == ['int', 'num', 'extra', 'float']
assert [f.type for f in base_fields.values()] == [ assert [f.type for f in base_fields.values()] == [
GraphQLInt, GraphQLInt,