mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-22 17:46:57 +03:00
Improved ObjectType fields
This commit is contained in:
parent
568718d573
commit
c87d87d1ea
|
@ -1,9 +0,0 @@
|
|||
|
||||
from ..schema import Droid
|
||||
|
||||
|
||||
def test_query_types():
|
||||
graphql_type = Droid._meta.graphql_type
|
||||
fields = graphql_type.get_fields()
|
||||
assert fields['friends'].parent == Droid
|
||||
assert fields
|
|
@ -146,9 +146,13 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)):
|
|||
def __init__(self, *args, **kwargs):
|
||||
# GraphQL ObjectType acting as container
|
||||
args_len = len(args)
|
||||
fields = self._meta.graphql_type.get_fields().values()
|
||||
for f in fields:
|
||||
setattr(self, getattr(f, 'attname', f.name), None)
|
||||
_fields = self._meta.graphql_type._fields
|
||||
if callable(_fields):
|
||||
_fields = _fields()
|
||||
|
||||
fields = _fields.items()
|
||||
for name, f in fields:
|
||||
setattr(self, getattr(f, 'attname', name), None)
|
||||
|
||||
if args_len > len(fields):
|
||||
# Daft, but matches old exception sans the err msg.
|
||||
|
@ -156,18 +160,18 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)):
|
|||
fields_iter = iter(fields)
|
||||
|
||||
if not kwargs:
|
||||
for val, field in zip(args, fields_iter):
|
||||
attname = getattr(field, 'attname', field.name)
|
||||
for val, (name, field) in zip(args, fields_iter):
|
||||
attname = getattr(field, 'attname', name)
|
||||
setattr(self, attname, val)
|
||||
else:
|
||||
for val, field in zip(args, fields_iter):
|
||||
attname = getattr(field, 'attname', field.name)
|
||||
for val, (name, field) in zip(args, fields_iter):
|
||||
attname = getattr(field, 'attname', name)
|
||||
setattr(self, attname, val)
|
||||
kwargs.pop(attname, None)
|
||||
|
||||
for field in fields_iter:
|
||||
for name, field in fields_iter:
|
||||
try:
|
||||
attname = getattr(field, 'attname', field.name)
|
||||
attname = getattr(field, 'attname', name)
|
||||
val = kwargs.pop(attname)
|
||||
setattr(self, attname, val)
|
||||
except KeyError:
|
||||
|
|
|
@ -16,10 +16,24 @@ def get_fields_from_attrs(in_type, attrs):
|
|||
yield attname, field
|
||||
|
||||
|
||||
def get_fields_from_types(bases):
|
||||
def get_fields_from_bases_and_types(bases, types):
|
||||
fields = set()
|
||||
for _class in bases:
|
||||
for attname, field in get_graphql_type(_class).get_fields().items():
|
||||
if not is_graphene_type(_class):
|
||||
continue
|
||||
_fields = get_graphql_type(_class)._fields
|
||||
if callable(_fields):
|
||||
_fields = _fields()
|
||||
|
||||
for default_attname, field in _fields.items():
|
||||
attname = getattr(field, 'attname', default_attname)
|
||||
if attname in fields:
|
||||
continue
|
||||
fields.add(attname)
|
||||
yield attname, field
|
||||
|
||||
for grapqhl_type in types:
|
||||
for attname, field in get_graphql_type(grapqhl_type).get_fields().items():
|
||||
if attname in fields:
|
||||
continue
|
||||
fields.add(attname)
|
||||
|
@ -29,11 +43,7 @@ def get_fields_from_types(bases):
|
|||
def get_fields(in_type, attrs, bases, graphql_types=()):
|
||||
fields = []
|
||||
|
||||
graphene_bases = tuple(
|
||||
base._meta.graphql_type for base in bases if is_graphene_type(base)
|
||||
) + graphql_types
|
||||
|
||||
extended_fields = list(get_fields_from_types(graphene_bases))
|
||||
extended_fields = list(get_fields_from_bases_and_types(bases, graphql_types))
|
||||
local_fields = list(get_fields_from_attrs(in_type, attrs))
|
||||
# We asume the extended fields are already sorted, so we only
|
||||
# have to sort the local fields, that are get from attrs
|
||||
|
|
|
@ -4,7 +4,7 @@ from graphql import (GraphQLField, GraphQLFloat, GraphQLInt,
|
|||
GraphQLInterfaceType, GraphQLString)
|
||||
|
||||
from ...types import Argument, Field, ObjectType, String
|
||||
from ..get_fields import get_fields_from_attrs, get_fields_from_types
|
||||
from ..get_fields import get_fields_from_attrs, get_fields_from_bases_and_types
|
||||
|
||||
|
||||
def test_get_fields_from_attrs():
|
||||
|
@ -31,8 +31,8 @@ def test_get_fields_from_types():
|
|||
('extra', GraphQLField(GraphQLFloat))
|
||||
]))
|
||||
|
||||
bases = (int_base, float_base)
|
||||
base_fields = OrderedDict(get_fields_from_types(bases))
|
||||
_types = (int_base, float_base)
|
||||
base_fields = OrderedDict(get_fields_from_bases_and_types((), _types))
|
||||
assert [f for f in base_fields.keys()] == ['int', 'num', 'extra', 'float']
|
||||
assert [f.type for f in base_fields.values()] == [
|
||||
GraphQLInt,
|
||||
|
|
Loading…
Reference in New Issue
Block a user