feat: introduce middleware classes

This commit is contained in:
Sui Yang 2023-08-27 02:04:02 +08:00
parent 40eccb0d6c
commit 7f320d6239
5 changed files with 133 additions and 0 deletions

View File

@ -61,6 +61,9 @@ def api_view(http_method_names=None):
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
APIView.parser_classes)
WrappedAPIView.middleware_classes = getattr(func, 'middleware_classes',
APIView.middleware_classes)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
APIView.authentication_classes)

View File

@ -0,0 +1,26 @@
class BaseMiddleware:
"""
All middleware classes should extend BaseMiddleware.
"""
def process_request(self, request):
pass
def process_response(self, response):
pass
class FooMiddleware(BaseMiddleware):
def process_request(self, request):
request._foo = "foo"
def process_response(self, response):
pass
class BarMiddleware(BaseMiddleware):
def process_request(self, request):
pass
def process_response(self, response):
response._bar = "bar"

View File

@ -37,6 +37,10 @@ DEFAULTS = {
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser'
],
'DEFAULT_MIDDLEWARE_CLASSES': [
'rest_framework.middleware.FooMiddleware',
'rest_framework.middleware.BarMiddleware'
],
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.BasicAuthentication'
@ -133,6 +137,7 @@ DEFAULTS = {
IMPORT_STRINGS = [
'DEFAULT_RENDERER_CLASSES',
'DEFAULT_PARSER_CLASSES',
'DEFAULT_MIDDLEWARE_CLASSES',
'DEFAULT_AUTHENTICATION_CLASSES',
'DEFAULT_PERMISSION_CLASSES',
'DEFAULT_THROTTLE_CLASSES',

View File

@ -106,6 +106,7 @@ class APIView(View):
# The following policies may be set at either globally, or per-view.
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
middleware_classes = api_settings.DEFAULT_MIDDLEWARE_CLASSES
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
@ -265,6 +266,12 @@ class APIView(View):
"""
return [parser() for parser in self.parser_classes]
def get_middlewares(self):
"""
Instantiates and returns the list of middlewares that this view can use.
"""
return [middleware() for middleware in self.middleware_classes]
def get_authenticators(self):
"""
Instantiates and returns the list of authenticators that this view can use.
@ -479,6 +486,24 @@ class APIView(View):
request.force_plaintext_errors(use_plaintext_traceback)
raise exc
def process_request(self, request):
"""
Pre-process the request by the middleware instances.
"""
middlewares = getattr(self, "middlewares", self.get_middlewares())
for middleware in middlewares:
if hasattr(middleware, "process_request"):
middleware.process_request(request)
def process_response(self, response):
"""
Pre-process the response by the middleware instances.
"""
middlewares = getattr(self, "middlewares", self.get_middlewares())
for middleware in reversed(middlewares):
if hasattr(middleware, "process_response"):
middleware.process_response(response)
# Note: Views are made CSRF exempt from within `as_view` as to prevent
# accidental removal of this exemption in cases where `dispatch` needs to
# be overridden.
@ -492,8 +517,10 @@ class APIView(View):
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
self.middlewares = self.get_middlewares()
try:
self.process_request(request)
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
@ -504,6 +531,7 @@ class APIView(View):
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
self.process_response(response)
except Exception as exc:
response = self.handle_exception(exc)

View File

@ -0,0 +1,71 @@
from django.test import TestCase, override_settings
from django.urls import path
from rest_framework import status
from rest_framework.middleware import (
BarMiddleware, BaseMiddleware, FooMiddleware
)
from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView
factory = APIRequestFactory()
class DummyRequestMiddleware(BaseMiddleware):
def process_request(self, request):
request._dummy = "dummy"
class DummyResponseMiddleware(BarMiddleware):
def process_response(self, response):
response._dummy = "dummy"
class MockView(APIView):
def get(self, request):
response = Response(status=status.HTTP_200_OK)
response._request = request # test client sets `request` input
return response
urlpatterns = [
path('foo/', MockView.as_view(middleware_classes=[FooMiddleware])),
path('bar/', MockView.as_view(middleware_classes=[BarMiddleware])),
path('multiple/', MockView.as_view(middleware_classes=[DummyRequestMiddleware, DummyResponseMiddleware])),
path('none/', MockView.as_view(middleware_classes=[]))
]
@override_settings(ROOT_URLCONF=__name__)
class FooMiddlewareTests(TestCase):
def test_foo_middleware_process_request(self):
response = APIClient().get('/foo/')
request = response._request
assert getattr(request, "_foo") == "foo"
assert response.status_code == status.HTTP_200_OK
@override_settings(ROOT_URLCONF=__name__)
class BarMiddlewareTests(TestCase):
def test_bar_middleware_process_response(self):
response = APIClient().get('/bar/')
assert getattr(response, "_bar") == "bar"
assert response.status_code == status.HTTP_200_OK
@override_settings(ROOT_URLCONF=__name__)
class MultipleMiddlewareClassesTests(TestCase):
def test_multiple_middleware_classes_process_request_and_response(self):
response = APIClient().get('/multiple/')
request = response._request
assert getattr(request, "_dummy") == "dummy"
assert getattr(response, "_dummy") == "dummy"
assert response.status_code == status.HTTP_200_OK
@override_settings(ROOT_URLCONF=__name__)
class NoMiddlewareClassesTests(TestCase):
def test_bar_middleware_process_request(self):
response = APIClient().get('/none/')
assert response.status_code == status.HTTP_200_OK