diff --git a/examples/starwars/tests/test_schema.py b/examples/starwars/tests/test_schema.py index b7ae49e4..e69de29b 100644 --- a/examples/starwars/tests/test_schema.py +++ b/examples/starwars/tests/test_schema.py @@ -1,9 +0,0 @@ - -from ..schema import Droid - - -def test_query_types(): - graphql_type = Droid._meta.graphql_type - fields = graphql_type.get_fields() - assert fields['friends'].parent == Droid - assert fields diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index a5fdf169..34bc12e2 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -146,9 +146,13 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)): def __init__(self, *args, **kwargs): # GraphQL ObjectType acting as container args_len = len(args) - fields = self._meta.graphql_type.get_fields().values() - for f in fields: - setattr(self, getattr(f, 'attname', f.name), None) + _fields = self._meta.graphql_type._fields + if callable(_fields): + _fields = _fields() + + fields = _fields.items() + for name, f in fields: + setattr(self, getattr(f, 'attname', name), None) if args_len > len(fields): # Daft, but matches old exception sans the err msg. @@ -156,18 +160,18 @@ class ObjectType(six.with_metaclass(ObjectTypeMeta)): fields_iter = iter(fields) if not kwargs: - for val, field in zip(args, fields_iter): - attname = getattr(field, 'attname', field.name) + for val, (name, field) in zip(args, fields_iter): + attname = getattr(field, 'attname', name) setattr(self, attname, val) else: - for val, field in zip(args, fields_iter): - attname = getattr(field, 'attname', field.name) + for val, (name, field) in zip(args, fields_iter): + attname = getattr(field, 'attname', name) setattr(self, attname, val) kwargs.pop(attname, None) - for field in fields_iter: + for name, field in fields_iter: try: - attname = getattr(field, 'attname', field.name) + attname = getattr(field, 'attname', name) val = kwargs.pop(attname) setattr(self, attname, val) except KeyError: diff --git a/graphene/utils/get_fields.py b/graphene/utils/get_fields.py index 29eb0397..13ba2dae 100644 --- a/graphene/utils/get_fields.py +++ b/graphene/utils/get_fields.py @@ -16,10 +16,24 @@ def get_fields_from_attrs(in_type, attrs): yield attname, field -def get_fields_from_types(bases): +def get_fields_from_bases_and_types(bases, types): fields = set() for _class in bases: - for attname, field in get_graphql_type(_class).get_fields().items(): + if not is_graphene_type(_class): + continue + _fields = get_graphql_type(_class)._fields + if callable(_fields): + _fields = _fields() + + for default_attname, field in _fields.items(): + attname = getattr(field, 'attname', default_attname) + if attname in fields: + continue + fields.add(attname) + yield attname, field + + for grapqhl_type in types: + for attname, field in get_graphql_type(grapqhl_type).get_fields().items(): if attname in fields: continue fields.add(attname) @@ -29,11 +43,7 @@ def get_fields_from_types(bases): def get_fields(in_type, attrs, bases, graphql_types=()): fields = [] - graphene_bases = tuple( - base._meta.graphql_type for base in bases if is_graphene_type(base) - ) + graphql_types - - extended_fields = list(get_fields_from_types(graphene_bases)) + extended_fields = list(get_fields_from_bases_and_types(bases, graphql_types)) local_fields = list(get_fields_from_attrs(in_type, attrs)) # We asume the extended fields are already sorted, so we only # have to sort the local fields, that are get from attrs diff --git a/graphene/utils/tests/test_get_fields.py b/graphene/utils/tests/test_get_fields.py index 0c188e75..65aae160 100644 --- a/graphene/utils/tests/test_get_fields.py +++ b/graphene/utils/tests/test_get_fields.py @@ -4,7 +4,7 @@ from graphql import (GraphQLField, GraphQLFloat, GraphQLInt, GraphQLInterfaceType, GraphQLString) from ...types import Argument, Field, ObjectType, String -from ..get_fields import get_fields_from_attrs, get_fields_from_types +from ..get_fields import get_fields_from_attrs, get_fields_from_bases_and_types def test_get_fields_from_attrs(): @@ -31,8 +31,8 @@ def test_get_fields_from_types(): ('extra', GraphQLField(GraphQLFloat)) ])) - bases = (int_base, float_base) - base_fields = OrderedDict(get_fields_from_types(bases)) + _types = (int_base, float_base) + base_fields = OrderedDict(get_fields_from_bases_and_types((), _types)) assert [f for f in base_fields.keys()] == ['int', 'num', 'extra', 'float'] assert [f.type for f in base_fields.values()] == [ GraphQLInt,