From 203fb2004e806f4cfefa86eb11cfee06d4b74f7a Mon Sep 17 00:00:00 2001 From: Mathieu Pillard Date: Mon, 13 May 2019 16:46:59 +0200 Subject: [PATCH] Always add all requests to all throttling classes history - Even when the request is going to fail because of the current throttle - Even when another throttling class already failed the request This punishes brute-forcing, rewarding clients that wait for the full duration of the suggested throttle before making new requests, and also makes setting multiple throttling classes more efficient/useful. --- rest_framework/throttling.py | 18 ++++++-- rest_framework/views.py | 6 ++- tests/test_throttling.py | 80 ++++++++++++++++++++++++++++++++---- 3 files changed, 90 insertions(+), 14 deletions(-) diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 0ba2ba66b..6da578720 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -131,19 +131,26 @@ class SimpleRateThrottle(BaseThrottle): return self.throttle_failure() return self.throttle_success() - def throttle_success(self): + def add_request_to_history(self): """ Inserts the current request's timestamp along with the key into the cache. """ self.history.insert(0, self.now) self.cache.set(self.key, self.history, self.duration) + + def throttle_success(self): + """ + Called when a request to the API has passed throttling checks. + """ + self.add_request_to_history() return True def throttle_failure(self): """ Called when a request to the API has failed due to throttling. """ + self.add_request_to_history() return False def wait(self): @@ -155,9 +162,12 @@ class SimpleRateThrottle(BaseThrottle): else: remaining_duration = self.duration - available_requests = self.num_requests - len(self.history) + 1 - if available_requests <= 0: - return None + # If we go over the num of requests in history, ensure the + # 'available_requests' will stay at 1, suggesting clients to wait for + # the full duration of the throttle. + available_requests = ( + self.num_requests - min(self.num_requests, len(self.history)) + 1 + ) return remaining_duration / float(available_requests) diff --git a/rest_framework/views.py b/rest_framework/views.py index 6ef7021d4..832f17233 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -350,9 +350,13 @@ class APIView(View): Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ + throttle_durations = [] for throttle in self.get_throttles(): if not throttle.allow_request(request, self): - self.throttled(request, throttle.wait()) + throttle_durations.append(throttle.wait()) + + if throttle_durations: + self.throttled(request, max(throttle_durations)) def determine_version(self, request, *args, **kwargs): """ diff --git a/tests/test_throttling.py b/tests/test_throttling.py index b20b6a809..9feaac44c 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -1,5 +1,5 @@ """ -Tests for the throttling implementations in the permissions module. +Tests for the throttling implementations. """ import pytest @@ -30,6 +30,11 @@ class User3MinRateThrottle(UserRateThrottle): scope = 'minutes' +class User6MinRateThrottle(UserRateThrottle): + rate = '6/min' + scope = 'minutes' + + class NonTimeThrottle(BaseThrottle): def allow_request(self, request, view): if not hasattr(self.__class__, 'called'): @@ -59,6 +64,13 @@ class MockView_NonTimeThrottling(APIView): return Response('foo') +class MockView_DoubleThrottling(APIView): + throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,) + + def get(self, request): + return Response('foo') + + class ThrottlingTests(TestCase): def setUp(self): """ @@ -80,7 +92,8 @@ class ThrottlingTests(TestCase): """ Explicitly set the timer, overriding time.time() """ - view.throttle_classes[0].timer = lambda self: value + for cls in view.throttle_classes: + cls.timer = lambda self: value def test_request_throttling_expires(self): """ @@ -99,11 +112,13 @@ class ThrottlingTests(TestCase): response = MockView.as_view()(request) assert response.status_code == 200 - def ensure_is_throttled(self, view, expect): + def ensure_is_throttled_separately(self, view, expect): request = self.factory.get('/') request.user = User.objects.create(username='a') for dummy in range(3): view.as_view()(request) + response = view.as_view()(request) + assert response.status_code == 429 request.user = User.objects.create(username='b') response = view.as_view()(request) assert response.status_code == expect @@ -113,7 +128,34 @@ class ThrottlingTests(TestCase): Ensure request rate is only limited per user, not globally for PerUserThrottles """ - self.ensure_is_throttled(MockView, 200) + self.set_throttle_timer(MockView, 0) + self.ensure_is_throttled_separately(MockView, 200) + + def test_request_throttling_multiple_throttles(self): + """ + """ + self.set_throttle_timer(MockView_DoubleThrottling, 0) + request = self.factory.get('/') + for dummy in range(4): + response = MockView_DoubleThrottling.as_view()(request) + assert response.status_code == 429 + + # At this point they made 4 requests (one was throttled) in a second. + # If we advance the timer by one second, they should be allowed to make + # 2 more before being throttled by the 2nd throttle class, which has a + # limit of 6 per minute. + self.set_throttle_timer(MockView_DoubleThrottling, 1) + response = MockView_DoubleThrottling.as_view()(request) + assert response.status_code == 200 + + response = MockView_DoubleThrottling.as_view()(request) + assert response.status_code == 200 + + # Shouldn't be necessary, but increment timer by one to make sure the + # throttling is caused by the User6MinRateThrottle class. + self.set_throttle_timer(MockView_DoubleThrottling, 1) + response = MockView_DoubleThrottling.as_view()(request) + assert response.status_code == 429 def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): """ @@ -125,9 +167,11 @@ class ThrottlingTests(TestCase): self.set_throttle_timer(view, timer) response = view.as_view()(request) if expect is not None: + assert response.status_code == 429 assert response['Retry-After'] == expect else: - assert not'Retry-After' in response + assert response.status_code == 200 + assert 'Retry-After' not in response def test_seconds_fields(self): """ @@ -258,10 +302,23 @@ class ScopedRateThrottleTests(TestCase): response = self.y_view(request) assert response.status_code == 429 - # Ensure throttles properly reset by advancing the rest of the minute + # Increment by 55 seconds. Because of the recent failures it should + # still not be allowed (and this request itself should be added to the + # history). self.increment_timer(55) + response = self.x_view(request) + assert response.status_code == 429 + + # Waiting 2 more seconds should clear the first items from the history. + self.increment_timer(2) + response = self.x_view(request) + assert response.status_code == 200 + + # After 58 more seconds we should be allowed to make a request again. + # Since we're incrementing timer by one between each request, the + # following 2 requests should work as well. + self.increment_timer(58) - # Should still be able to hit x view 3 times per minute. response = self.x_view(request) assert response.status_code == 200 @@ -423,13 +480,18 @@ class SimpleRateThrottleTests(TestCase): assert isinstance(waiting_time, float) assert waiting_time == 30.0 - def test_wait_returns_none_if_there_are_no_available_requests(self): + def test_wait_returns_duration_if_there_are_no_available_requests(self): + def timer(): + return throttle.now throttle = SimpleRateThrottle() throttle.num_requests = 1 throttle.duration = 60 throttle.now = throttle.timer() + throttle.timer = timer # Force time to be fixed for this test. + # Number of requests in history is already over the limit, so clients + # should wait for at least the full duration of the throttle. throttle.history = [throttle.timer() for _ in range(3)] - assert throttle.wait() is None + assert throttle.wait() == 60.0 class AnonRateThrottleTests(TestCase):