mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-18 12:12:19 +03:00
feat: introduce middleware classes
This commit is contained in:
parent
40eccb0d6c
commit
7f320d6239
|
@ -61,6 +61,9 @@ def api_view(http_method_names=None):
|
|||
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
|
||||
APIView.parser_classes)
|
||||
|
||||
WrappedAPIView.middleware_classes = getattr(func, 'middleware_classes',
|
||||
APIView.middleware_classes)
|
||||
|
||||
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
|
||||
APIView.authentication_classes)
|
||||
|
||||
|
|
26
rest_framework/middleware.py
Normal file
26
rest_framework/middleware.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
class BaseMiddleware:
|
||||
"""
|
||||
All middleware classes should extend BaseMiddleware.
|
||||
"""
|
||||
|
||||
def process_request(self, request):
|
||||
pass
|
||||
|
||||
def process_response(self, response):
|
||||
pass
|
||||
|
||||
|
||||
class FooMiddleware(BaseMiddleware):
|
||||
def process_request(self, request):
|
||||
request._foo = "foo"
|
||||
|
||||
def process_response(self, response):
|
||||
pass
|
||||
|
||||
|
||||
class BarMiddleware(BaseMiddleware):
|
||||
def process_request(self, request):
|
||||
pass
|
||||
|
||||
def process_response(self, response):
|
||||
response._bar = "bar"
|
|
@ -37,6 +37,10 @@ DEFAULTS = {
|
|||
'rest_framework.parsers.FormParser',
|
||||
'rest_framework.parsers.MultiPartParser'
|
||||
],
|
||||
'DEFAULT_MIDDLEWARE_CLASSES': [
|
||||
'rest_framework.middleware.FooMiddleware',
|
||||
'rest_framework.middleware.BarMiddleware'
|
||||
],
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': [
|
||||
'rest_framework.authentication.SessionAuthentication',
|
||||
'rest_framework.authentication.BasicAuthentication'
|
||||
|
@ -133,6 +137,7 @@ DEFAULTS = {
|
|||
IMPORT_STRINGS = [
|
||||
'DEFAULT_RENDERER_CLASSES',
|
||||
'DEFAULT_PARSER_CLASSES',
|
||||
'DEFAULT_MIDDLEWARE_CLASSES',
|
||||
'DEFAULT_AUTHENTICATION_CLASSES',
|
||||
'DEFAULT_PERMISSION_CLASSES',
|
||||
'DEFAULT_THROTTLE_CLASSES',
|
||||
|
|
|
@ -106,6 +106,7 @@ class APIView(View):
|
|||
# The following policies may be set at either globally, or per-view.
|
||||
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
|
||||
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
|
||||
middleware_classes = api_settings.DEFAULT_MIDDLEWARE_CLASSES
|
||||
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
|
||||
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
|
||||
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
|
||||
|
@ -265,6 +266,12 @@ class APIView(View):
|
|||
"""
|
||||
return [parser() for parser in self.parser_classes]
|
||||
|
||||
def get_middlewares(self):
|
||||
"""
|
||||
Instantiates and returns the list of middlewares that this view can use.
|
||||
"""
|
||||
return [middleware() for middleware in self.middleware_classes]
|
||||
|
||||
def get_authenticators(self):
|
||||
"""
|
||||
Instantiates and returns the list of authenticators that this view can use.
|
||||
|
@ -479,6 +486,24 @@ class APIView(View):
|
|||
request.force_plaintext_errors(use_plaintext_traceback)
|
||||
raise exc
|
||||
|
||||
def process_request(self, request):
|
||||
"""
|
||||
Pre-process the request by the middleware instances.
|
||||
"""
|
||||
middlewares = getattr(self, "middlewares", self.get_middlewares())
|
||||
for middleware in middlewares:
|
||||
if hasattr(middleware, "process_request"):
|
||||
middleware.process_request(request)
|
||||
|
||||
def process_response(self, response):
|
||||
"""
|
||||
Pre-process the response by the middleware instances.
|
||||
"""
|
||||
middlewares = getattr(self, "middlewares", self.get_middlewares())
|
||||
for middleware in reversed(middlewares):
|
||||
if hasattr(middleware, "process_response"):
|
||||
middleware.process_response(response)
|
||||
|
||||
# Note: Views are made CSRF exempt from within `as_view` as to prevent
|
||||
# accidental removal of this exemption in cases where `dispatch` needs to
|
||||
# be overridden.
|
||||
|
@ -492,8 +517,10 @@ class APIView(View):
|
|||
request = self.initialize_request(request, *args, **kwargs)
|
||||
self.request = request
|
||||
self.headers = self.default_response_headers # deprecate?
|
||||
self.middlewares = self.get_middlewares()
|
||||
|
||||
try:
|
||||
self.process_request(request)
|
||||
self.initial(request, *args, **kwargs)
|
||||
|
||||
# Get the appropriate handler method
|
||||
|
@ -504,6 +531,7 @@ class APIView(View):
|
|||
handler = self.http_method_not_allowed
|
||||
|
||||
response = handler(request, *args, **kwargs)
|
||||
self.process_response(response)
|
||||
|
||||
except Exception as exc:
|
||||
response = self.handle_exception(exc)
|
||||
|
|
71
tests/test_rest_middleware.py
Normal file
71
tests/test_rest_middleware.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from django.test import TestCase, override_settings
|
||||
from django.urls import path
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.middleware import (
|
||||
BarMiddleware, BaseMiddleware, FooMiddleware
|
||||
)
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APIClient, APIRequestFactory
|
||||
from rest_framework.views import APIView
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
||||
class DummyRequestMiddleware(BaseMiddleware):
|
||||
def process_request(self, request):
|
||||
request._dummy = "dummy"
|
||||
|
||||
|
||||
class DummyResponseMiddleware(BarMiddleware):
|
||||
def process_response(self, response):
|
||||
response._dummy = "dummy"
|
||||
|
||||
|
||||
class MockView(APIView):
|
||||
def get(self, request):
|
||||
response = Response(status=status.HTTP_200_OK)
|
||||
response._request = request # test client sets `request` input
|
||||
return response
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path('foo/', MockView.as_view(middleware_classes=[FooMiddleware])),
|
||||
path('bar/', MockView.as_view(middleware_classes=[BarMiddleware])),
|
||||
path('multiple/', MockView.as_view(middleware_classes=[DummyRequestMiddleware, DummyResponseMiddleware])),
|
||||
path('none/', MockView.as_view(middleware_classes=[]))
|
||||
]
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF=__name__)
|
||||
class FooMiddlewareTests(TestCase):
|
||||
def test_foo_middleware_process_request(self):
|
||||
response = APIClient().get('/foo/')
|
||||
request = response._request
|
||||
assert getattr(request, "_foo") == "foo"
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF=__name__)
|
||||
class BarMiddlewareTests(TestCase):
|
||||
def test_bar_middleware_process_response(self):
|
||||
response = APIClient().get('/bar/')
|
||||
assert getattr(response, "_bar") == "bar"
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF=__name__)
|
||||
class MultipleMiddlewareClassesTests(TestCase):
|
||||
def test_multiple_middleware_classes_process_request_and_response(self):
|
||||
response = APIClient().get('/multiple/')
|
||||
request = response._request
|
||||
assert getattr(request, "_dummy") == "dummy"
|
||||
assert getattr(response, "_dummy") == "dummy"
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF=__name__)
|
||||
class NoMiddlewareClassesTests(TestCase):
|
||||
def test_bar_middleware_process_request(self):
|
||||
response = APIClient().get('/none/')
|
||||
assert response.status_code == status.HTTP_200_OK
|
Loading…
Reference in New Issue
Block a user