diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index c0d376b..0f8c376 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -5,6 +5,7 @@ from django.utils.translation import gettext_lazy from mock import patch from ..utils import camelize, get_model_fields, GraphQLTestCase +from ..views import instantiate_middleware from .models import Film, Reporter @@ -36,6 +37,27 @@ def test_camelize(): assert camelize({0: {"field_a": ["errors"]}}) == {0: {"fieldA": ["errors"]}} +def test_instantiate_middleware_input_types(): + from django.middleware.security import SecurityMiddleware + from django.middleware.common import CommonMiddleware + from django.middleware.csrf import CsrfViewMiddleware + from django.contrib.messages.middleware import MessageMiddleware + + middleware = [ + "django.middleware.security.SecurityMiddleware", + CommonMiddleware, + CsrfViewMiddleware(), + "django.contrib.messages.middleware.MessageMiddleware", + ] + + loaded_middlewares = list(instantiate_middleware(middleware)) + + assert isinstance(loaded_middlewares[0], SecurityMiddleware) + assert isinstance(loaded_middlewares[1], CommonMiddleware) + assert isinstance(loaded_middlewares[2], CsrfViewMiddleware) + assert isinstance(loaded_middlewares[3], MessageMiddleware) + + @pytest.mark.django_db @patch("graphene_django.utils.testing.Client.post") def test_graphql_test_case_op_name(post_mock): diff --git a/graphene_django/views.py b/graphene_django/views.py index 16bf34b..c7d88af 100644 --- a/graphene_django/views.py +++ b/graphene_django/views.py @@ -9,6 +9,7 @@ from django.shortcuts import render from django.utils.decorators import method_decorator from django.views.generic import View from django.views.decorators.csrf import ensure_csrf_cookie +from django.utils.module_loading import import_string from graphql import get_default_backend from graphql.error import format_error as format_graphql_error @@ -48,6 +49,10 @@ def instantiate_middleware(middlewares): if inspect.isclass(middleware): yield middleware() continue + elif isinstance(middleware, six.string_types): + middleware_class = import_string(middleware) + yield middleware_class() + continue yield middleware