diff --git a/graphene/relay/mutation.py b/graphene/relay/mutation.py index 67930db7..2d89e7b6 100644 --- a/graphene/relay/mutation.py +++ b/graphene/relay/mutation.py @@ -4,7 +4,7 @@ from functools import partial import six from promise import Promise -from ..types import Argument, Field, InputObjectType, String +from ..types import Argument, Field, InputObjectType, String, AbstractType from ..types.objecttype import ObjectType, ObjectTypeMeta from ..utils.is_base_type import is_base_type from ..utils.props import props @@ -27,9 +27,14 @@ class ClientIDMutationMeta(ObjectTypeMeta): "{}.mutate_and_get_payload method is required" " in a ClientIDMutation." ).format(name) - input_attrs = props(input_class) if input_class else {} + input_attrs = {} + bases = () + if not issubclass(input_class, AbstractType): + input_attrs = props(input_class) if input_class else {} + else: + bases += (input_class, ) input_attrs['client_mutation_id'] = String(name='clientMutationId') - cls.Input = type('{}Input'.format(base_name), (InputObjectType,), input_attrs) + cls.Input = type('{}Input'.format(base_name), bases + (InputObjectType,), input_attrs) cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input)) return cls diff --git a/graphene/relay/tests/test_mutation.py b/graphene/relay/tests/test_mutation.py index a7480075..0a958c5c 100644 --- a/graphene/relay/tests/test_mutation.py +++ b/graphene/relay/tests/test_mutation.py @@ -1,12 +1,12 @@ import pytest from ...types import (Argument, Field, InputField, InputObjectType, ObjectType, - Schema) + Schema, AbstractType) from ...types.scalars import String from ..mutation import ClientIDMutation -class SharedFields(object): +class SharedFields(AbstractType): shared = String() diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index 97d2936b..421ef695 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -10,9 +10,6 @@ from ..node import Node class SharedNodeFields(AbstractType): - class Meta: - interfaces = (Node, ) - shared = String() something_else = String() @@ -34,6 +31,9 @@ class MyNode(ObjectType): class MyOtherNode(SharedNodeFields, ObjectType): extra_field = String() + class Meta: + interfaces = (Node, ) + def resolve_extra_field(self, *_): return 'extra field info.'