diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 2df58039..14f1d129 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -1,6 +1,6 @@ import inspect -from graphql import GraphQLSchema, graphql, is_type +from graphql import GraphQLSchema, MiddlewareManager, graphql, is_type from graphql.utils.introspection_query import introspection_query from graphql.utils.schema_printer import print_schema @@ -24,7 +24,7 @@ from .typemap import TypeMap, is_graphene_type class Schema(GraphQLSchema): - def __init__(self, query=None, mutation=None, subscription=None, directives=None, types=None, executor=None): + def __init__(self, query=None, mutation=None, subscription=None, directives=None, types=None, executor=None, middlewares=None): self._query = query self._mutation = mutation self._subscription = subscription @@ -40,7 +40,10 @@ class Schema(GraphQLSchema): 'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format( directives ) - + if middlewares: + self.middlewares = MiddlewareManager(*middlewares) + else: + self.middlewares = None self._directives = directives self.build_typemap() @@ -74,7 +77,8 @@ class Schema(GraphQLSchema): context_value=context_value, variable_values=variable_values, operation_name=operation_name, - executor=executor or self._executor + executor=executor or self._executor, + middlewares=self.middlewares ) def register(self, object_type): diff --git a/graphene/types/tests/test_query.py b/graphene/types/tests/test_query.py index 7bdf3651..9edbb4aa 100644 --- a/graphene/types/tests/test_query.py +++ b/graphene/types/tests/test_query.py @@ -1,6 +1,7 @@ from collections import OrderedDict from py.test import raises +from graphql import MiddlewareManager from ..objecttype import ObjectType from ..scalars import String, Int, Boolean @@ -10,12 +11,48 @@ from ..structures import List from ..schema import Schema -class Query(ObjectType): - hello = String(resolver=lambda *_: 'World') - - def test_query(): + class Query(ObjectType): + hello = String(resolver=lambda *_: 'World') + hello_schema = Schema(Query) executed = hello_schema.execute('{ hello }') - print executed.errors + assert not executed.errors + assert executed.data == {'hello': 'World'} + + +def test_query_resolve_function(): + class Query(ObjectType): + hello = String() + + def resolve_hello(self, args, context, info): + return 'World' + + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello }') + assert not executed.errors + assert executed.data == {'hello': 'World'} + + +def test_query_middlewares(): + class Query(ObjectType): + hello = String() + other = String() + + def resolve_hello(self, args, context, info): + return 'World' + + def resolve_other(self, args, context, info): + return 'other' + + def reversed_middleware(next, *args, **kwargs): + p = next(*args, **kwargs) + return p.then(lambda x: x[::-1]) + + hello_schema = Schema(Query, middlewares=[reversed_middleware]) + + executed = hello_schema.execute('{ hello, other }') + assert not executed.errors + assert executed.data == {'hello': 'dlroW', 'other': 'rehto'}