From d540b10bb7d42db1f2bf0e9a13b2bd6b14137158 Mon Sep 17 00:00:00 2001 From: Davide Setti Date: Thu, 4 Mar 2021 11:47:02 +0100 Subject: [PATCH] Allow custom throttle to provide a custom detail --- docs/api-guide/throttling.md | 10 +++++++++- rest_framework/views.py | 12 ++++++++---- tests/test_throttling.py | 24 ++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md index 4c58fa713..098f8b6b7 100644 --- a/docs/api-guide/throttling.md +++ b/docs/api-guide/throttling.md @@ -201,7 +201,15 @@ User requests to either `ContactListView` or `ContactDetailView` would be restri To create a custom throttle, override `BaseThrottle` and implement `.allow_request(self, request, view)`. The method should return `True` if the request should be allowed, and `False` otherwise. -Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recommended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.allow_request()` has previously returned `False`. +Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return one of the following: + +- a single value representing the recommended number of seconds to wait before attempting the next request +- a tuple with two elements, in order: + - the recommended number of seconds to wait before attempting the next request or `None` + - a string to be used as _detail_ message +- `None` (default) + +The `.wait()` method will only be called if `.allow_request()` has previously returned `False`. If the `.wait()` method is implemented and the request is throttled, then a `Retry-After` header will be included in the response. diff --git a/rest_framework/views.py b/rest_framework/views.py index 4c30029fd..dae080803 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -174,11 +174,11 @@ class APIView(View): raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied(detail=message, code=code) - def throttled(self, request, wait): + def throttled(self, request, wait, detail=None): """ If request is throttled, determine what kind of exception to raise. """ - raise exceptions.Throttled(wait) + raise exceptions.Throttled(wait, detail) def get_authenticate_header(self, request): """ @@ -367,8 +367,12 @@ class APIView(View): if duration is not None ] - duration = max(durations, default=None) - self.throttled(request, duration) + # consider also wait to return (duration, message) tuple + duration = max(durations, key=lambda d: d[0] or 0 if isinstance(d, (list, tuple)) else d, default=None) + if isinstance(duration, (list, tuple)): + self.throttled(request, *duration[:2]) + else: + self.throttled(request, duration) def determine_version(self, request, *args, **kwargs): """ diff --git a/tests/test_throttling.py b/tests/test_throttling.py index d5a61232d..36b50b204 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -43,6 +43,14 @@ class NonTimeThrottle(BaseThrottle): return False +class CustomDetailThrottle(BaseThrottle): + def allow_request(self, request, view): + return False + + def wait(self): + return None, 'custom detail' + + class MockView_DoubleThrottling(APIView): throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,) @@ -57,6 +65,13 @@ class MockView(APIView): return Response('foo') +class MockView_CustomDetail(APIView): + throttle_classes = (CustomDetailThrottle,) + + def get(self, request): + return Response('foo') + + class MockView_MinuteThrottling(APIView): throttle_classes = (User3MinRateThrottle,) @@ -88,6 +103,15 @@ class ThrottlingTests(TestCase): response = MockView.as_view()(request) assert response.status_code == 429 + def test_requests_are_throttled_custom_detail(self): + """ + Ensure request rate is limited + """ + request = self.factory.get('/') + response = MockView_CustomDetail.as_view()(request) + assert response.status_code == 429 + assert response.data == {'detail': 'custom detail'} + def set_throttle_timer(self, view, value): """ Explicitly set the timer, overriding time.time()