diff --git a/graphene/types/interface.py b/graphene/types/interface.py index d6c1fe37..6c73ba06 100644 --- a/graphene/types/interface.py +++ b/graphene/types/interface.py @@ -32,11 +32,13 @@ class InterfaceMeta(AbstractTypeMeta): class Interface(six.with_metaclass(InterfaceMeta)): - resolve_type = None + @classmethod + def resolve_type(cls, root, args, info): + return type(root) def __init__(self, *args, **kwargs): raise Exception("An Interface cannot be intitialized") - # @classmethod - # def implements(cls, objecttype): - # pass + @classmethod + def implements(cls, objecttype): + pass diff --git a/graphene/types/json.py b/graphene/types/json.py index f4bc632b..4c37e11b 100644 --- a/graphene/types/json.py +++ b/graphene/types/json.py @@ -21,4 +21,4 @@ class JSONString(Scalar): @staticmethod def parse_value(value): - return json.dumps(value) + return json.loads(value) diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 3c1c7791..7ddd5783 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -1,3 +1,4 @@ +from collections import OrderedDict import six from ..utils.is_base_type import is_base_type @@ -5,6 +6,7 @@ from .options import Options from .abstracttype import AbstractTypeMeta from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs +from .interface import Interface class ObjectTypeMeta(AbstractTypeMeta): @@ -23,10 +25,22 @@ class ObjectTypeMeta(AbstractTypeMeta): ) attrs = merge_fields_in_attrs(bases, attrs) - options.fields = get_fields_in_type(ObjectType, attrs) - yank_fields_from_attrs(attrs, options.fields) + options.local_fields = get_fields_in_type(ObjectType, attrs) + yank_fields_from_attrs(attrs, options.local_fields) + options.interface_fields = OrderedDict() + for interface in options.interfaces: + assert issubclass(interface, Interface), ( + 'All interfaces of {} must be a subclass of Interface. Received "{}".' + ).format(name, interface) + options.interface_fields.update(interface._meta.fields) + options.fields = OrderedDict(options.interface_fields) + options.fields.update(options.local_fields) - return type.__new__(cls, name, bases, dict(attrs, _meta=options)) + cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) + for interface in options.interfaces: + interface.implements(cls) + + return cls def __str__(cls): return cls._meta.name diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 624ce9e1..6145d4e5 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -62,6 +62,8 @@ class Schema(GraphQLSchema): return self.get_graphql_type(self._subscription) def get_graphql_type(self, _type): + if not _type: + return _type if is_type(_type): return _type if is_graphene_type(_type): diff --git a/graphene/types/tests/test_interface.py b/graphene/types/tests/test_interface.py index 97804c5c..b15c85c5 100644 --- a/graphene/types/tests/test_interface.py +++ b/graphene/types/tests/test_interface.py @@ -71,6 +71,18 @@ def test_generate_interface_inherit_abstracttype(): assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] +def test_generate_interface_inherit_interface(): + class MyBaseInterface(Interface): + field1 = MyScalar() + + class MyInterface(MyBaseInterface): + field2 = MyScalar() + + assert MyInterface._meta.name == 'MyInterface' + assert MyInterface._meta.fields.keys() == ['field1', 'field2'] + assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] + + def test_generate_interface_inherit_abstracttype_reversed(): class MyAbstractType(AbstractType): field1 = MyScalar() diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index 9b8b3913..3da688f5 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -4,9 +4,10 @@ from ..field import Field from ..objecttype import ObjectType from ..unmountedtype import UnmountedType from ..abstracttype import AbstractType +from ..interface import Interface -class MyType(object): +class MyType(Interface): pass @@ -15,6 +16,17 @@ class Container(ObjectType): field2 = Field(MyType) +class MyInterface(Interface): + ifield = Field(MyType) + + +class ContainerWithInterface(ObjectType): + class Meta: + interfaces = (MyInterface, ) + field1 = Field(MyType) + field2 = Field(MyType) + + class MyScalar(UnmountedType): def get_type(self): return MyType @@ -94,6 +106,10 @@ def test_parent_container_get_fields(): assert list(Container._meta.fields.keys()) == ['field1', 'field2'] +def test_parent_container_interface_get_fields(): + assert list(ContainerWithInterface._meta.fields.keys()) == ['ifield', 'field1', 'field2'] + + def test_objecttype_as_container_only_args(): container = Container("1", "2") assert container.field1 == "1" diff --git a/graphene/types/typemap.py b/graphene/types/typemap.py index 1f09871c..506900b6 100644 --- a/graphene/types/typemap.py +++ b/graphene/types/typemap.py @@ -1,4 +1,5 @@ import inspect +from functools import partial from collections import OrderedDict from graphql.type.typemap import GraphQLTypeMap @@ -14,6 +15,8 @@ from .scalars import Scalar, String, Boolean, Int, Float, ID from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull, GraphQLInputObjectField, GraphQLArgument from graphql.type import GraphQLEnumValue +from ..utils.str_converters import to_camel_case + def is_graphene_type(_type): if isinstance(_type, (List, NonNull)): @@ -22,13 +25,26 @@ def is_graphene_type(_type): return True +def resolve_type(resolve_type_func, map, root, args, info): + _type = resolve_type_func(root, args, info) + # assert inspect.isclass(_type) and issubclass(_type, ObjectType), ( + # 'Received incompatible type "{}".'.format(_type) + # ) + if inspect.isclass(_type) and issubclass(_type, ObjectType): + graphql_type = map.get(_type._meta.name) + assert graphql_type and graphql_type.graphene_type == _type + return graphql_type + return _type + + class TypeMap(GraphQLTypeMap): @classmethod def reducer(cls, map, type): if not type: return map - + if inspect.isfunction(type): + type = type() if is_graphene_type(type): return cls.graphene_reducer(map, type) return super(TypeMap, cls).reducer(map, type) @@ -112,10 +128,11 @@ class TypeMap(GraphQLTypeMap): ) interfaces = [] for i in type._meta.interfaces: - map = cls.construct_interface(map, i) + map = cls.reducer(map, i) interfaces.append(map[i._meta.name]) - map[type._meta.name].interfaces = interfaces + map[type._meta.name]._provided_interfaces = interfaces map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) + # cls.reducer(map, map[type._meta.name]) return map @classmethod @@ -126,9 +143,10 @@ class TypeMap(GraphQLTypeMap): name=type._meta.name, description=type._meta.description, fields=None, - resolve_type=type.resolve_type, + resolve_type=partial(resolve_type, type.resolve_type, map), ) map[type._meta.name]._fields = cls.construct_fields_for_type(map, type) + # cls.reducer(map, map[type._meta.name]) return map @classmethod @@ -159,6 +177,14 @@ class TypeMap(GraphQLTypeMap): map[type._meta.name].types = types return map + @classmethod + def process_field_name(cls, name): + return to_camel_case(name) + + @classmethod + def default_resolver(cls, attname, root, *_): + return getattr(root, attname, None) + @classmethod def construct_fields_for_type(cls, map, type, is_input_type=False): fields = OrderedDict() @@ -181,25 +207,42 @@ class TypeMap(GraphQLTypeMap): description=arg.description, default_value=arg.default_value ) - resolver = field.resolver - resolver_type = getattr(type, 'resolve_{}'.format(name), None) - if resolver_type: - resolver = resolver_type.__func__ - _field = GraphQLField( field_type, args=args, - resolver=resolver, + resolver=field.resolver or cls.get_resolver_for_type(type, name), deprecation_reason=field.deprecation_reason, description=field.description ) - fields[name] = _field + processed_name = cls.process_field_name(name) + fields[processed_name] = _field return fields + @classmethod + def get_resolver_for_type(cls, type, name): + if not issubclass(type, ObjectType): + return + resolver = getattr(type, 'resolve_{}'.format(name), None) + if not resolver: + # If we don't find the resolver in the ObjectType class, then try to + # find it in each of the interfaces + interface_resolver = None + for interface in type._meta.interfaces: + interface_resolver = getattr(interface, 'resolve_{}'.format(name), None) + if interface_resolver: + break + resolver = interface_resolver + # Only if is not decorated with classmethod + if resolver and not getattr(resolver, '__self__', True): + return resolver.__func__ + return partial(cls.default_resolver, name) + @classmethod def get_field_type(self, map, type): if isinstance(type, List): return GraphQLList(self.get_field_type(map, type.of_type)) if isinstance(type, NonNull): return GraphQLNonNull(self.get_field_type(map, type.of_type)) + if inspect.isfunction(type): + type = type() return map.get(type._meta.name) diff --git a/graphene/types/utils.py b/graphene/types/utils.py index 74eb52dc..63a4dfc8 100644 --- a/graphene/types/utils.py +++ b/graphene/types/utils.py @@ -6,9 +6,10 @@ from .inputfield import InputField def merge_fields_in_attrs(bases, attrs): - from ..types.abstracttype import AbstractType + from ..types import AbstractType, Interface + inherited_bases = (AbstractType, Interface) for base in bases: - if base == AbstractType or not issubclass(base, AbstractType): + if base in inherited_bases or not issubclass(base, inherited_bases): continue for name, field in base._meta.fields.items(): if name in attrs: