diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 3b572c09e..883a1c6d2 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -61,6 +61,9 @@ def api_view(http_method_names=None): WrappedAPIView.parser_classes = getattr(func, 'parser_classes', APIView.parser_classes) + WrappedAPIView.middleware_classes = getattr(func, 'middleware_classes', + APIView.middleware_classes) + WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes', APIView.authentication_classes) diff --git a/rest_framework/middleware.py b/rest_framework/middleware.py new file mode 100644 index 000000000..06bfb815e --- /dev/null +++ b/rest_framework/middleware.py @@ -0,0 +1,26 @@ +class BaseMiddleware: + """ + All middleware classes should extend BaseMiddleware. + """ + + def process_request(self, request): + pass + + def process_response(self, response): + pass + + +class FooMiddleware(BaseMiddleware): + def process_request(self, request): + request._foo = "foo" + + def process_response(self, response): + pass + + +class BarMiddleware(BaseMiddleware): + def process_request(self, request): + pass + + def process_response(self, response): + response._bar = "bar" diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 96b664574..a60477f27 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -37,6 +37,10 @@ DEFAULTS = { 'rest_framework.parsers.FormParser', 'rest_framework.parsers.MultiPartParser' ], + 'DEFAULT_MIDDLEWARE_CLASSES': [ + 'rest_framework.middleware.FooMiddleware', + 'rest_framework.middleware.BarMiddleware' + ], 'DEFAULT_AUTHENTICATION_CLASSES': [ 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.BasicAuthentication' @@ -133,6 +137,7 @@ DEFAULTS = { IMPORT_STRINGS = [ 'DEFAULT_RENDERER_CLASSES', 'DEFAULT_PARSER_CLASSES', + 'DEFAULT_MIDDLEWARE_CLASSES', 'DEFAULT_AUTHENTICATION_CLASSES', 'DEFAULT_PERMISSION_CLASSES', 'DEFAULT_THROTTLE_CLASSES', diff --git a/rest_framework/views.py b/rest_framework/views.py index 4c30029fd..0903b2c76 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -106,6 +106,7 @@ class APIView(View): # The following policies may be set at either globally, or per-view. renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES parser_classes = api_settings.DEFAULT_PARSER_CLASSES + middleware_classes = api_settings.DEFAULT_MIDDLEWARE_CLASSES authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES @@ -265,6 +266,12 @@ class APIView(View): """ return [parser() for parser in self.parser_classes] + def get_middlewares(self): + """ + Instantiates and returns the list of middlewares that this view can use. + """ + return [middleware() for middleware in self.middleware_classes] + def get_authenticators(self): """ Instantiates and returns the list of authenticators that this view can use. @@ -479,6 +486,24 @@ class APIView(View): request.force_plaintext_errors(use_plaintext_traceback) raise exc + def process_request(self, request): + """ + Pre-process the request by the middleware instances. + """ + middlewares = getattr(self, "middlewares", self.get_middlewares()) + for middleware in middlewares: + if hasattr(middleware, "process_request"): + middleware.process_request(request) + + def process_response(self, response): + """ + Pre-process the response by the middleware instances. + """ + middlewares = getattr(self, "middlewares", self.get_middlewares()) + for middleware in reversed(middlewares): + if hasattr(middleware, "process_response"): + middleware.process_response(response) + # Note: Views are made CSRF exempt from within `as_view` as to prevent # accidental removal of this exemption in cases where `dispatch` needs to # be overridden. @@ -492,8 +517,10 @@ class APIView(View): request = self.initialize_request(request, *args, **kwargs) self.request = request self.headers = self.default_response_headers # deprecate? + self.middlewares = self.get_middlewares() try: + self.process_request(request) self.initial(request, *args, **kwargs) # Get the appropriate handler method @@ -504,6 +531,7 @@ class APIView(View): handler = self.http_method_not_allowed response = handler(request, *args, **kwargs) + self.process_response(response) except Exception as exc: response = self.handle_exception(exc) diff --git a/tests/test_rest_middleware.py b/tests/test_rest_middleware.py new file mode 100644 index 000000000..b8837ca75 --- /dev/null +++ b/tests/test_rest_middleware.py @@ -0,0 +1,71 @@ +from django.test import TestCase, override_settings +from django.urls import path + +from rest_framework import status +from rest_framework.middleware import ( + BarMiddleware, BaseMiddleware, FooMiddleware +) +from rest_framework.response import Response +from rest_framework.test import APIClient, APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + + +class DummyRequestMiddleware(BaseMiddleware): + def process_request(self, request): + request._dummy = "dummy" + + +class DummyResponseMiddleware(BarMiddleware): + def process_response(self, response): + response._dummy = "dummy" + + +class MockView(APIView): + def get(self, request): + response = Response(status=status.HTTP_200_OK) + response._request = request # test client sets `request` input + return response + + +urlpatterns = [ + path('foo/', MockView.as_view(middleware_classes=[FooMiddleware])), + path('bar/', MockView.as_view(middleware_classes=[BarMiddleware])), + path('multiple/', MockView.as_view(middleware_classes=[DummyRequestMiddleware, DummyResponseMiddleware])), + path('none/', MockView.as_view(middleware_classes=[])) +] + + +@override_settings(ROOT_URLCONF=__name__) +class FooMiddlewareTests(TestCase): + def test_foo_middleware_process_request(self): + response = APIClient().get('/foo/') + request = response._request + assert getattr(request, "_foo") == "foo" + assert response.status_code == status.HTTP_200_OK + + +@override_settings(ROOT_URLCONF=__name__) +class BarMiddlewareTests(TestCase): + def test_bar_middleware_process_response(self): + response = APIClient().get('/bar/') + assert getattr(response, "_bar") == "bar" + assert response.status_code == status.HTTP_200_OK + + +@override_settings(ROOT_URLCONF=__name__) +class MultipleMiddlewareClassesTests(TestCase): + def test_multiple_middleware_classes_process_request_and_response(self): + response = APIClient().get('/multiple/') + request = response._request + assert getattr(request, "_dummy") == "dummy" + assert getattr(response, "_dummy") == "dummy" + assert response.status_code == status.HTTP_200_OK + + +@override_settings(ROOT_URLCONF=__name__) +class NoMiddlewareClassesTests(TestCase): + def test_bar_middleware_process_request(self): + response = APIClient().get('/none/') + assert response.status_code == status.HTTP_200_OK