diff --git a/graphene/relay/mutation.py b/graphene/relay/mutation.py index dde96fff..eccfce93 100644 --- a/graphene/relay/mutation.py +++ b/graphene/relay/mutation.py @@ -4,8 +4,9 @@ from graphql_relay import mutation_with_client_mutation_id from ..types.mutation import Mutation, MutationMeta from ..types.inputobjecttype import InputObjectType -from ..types.field import Field - +from ..types.field import Field, InputField +from ..utils.get_fields import get_fields +from ..utils.copy_fields import copy_fields from ..utils.props import props @@ -29,7 +30,7 @@ class ClientIDMutationMeta(MutationMeta): cls.mutate_and_get_payload = attrs.pop('mutate_and_get_payload', None) - input_local_fields = {f.name: f for f in InputObjectType._extract_local_fields(input_fields)} + input_local_fields = copy_fields(InputField, get_fields(InputObjectType, input_fields, ())) local_fields = cls._extract_local_fields(attrs) assert cls.mutate_and_get_payload, "{}.mutate_and_get_payload method is required in a ClientIDMutation ObjectType.".format(cls.__name__) field = mutation_with_client_mutation_id( diff --git a/graphene/types/field.py b/graphene/types/field.py index 25d29de8..2ae51f69 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -129,9 +129,7 @@ class Field(AbstractField, GraphQLField, OrderedType): # If is a GraphQLField type = type or field.type resolver = resolver or field.resolver - source = None name = field.name - required = None _creation_counter = None attname = attname or name parent = parent @@ -187,22 +185,19 @@ class InputField(AbstractField, GraphQLInputObjectField, OrderedType): return self.copy_and_extend(self) @classmethod - def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, _creation_counter=False): + def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, parent=None, attname=None, _creation_counter=False): if isinstance(field, Field): type = type or field._type name = name or field._name required = required or field.required _creation_counter = field.creation_counter if _creation_counter is False else None - attname = field.attname - parent = field.parent + attname = attname or field.attname + parent = parent or field.parent else: # If is a GraphQLField type = type or field.type name = field.name - required = None _creation_counter = None - attname = None - parent = None new_field = cls( type=type, diff --git a/graphene/types/inputobjecttype.py b/graphene/types/inputobjecttype.py index 57b65b9d..032190f6 100644 --- a/graphene/types/inputobjecttype.py +++ b/graphene/types/inputobjecttype.py @@ -3,39 +3,58 @@ import six from graphql import GraphQLInputObjectType from .definitions import FieldsMeta, ClassTypeMeta, GrapheneGraphQLType +from .interface import attrs_without_fields from .unmountedtype import UnmountedType +from .options import Options +from ..utils.is_base_type import is_base_type +from ..utils.get_fields import get_fields +from ..utils.copy_fields import copy_fields +from .field import InputField class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType): pass -class InputObjectTypeMeta(FieldsMeta, ClassTypeMeta): +class InputObjectTypeMeta(type): - def get_options(cls, meta): - return cls.options_class( - meta, + def __new__(cls, name, bases, attrs): + super_new = super(InputObjectTypeMeta, cls).__new__ + + # Also ensure initialization is only performed for subclasses of Model + # (excluding Model class itself). + if not is_base_type(bases, InputObjectTypeMeta): + return super_new(cls, name, bases, attrs) + + options = Options( + attrs.pop('Meta', None), name=None, description=None, graphql_type=None, abstract=False ) - def construct(cls, bases, attrs): - if not cls._meta.abstract: - local_fields = cls._extract_local_fields(attrs) - if not cls._meta.graphql_type: - cls._meta.graphql_type = GrapheneInputObjectType( - graphene_type=cls, - name=cls._meta.name or cls.__name__, - description=cls._meta.description or cls.__doc__, - fields=cls._fields(bases, attrs, local_fields), - ) - else: - assert not local_fields, "Can't mount Fields in an InputObjectType with a defined graphql_type" - return super(InputObjectTypeMeta, cls).construct(bases, attrs) + fields = get_fields(InputObjectType, attrs, bases) + attrs = attrs_without_fields(attrs, fields) + cls = super_new(cls, name, bases, dict(attrs, _meta=options)) + + if not options.graphql_type: + fields = copy_fields(InputField, fields, parent=cls) + options.graphql_type = GrapheneInputObjectType( + graphene_type=cls, + name=options.name or cls.__name__, + description=options.description or cls.__doc__, + fields=fields, + ) + else: + assert not fields, "Can't mount InputFields in an InputObjectType with a defined graphql_type" + fields = copy_fields(options.graphql_type.get_fields(), parent=cls) + + for name, field in fields.items(): + setattr(cls, field.attname or name, field) + + return cls class InputObjectType(six.with_metaclass(InputObjectTypeMeta, UnmountedType)): - class Meta: - abstract = True + pass diff --git a/graphene/types/interface.py b/graphene/types/interface.py index 5606b0f5..d54ab991 100644 --- a/graphene/types/interface.py +++ b/graphene/types/interface.py @@ -6,6 +6,7 @@ from .options import Options from ..utils.is_base_type import is_base_type from ..utils.get_fields import get_fields from ..utils.copy_fields import copy_fields +from .field import Field class GrapheneInterfaceType(GrapheneGraphQLType, GraphQLInterfaceType): @@ -35,7 +36,7 @@ class InterfaceTypeMeta(type): cls = super_new(cls, name, bases, dict(attrs, _meta=options)) if not options.graphql_type: - fields = copy_fields(fields, parent=cls) + fields = copy_fields(Field, fields, parent=cls) options.graphql_type = GrapheneInterfaceType( graphene_type=cls, name=options.name or cls.__name__, diff --git a/graphene/utils/copy_fields.py b/graphene/utils/copy_fields.py index aea1c17c..59239a12 100644 --- a/graphene/utils/copy_fields.py +++ b/graphene/utils/copy_fields.py @@ -1,11 +1,11 @@ from collections import OrderedDict -from ..types.field import Field +from ..types.field import Field, InputField -def copy_fields(fields, **extra): +def copy_fields(like, fields, **extra): _fields = [] for attname, field in fields.items(): - field = Field.copy_and_extend(field, attname=attname, **extra) + field = like.copy_and_extend(field, attname=attname, **extra) _fields.append(field) return OrderedDict((f.name, f) for f in _fields) diff --git a/graphene/utils/is_graphene_type.py b/graphene/utils/is_graphene_type.py index e32bc806..44663c7b 100644 --- a/graphene/utils/is_graphene_type.py +++ b/graphene/utils/is_graphene_type.py @@ -8,7 +8,7 @@ def is_graphene_type(_type): from ..types.scalars import Scalar from ..types.enum import Enum - if _type in [Interface]: + if _type in [Interface, InputObjectType]: return False return inspect.isclass(_type) and issubclass(_type, ( Interface,