mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-18 12:12:19 +03:00
feat: introduce middleware classes
This commit is contained in:
parent
40eccb0d6c
commit
7f320d6239
|
@ -61,6 +61,9 @@ def api_view(http_method_names=None):
|
||||||
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
|
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
|
||||||
APIView.parser_classes)
|
APIView.parser_classes)
|
||||||
|
|
||||||
|
WrappedAPIView.middleware_classes = getattr(func, 'middleware_classes',
|
||||||
|
APIView.middleware_classes)
|
||||||
|
|
||||||
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
|
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
|
||||||
APIView.authentication_classes)
|
APIView.authentication_classes)
|
||||||
|
|
||||||
|
|
26
rest_framework/middleware.py
Normal file
26
rest_framework/middleware.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
class BaseMiddleware:
|
||||||
|
"""
|
||||||
|
All middleware classes should extend BaseMiddleware.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def process_request(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process_response(self, response):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FooMiddleware(BaseMiddleware):
|
||||||
|
def process_request(self, request):
|
||||||
|
request._foo = "foo"
|
||||||
|
|
||||||
|
def process_response(self, response):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BarMiddleware(BaseMiddleware):
|
||||||
|
def process_request(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process_response(self, response):
|
||||||
|
response._bar = "bar"
|
|
@ -37,6 +37,10 @@ DEFAULTS = {
|
||||||
'rest_framework.parsers.FormParser',
|
'rest_framework.parsers.FormParser',
|
||||||
'rest_framework.parsers.MultiPartParser'
|
'rest_framework.parsers.MultiPartParser'
|
||||||
],
|
],
|
||||||
|
'DEFAULT_MIDDLEWARE_CLASSES': [
|
||||||
|
'rest_framework.middleware.FooMiddleware',
|
||||||
|
'rest_framework.middleware.BarMiddleware'
|
||||||
|
],
|
||||||
'DEFAULT_AUTHENTICATION_CLASSES': [
|
'DEFAULT_AUTHENTICATION_CLASSES': [
|
||||||
'rest_framework.authentication.SessionAuthentication',
|
'rest_framework.authentication.SessionAuthentication',
|
||||||
'rest_framework.authentication.BasicAuthentication'
|
'rest_framework.authentication.BasicAuthentication'
|
||||||
|
@ -133,6 +137,7 @@ DEFAULTS = {
|
||||||
IMPORT_STRINGS = [
|
IMPORT_STRINGS = [
|
||||||
'DEFAULT_RENDERER_CLASSES',
|
'DEFAULT_RENDERER_CLASSES',
|
||||||
'DEFAULT_PARSER_CLASSES',
|
'DEFAULT_PARSER_CLASSES',
|
||||||
|
'DEFAULT_MIDDLEWARE_CLASSES',
|
||||||
'DEFAULT_AUTHENTICATION_CLASSES',
|
'DEFAULT_AUTHENTICATION_CLASSES',
|
||||||
'DEFAULT_PERMISSION_CLASSES',
|
'DEFAULT_PERMISSION_CLASSES',
|
||||||
'DEFAULT_THROTTLE_CLASSES',
|
'DEFAULT_THROTTLE_CLASSES',
|
||||||
|
|
|
@ -106,6 +106,7 @@ class APIView(View):
|
||||||
# The following policies may be set at either globally, or per-view.
|
# The following policies may be set at either globally, or per-view.
|
||||||
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
|
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
|
||||||
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
|
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
|
||||||
|
middleware_classes = api_settings.DEFAULT_MIDDLEWARE_CLASSES
|
||||||
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
|
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
|
||||||
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
|
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
|
||||||
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
|
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
|
||||||
|
@ -265,6 +266,12 @@ class APIView(View):
|
||||||
"""
|
"""
|
||||||
return [parser() for parser in self.parser_classes]
|
return [parser() for parser in self.parser_classes]
|
||||||
|
|
||||||
|
def get_middlewares(self):
|
||||||
|
"""
|
||||||
|
Instantiates and returns the list of middlewares that this view can use.
|
||||||
|
"""
|
||||||
|
return [middleware() for middleware in self.middleware_classes]
|
||||||
|
|
||||||
def get_authenticators(self):
|
def get_authenticators(self):
|
||||||
"""
|
"""
|
||||||
Instantiates and returns the list of authenticators that this view can use.
|
Instantiates and returns the list of authenticators that this view can use.
|
||||||
|
@ -479,6 +486,24 @@ class APIView(View):
|
||||||
request.force_plaintext_errors(use_plaintext_traceback)
|
request.force_plaintext_errors(use_plaintext_traceback)
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
|
def process_request(self, request):
|
||||||
|
"""
|
||||||
|
Pre-process the request by the middleware instances.
|
||||||
|
"""
|
||||||
|
middlewares = getattr(self, "middlewares", self.get_middlewares())
|
||||||
|
for middleware in middlewares:
|
||||||
|
if hasattr(middleware, "process_request"):
|
||||||
|
middleware.process_request(request)
|
||||||
|
|
||||||
|
def process_response(self, response):
|
||||||
|
"""
|
||||||
|
Pre-process the response by the middleware instances.
|
||||||
|
"""
|
||||||
|
middlewares = getattr(self, "middlewares", self.get_middlewares())
|
||||||
|
for middleware in reversed(middlewares):
|
||||||
|
if hasattr(middleware, "process_response"):
|
||||||
|
middleware.process_response(response)
|
||||||
|
|
||||||
# Note: Views are made CSRF exempt from within `as_view` as to prevent
|
# Note: Views are made CSRF exempt from within `as_view` as to prevent
|
||||||
# accidental removal of this exemption in cases where `dispatch` needs to
|
# accidental removal of this exemption in cases where `dispatch` needs to
|
||||||
# be overridden.
|
# be overridden.
|
||||||
|
@ -492,8 +517,10 @@ class APIView(View):
|
||||||
request = self.initialize_request(request, *args, **kwargs)
|
request = self.initialize_request(request, *args, **kwargs)
|
||||||
self.request = request
|
self.request = request
|
||||||
self.headers = self.default_response_headers # deprecate?
|
self.headers = self.default_response_headers # deprecate?
|
||||||
|
self.middlewares = self.get_middlewares()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
self.process_request(request)
|
||||||
self.initial(request, *args, **kwargs)
|
self.initial(request, *args, **kwargs)
|
||||||
|
|
||||||
# Get the appropriate handler method
|
# Get the appropriate handler method
|
||||||
|
@ -504,6 +531,7 @@ class APIView(View):
|
||||||
handler = self.http_method_not_allowed
|
handler = self.http_method_not_allowed
|
||||||
|
|
||||||
response = handler(request, *args, **kwargs)
|
response = handler(request, *args, **kwargs)
|
||||||
|
self.process_response(response)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
response = self.handle_exception(exc)
|
response = self.handle_exception(exc)
|
||||||
|
|
71
tests/test_rest_middleware.py
Normal file
71
tests/test_rest_middleware.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
from django.test import TestCase, override_settings
|
||||||
|
from django.urls import path
|
||||||
|
|
||||||
|
from rest_framework import status
|
||||||
|
from rest_framework.middleware import (
|
||||||
|
BarMiddleware, BaseMiddleware, FooMiddleware
|
||||||
|
)
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.test import APIClient, APIRequestFactory
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
factory = APIRequestFactory()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyRequestMiddleware(BaseMiddleware):
|
||||||
|
def process_request(self, request):
|
||||||
|
request._dummy = "dummy"
|
||||||
|
|
||||||
|
|
||||||
|
class DummyResponseMiddleware(BarMiddleware):
|
||||||
|
def process_response(self, response):
|
||||||
|
response._dummy = "dummy"
|
||||||
|
|
||||||
|
|
||||||
|
class MockView(APIView):
|
||||||
|
def get(self, request):
|
||||||
|
response = Response(status=status.HTTP_200_OK)
|
||||||
|
response._request = request # test client sets `request` input
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path('foo/', MockView.as_view(middleware_classes=[FooMiddleware])),
|
||||||
|
path('bar/', MockView.as_view(middleware_classes=[BarMiddleware])),
|
||||||
|
path('multiple/', MockView.as_view(middleware_classes=[DummyRequestMiddleware, DummyResponseMiddleware])),
|
||||||
|
path('none/', MockView.as_view(middleware_classes=[]))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(ROOT_URLCONF=__name__)
|
||||||
|
class FooMiddlewareTests(TestCase):
|
||||||
|
def test_foo_middleware_process_request(self):
|
||||||
|
response = APIClient().get('/foo/')
|
||||||
|
request = response._request
|
||||||
|
assert getattr(request, "_foo") == "foo"
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(ROOT_URLCONF=__name__)
|
||||||
|
class BarMiddlewareTests(TestCase):
|
||||||
|
def test_bar_middleware_process_response(self):
|
||||||
|
response = APIClient().get('/bar/')
|
||||||
|
assert getattr(response, "_bar") == "bar"
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(ROOT_URLCONF=__name__)
|
||||||
|
class MultipleMiddlewareClassesTests(TestCase):
|
||||||
|
def test_multiple_middleware_classes_process_request_and_response(self):
|
||||||
|
response = APIClient().get('/multiple/')
|
||||||
|
request = response._request
|
||||||
|
assert getattr(request, "_dummy") == "dummy"
|
||||||
|
assert getattr(response, "_dummy") == "dummy"
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(ROOT_URLCONF=__name__)
|
||||||
|
class NoMiddlewareClassesTests(TestCase):
|
||||||
|
def test_bar_middleware_process_request(self):
|
||||||
|
response = APIClient().get('/none/')
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
Loading…
Reference in New Issue
Block a user