From afb678433b7dc068213e6e8a359ad1f6fff05b0d Mon Sep 17 00:00:00 2001 From: Mathieu Pillard Date: Thu, 23 May 2019 15:42:29 +0200 Subject: [PATCH] Always call all throttling classes on the view when checking throttles (#6711) --- rest_framework/views.py | 6 +++++- tests/test_throttling.py | 46 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/rest_framework/views.py b/rest_framework/views.py index 6ef7021d4..832f17233 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -350,9 +350,13 @@ class APIView(View): Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ + throttle_durations = [] for throttle in self.get_throttles(): if not throttle.allow_request(request, self): - self.throttled(request, throttle.wait()) + throttle_durations.append(throttle.wait()) + + if throttle_durations: + self.throttled(request, max(throttle_durations)) def determine_version(self, request, *args, **kwargs): """ diff --git a/tests/test_throttling.py b/tests/test_throttling.py index b20b6a809..3c172e263 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -30,6 +30,11 @@ class User3MinRateThrottle(UserRateThrottle): scope = 'minutes' +class User6MinRateThrottle(UserRateThrottle): + rate = '6/min' + scope = 'minutes' + + class NonTimeThrottle(BaseThrottle): def allow_request(self, request, view): if not hasattr(self.__class__, 'called'): @@ -38,6 +43,13 @@ class NonTimeThrottle(BaseThrottle): return False +class MockView_DoubleThrottling(APIView): + throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,) + + def get(self, request): + return Response('foo') + + class MockView(APIView): throttle_classes = (User3SecRateThrottle,) @@ -80,7 +92,8 @@ class ThrottlingTests(TestCase): """ Explicitly set the timer, overriding time.time() """ - view.throttle_classes[0].timer = lambda self: value + for cls in view.throttle_classes: + cls.timer = lambda self: value def test_request_throttling_expires(self): """ @@ -115,6 +128,37 @@ class ThrottlingTests(TestCase): """ self.ensure_is_throttled(MockView, 200) + def test_request_throttling_multiple_throttles(self): + """ + Ensure all throttle classes see each request even when the request is + already being throttled + """ + self.set_throttle_timer(MockView_DoubleThrottling, 0) + request = self.factory.get('/') + for dummy in range(4): + response = MockView_DoubleThrottling.as_view()(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 1 + + # At this point our client made 4 requests (one was throttled) in a + # second. If we advance the timer by one additional second, the client + # should be allowed to make 2 more before being throttled by the 2nd + # throttle class, which has a limit of 6 per minute. + self.set_throttle_timer(MockView_DoubleThrottling, 1) + for dummy in range(2): + response = MockView_DoubleThrottling.as_view()(request) + assert response.status_code == 200 + + response = MockView_DoubleThrottling.as_view()(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 59 + + # Just to make sure check again after two more seconds. + self.set_throttle_timer(MockView_DoubleThrottling, 2) + response = MockView_DoubleThrottling.as_view()(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 58 + def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): """ Ensure the response returns an Retry-After field with status and next attributes