diff --git a/graphene/generators/__init__.py b/graphene/generators/__init__.py index c9626f08..1f46ebc6 100644 --- a/graphene/generators/__init__.py +++ b/graphene/generators/__init__.py @@ -1,4 +1,4 @@ -from .definitions import GrapheneInterfaceType, GrapheneObjectType, GrapheneScalarType, GrapheneEnumType +from .definitions import GrapheneInterfaceType, GrapheneObjectType, GrapheneScalarType, GrapheneEnumType, GrapheneInputObjectType from .utils import values_from_enum @@ -42,3 +42,12 @@ def generate_enum(enum): name=enum._meta.name or enum.__name__, description=enum._meta.description or enum.__doc__, ) + + +def generate_inputobjecttype(inputobjecttype): + return GrapheneInputObjectType( + graphene_type=inputobjecttype, + name=inputobjecttype._meta.name or inputobjecttype.__name__, + description=inputobjecttype._meta.description or inputobjecttype.__doc__, + fields=inputobjecttype._meta.get_fields, + ) diff --git a/graphene/generators/definitions.py b/graphene/generators/definitions.py index 423c14e3..70eba044 100644 --- a/graphene/generators/definitions.py +++ b/graphene/generators/definitions.py @@ -1,4 +1,4 @@ -from graphql import GraphQLObjectType, GraphQLInterfaceType, GraphQLScalarType, GraphQLEnumType +from graphql import GraphQLObjectType, GraphQLInterfaceType, GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType class GrapheneGraphQLType(object): @@ -35,3 +35,7 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType): class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType): pass + + +class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType): + pass diff --git a/graphene/types/inputobjecttype.py b/graphene/types/inputobjecttype.py index b5398497..40549994 100644 --- a/graphene/types/inputobjecttype.py +++ b/graphene/types/inputobjecttype.py @@ -1,19 +1,15 @@ import six -from graphql import GraphQLInputObjectType - from ..utils.copy_fields import copy_fields from ..utils.get_fields import get_fields from ..utils.is_base_type import is_base_type -from .definitions import GrapheneGraphQLType from .field import InputField from .objecttype import attrs_without_fields from .options import Options from .unmountedtype import UnmountedType -class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType): - pass +from ..generators import generate_inputobjecttype class InputObjectTypeMeta(type): @@ -39,12 +35,8 @@ class InputObjectTypeMeta(type): if not options.graphql_type: fields = copy_fields(InputField, fields, parent=cls) - options.graphql_type = GrapheneInputObjectType( - graphene_type=cls, - name=options.name or cls.__name__, - description=options.description or cls.__doc__, - fields=fields, - ) + options.get_fields = lambda: fields + options.graphql_type = generate_inputobjecttype(cls) else: assert not fields, "Can't mount InputFields in an InputObjectType with a defined graphql_type" fields = copy_fields(InputField, options.graphql_type.get_fields(), parent=cls)