mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-26 03:23:59 +03:00
Always call all throttling classes on the view when checking throttles (#6711)
This commit is contained in:
parent
19ca86d8d6
commit
afb678433b
|
@ -350,9 +350,13 @@ class APIView(View):
|
||||||
Check if request should be throttled.
|
Check if request should be throttled.
|
||||||
Raises an appropriate exception if the request is throttled.
|
Raises an appropriate exception if the request is throttled.
|
||||||
"""
|
"""
|
||||||
|
throttle_durations = []
|
||||||
for throttle in self.get_throttles():
|
for throttle in self.get_throttles():
|
||||||
if not throttle.allow_request(request, self):
|
if not throttle.allow_request(request, self):
|
||||||
self.throttled(request, throttle.wait())
|
throttle_durations.append(throttle.wait())
|
||||||
|
|
||||||
|
if throttle_durations:
|
||||||
|
self.throttled(request, max(throttle_durations))
|
||||||
|
|
||||||
def determine_version(self, request, *args, **kwargs):
|
def determine_version(self, request, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -30,6 +30,11 @@ class User3MinRateThrottle(UserRateThrottle):
|
||||||
scope = 'minutes'
|
scope = 'minutes'
|
||||||
|
|
||||||
|
|
||||||
|
class User6MinRateThrottle(UserRateThrottle):
|
||||||
|
rate = '6/min'
|
||||||
|
scope = 'minutes'
|
||||||
|
|
||||||
|
|
||||||
class NonTimeThrottle(BaseThrottle):
|
class NonTimeThrottle(BaseThrottle):
|
||||||
def allow_request(self, request, view):
|
def allow_request(self, request, view):
|
||||||
if not hasattr(self.__class__, 'called'):
|
if not hasattr(self.__class__, 'called'):
|
||||||
|
@ -38,6 +43,13 @@ class NonTimeThrottle(BaseThrottle):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MockView_DoubleThrottling(APIView):
|
||||||
|
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
|
||||||
|
|
||||||
|
def get(self, request):
|
||||||
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class MockView(APIView):
|
class MockView(APIView):
|
||||||
throttle_classes = (User3SecRateThrottle,)
|
throttle_classes = (User3SecRateThrottle,)
|
||||||
|
|
||||||
|
@ -80,7 +92,8 @@ class ThrottlingTests(TestCase):
|
||||||
"""
|
"""
|
||||||
Explicitly set the timer, overriding time.time()
|
Explicitly set the timer, overriding time.time()
|
||||||
"""
|
"""
|
||||||
view.throttle_classes[0].timer = lambda self: value
|
for cls in view.throttle_classes:
|
||||||
|
cls.timer = lambda self: value
|
||||||
|
|
||||||
def test_request_throttling_expires(self):
|
def test_request_throttling_expires(self):
|
||||||
"""
|
"""
|
||||||
|
@ -115,6 +128,37 @@ class ThrottlingTests(TestCase):
|
||||||
"""
|
"""
|
||||||
self.ensure_is_throttled(MockView, 200)
|
self.ensure_is_throttled(MockView, 200)
|
||||||
|
|
||||||
|
def test_request_throttling_multiple_throttles(self):
|
||||||
|
"""
|
||||||
|
Ensure all throttle classes see each request even when the request is
|
||||||
|
already being throttled
|
||||||
|
"""
|
||||||
|
self.set_throttle_timer(MockView_DoubleThrottling, 0)
|
||||||
|
request = self.factory.get('/')
|
||||||
|
for dummy in range(4):
|
||||||
|
response = MockView_DoubleThrottling.as_view()(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert int(response['retry-after']) == 1
|
||||||
|
|
||||||
|
# At this point our client made 4 requests (one was throttled) in a
|
||||||
|
# second. If we advance the timer by one additional second, the client
|
||||||
|
# should be allowed to make 2 more before being throttled by the 2nd
|
||||||
|
# throttle class, which has a limit of 6 per minute.
|
||||||
|
self.set_throttle_timer(MockView_DoubleThrottling, 1)
|
||||||
|
for dummy in range(2):
|
||||||
|
response = MockView_DoubleThrottling.as_view()(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response = MockView_DoubleThrottling.as_view()(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert int(response['retry-after']) == 59
|
||||||
|
|
||||||
|
# Just to make sure check again after two more seconds.
|
||||||
|
self.set_throttle_timer(MockView_DoubleThrottling, 2)
|
||||||
|
response = MockView_DoubleThrottling.as_view()(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert int(response['retry-after']) == 58
|
||||||
|
|
||||||
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
|
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
|
||||||
"""
|
"""
|
||||||
Ensure the response returns an Retry-After field with status and next attributes
|
Ensure the response returns an Retry-After field with status and next attributes
|
||||||
|
|
Loading…
Reference in New Issue
Block a user