diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 8066de3e..7fd513b2 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -1,6 +1,6 @@ import inspect -from graphql import GraphQLSchema, graphql, is_type +from graphql import GraphQLSchema, graphql, is_type, GraphQLObjectType from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective, GraphQLSkipDirective) from graphql.type.introspection import IntrospectionSchema @@ -12,6 +12,17 @@ from .objecttype import ObjectType from .typemap import TypeMap, is_graphene_type +def assert_valid_root_type(_type): + if _type is None: + return + is_graphene_objecttype = inspect.isclass( + _type) and issubclass(_type, ObjectType) + is_graphql_objecttype = isinstance(_type, GraphQLObjectType) + assert is_graphene_objecttype or is_graphql_objecttype, ( + "Type {} is not a valid ObjectType." + ).format(_type) + + class Schema(GraphQLSchema): ''' Schema Definition @@ -20,21 +31,23 @@ class Schema(GraphQLSchema): query and mutation (optional). ''' - def __init__(self, query=None, mutation=None, subscription=None, - directives=None, types=None, auto_camelcase=True): - assert inspect.isclass(query) and issubclass(query, ObjectType), ( - 'Schema query must be Object Type but got: {}.' - ).format(query) + def __init__(self, + query=None, + mutation=None, + subscription=None, + directives=None, + types=None, + auto_camelcase=True): + assert_valid_root_type(query) + assert_valid_root_type(mutation) + assert_valid_root_type(subscription) self._query = query self._mutation = mutation self._subscription = subscription self.types = types self.auto_camelcase = auto_camelcase if directives is None: - directives = [ - GraphQLIncludeDirective, - GraphQLSkipDirective - ] + directives = [GraphQLIncludeDirective, GraphQLSkipDirective] assert all(isinstance(d, GraphQLDirective) for d in directives), \ 'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format( @@ -61,7 +74,8 @@ class Schema(GraphQLSchema): ''' _type = super(Schema, self).get_type(type_name) if _type is None: - raise AttributeError('Type "{}" not found in the Schema'.format(type_name)) + raise AttributeError( + 'Type "{}" not found in the Schema'.format(type_name)) if isinstance(_type, GrapheneGraphQLType): return _type.graphene_type return _type @@ -73,7 +87,8 @@ class Schema(GraphQLSchema): return _type if is_graphene_type(_type): graphql_type = self.get_type(_type._meta.name) - assert graphql_type, "Type {} not found in this schema.".format(_type._meta.name) + assert graphql_type, "Type {} not found in this schema.".format( + _type._meta.name) assert graphql_type.graphene_type == _type return graphql_type raise Exception("{} is not a valid GraphQL type.".format(_type)) @@ -102,4 +117,8 @@ class Schema(GraphQLSchema): ] if self.types: initial_types += self.types - self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase, schema=self) + self._type_map = TypeMap( + initial_types, + auto_camelcase=self.auto_camelcase, + schema=self + )