mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-24 02:24:03 +03:00
Allow custom throttle to provide a custom detail
This commit is contained in:
parent
9cfa4bd7cc
commit
d540b10bb7
|
@ -201,7 +201,15 @@ User requests to either `ContactListView` or `ContactDetailView` would be restri
|
||||||
|
|
||||||
To create a custom throttle, override `BaseThrottle` and implement `.allow_request(self, request, view)`. The method should return `True` if the request should be allowed, and `False` otherwise.
|
To create a custom throttle, override `BaseThrottle` and implement `.allow_request(self, request, view)`. The method should return `True` if the request should be allowed, and `False` otherwise.
|
||||||
|
|
||||||
Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recommended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.allow_request()` has previously returned `False`.
|
Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return one of the following:
|
||||||
|
|
||||||
|
- a single value representing the recommended number of seconds to wait before attempting the next request
|
||||||
|
- a tuple with two elements, in order:
|
||||||
|
- the recommended number of seconds to wait before attempting the next request or `None`
|
||||||
|
- a string to be used as _detail_ message
|
||||||
|
- `None` (default)
|
||||||
|
|
||||||
|
The `.wait()` method will only be called if `.allow_request()` has previously returned `False`.
|
||||||
|
|
||||||
If the `.wait()` method is implemented and the request is throttled, then a `Retry-After` header will be included in the response.
|
If the `.wait()` method is implemented and the request is throttled, then a `Retry-After` header will be included in the response.
|
||||||
|
|
||||||
|
|
|
@ -174,11 +174,11 @@ class APIView(View):
|
||||||
raise exceptions.NotAuthenticated()
|
raise exceptions.NotAuthenticated()
|
||||||
raise exceptions.PermissionDenied(detail=message, code=code)
|
raise exceptions.PermissionDenied(detail=message, code=code)
|
||||||
|
|
||||||
def throttled(self, request, wait):
|
def throttled(self, request, wait, detail=None):
|
||||||
"""
|
"""
|
||||||
If request is throttled, determine what kind of exception to raise.
|
If request is throttled, determine what kind of exception to raise.
|
||||||
"""
|
"""
|
||||||
raise exceptions.Throttled(wait)
|
raise exceptions.Throttled(wait, detail)
|
||||||
|
|
||||||
def get_authenticate_header(self, request):
|
def get_authenticate_header(self, request):
|
||||||
"""
|
"""
|
||||||
|
@ -367,8 +367,12 @@ class APIView(View):
|
||||||
if duration is not None
|
if duration is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
duration = max(durations, default=None)
|
# consider also wait to return (duration, message) tuple
|
||||||
self.throttled(request, duration)
|
duration = max(durations, key=lambda d: d[0] or 0 if isinstance(d, (list, tuple)) else d, default=None)
|
||||||
|
if isinstance(duration, (list, tuple)):
|
||||||
|
self.throttled(request, *duration[:2])
|
||||||
|
else:
|
||||||
|
self.throttled(request, duration)
|
||||||
|
|
||||||
def determine_version(self, request, *args, **kwargs):
|
def determine_version(self, request, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -43,6 +43,14 @@ class NonTimeThrottle(BaseThrottle):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDetailThrottle(BaseThrottle):
|
||||||
|
def allow_request(self, request, view):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
return None, 'custom detail'
|
||||||
|
|
||||||
|
|
||||||
class MockView_DoubleThrottling(APIView):
|
class MockView_DoubleThrottling(APIView):
|
||||||
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
|
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
|
||||||
|
|
||||||
|
@ -57,6 +65,13 @@ class MockView(APIView):
|
||||||
return Response('foo')
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
|
class MockView_CustomDetail(APIView):
|
||||||
|
throttle_classes = (CustomDetailThrottle,)
|
||||||
|
|
||||||
|
def get(self, request):
|
||||||
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class MockView_MinuteThrottling(APIView):
|
class MockView_MinuteThrottling(APIView):
|
||||||
throttle_classes = (User3MinRateThrottle,)
|
throttle_classes = (User3MinRateThrottle,)
|
||||||
|
|
||||||
|
@ -88,6 +103,15 @@ class ThrottlingTests(TestCase):
|
||||||
response = MockView.as_view()(request)
|
response = MockView.as_view()(request)
|
||||||
assert response.status_code == 429
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
def test_requests_are_throttled_custom_detail(self):
|
||||||
|
"""
|
||||||
|
Ensure request rate is limited
|
||||||
|
"""
|
||||||
|
request = self.factory.get('/')
|
||||||
|
response = MockView_CustomDetail.as_view()(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert response.data == {'detail': 'custom detail'}
|
||||||
|
|
||||||
def set_throttle_timer(self, view, value):
|
def set_throttle_timer(self, view, value):
|
||||||
"""
|
"""
|
||||||
Explicitly set the timer, overriding time.time()
|
Explicitly set the timer, overriding time.time()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user