Always add all requests to all throttling classes history

- Even when the request is going to fail because of the current throttle
- Even when another throttling class already failed the request

This punishes brute-forcing, rewarding clients that wait for the full
duration of the suggested throttle before making new requests, and
also makes setting multiple throttling classes more efficient/useful.
This commit is contained in:
Mathieu Pillard 2019-05-13 16:46:59 +02:00
parent 37f210a455
commit 203fb2004e
3 changed files with 90 additions and 14 deletions

View File

@ -131,19 +131,26 @@ class SimpleRateThrottle(BaseThrottle):
return self.throttle_failure() return self.throttle_failure()
return self.throttle_success() return self.throttle_success()
def throttle_success(self): def add_request_to_history(self):
""" """
Inserts the current request's timestamp along with the key Inserts the current request's timestamp along with the key
into the cache. into the cache.
""" """
self.history.insert(0, self.now) self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration) self.cache.set(self.key, self.history, self.duration)
def throttle_success(self):
"""
Called when a request to the API has passed throttling checks.
"""
self.add_request_to_history()
return True return True
def throttle_failure(self): def throttle_failure(self):
""" """
Called when a request to the API has failed due to throttling. Called when a request to the API has failed due to throttling.
""" """
self.add_request_to_history()
return False return False
def wait(self): def wait(self):
@ -155,9 +162,12 @@ class SimpleRateThrottle(BaseThrottle):
else: else:
remaining_duration = self.duration remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1 # If we go over the num of requests in history, ensure the
if available_requests <= 0: # 'available_requests' will stay at 1, suggesting clients to wait for
return None # the full duration of the throttle.
available_requests = (
self.num_requests - min(self.num_requests, len(self.history)) + 1
)
return remaining_duration / float(available_requests) return remaining_duration / float(available_requests)

View File

@ -350,9 +350,13 @@ class APIView(View):
Check if request should be throttled. Check if request should be throttled.
Raises an appropriate exception if the request is throttled. Raises an appropriate exception if the request is throttled.
""" """
throttle_durations = []
for throttle in self.get_throttles(): for throttle in self.get_throttles():
if not throttle.allow_request(request, self): if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait()) throttle_durations.append(throttle.wait())
if throttle_durations:
self.throttled(request, max(throttle_durations))
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
""" """

View File

@ -1,5 +1,5 @@
""" """
Tests for the throttling implementations in the permissions module. Tests for the throttling implementations.
""" """
import pytest import pytest
@ -30,6 +30,11 @@ class User3MinRateThrottle(UserRateThrottle):
scope = 'minutes' scope = 'minutes'
class User6MinRateThrottle(UserRateThrottle):
rate = '6/min'
scope = 'minutes'
class NonTimeThrottle(BaseThrottle): class NonTimeThrottle(BaseThrottle):
def allow_request(self, request, view): def allow_request(self, request, view):
if not hasattr(self.__class__, 'called'): if not hasattr(self.__class__, 'called'):
@ -59,6 +64,13 @@ class MockView_NonTimeThrottling(APIView):
return Response('foo') return Response('foo')
class MockView_DoubleThrottling(APIView):
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
def get(self, request):
return Response('foo')
class ThrottlingTests(TestCase): class ThrottlingTests(TestCase):
def setUp(self): def setUp(self):
""" """
@ -80,7 +92,8 @@ class ThrottlingTests(TestCase):
""" """
Explicitly set the timer, overriding time.time() Explicitly set the timer, overriding time.time()
""" """
view.throttle_classes[0].timer = lambda self: value for cls in view.throttle_classes:
cls.timer = lambda self: value
def test_request_throttling_expires(self): def test_request_throttling_expires(self):
""" """
@ -99,11 +112,13 @@ class ThrottlingTests(TestCase):
response = MockView.as_view()(request) response = MockView.as_view()(request)
assert response.status_code == 200 assert response.status_code == 200
def ensure_is_throttled(self, view, expect): def ensure_is_throttled_separately(self, view, expect):
request = self.factory.get('/') request = self.factory.get('/')
request.user = User.objects.create(username='a') request.user = User.objects.create(username='a')
for dummy in range(3): for dummy in range(3):
view.as_view()(request) view.as_view()(request)
response = view.as_view()(request)
assert response.status_code == 429
request.user = User.objects.create(username='b') request.user = User.objects.create(username='b')
response = view.as_view()(request) response = view.as_view()(request)
assert response.status_code == expect assert response.status_code == expect
@ -113,7 +128,34 @@ class ThrottlingTests(TestCase):
Ensure request rate is only limited per user, not globally for Ensure request rate is only limited per user, not globally for
PerUserThrottles PerUserThrottles
""" """
self.ensure_is_throttled(MockView, 200) self.set_throttle_timer(MockView, 0)
self.ensure_is_throttled_separately(MockView, 200)
def test_request_throttling_multiple_throttles(self):
"""
"""
self.set_throttle_timer(MockView_DoubleThrottling, 0)
request = self.factory.get('/')
for dummy in range(4):
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 429
# At this point they made 4 requests (one was throttled) in a second.
# If we advance the timer by one second, they should be allowed to make
# 2 more before being throttled by the 2nd throttle class, which has a
# limit of 6 per minute.
self.set_throttle_timer(MockView_DoubleThrottling, 1)
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 200
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 200
# Shouldn't be necessary, but increment timer by one to make sure the
# throttling is caused by the User6MinRateThrottle class.
self.set_throttle_timer(MockView_DoubleThrottling, 1)
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 429
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
""" """
@ -125,9 +167,11 @@ class ThrottlingTests(TestCase):
self.set_throttle_timer(view, timer) self.set_throttle_timer(view, timer)
response = view.as_view()(request) response = view.as_view()(request)
if expect is not None: if expect is not None:
assert response.status_code == 429
assert response['Retry-After'] == expect assert response['Retry-After'] == expect
else: else:
assert not'Retry-After' in response assert response.status_code == 200
assert 'Retry-After' not in response
def test_seconds_fields(self): def test_seconds_fields(self):
""" """
@ -258,10 +302,23 @@ class ScopedRateThrottleTests(TestCase):
response = self.y_view(request) response = self.y_view(request)
assert response.status_code == 429 assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute # Increment by 55 seconds. Because of the recent failures it should
# still not be allowed (and this request itself should be added to the
# history).
self.increment_timer(55) self.increment_timer(55)
response = self.x_view(request)
assert response.status_code == 429
# Waiting 2 more seconds should clear the first items from the history.
self.increment_timer(2)
response = self.x_view(request)
assert response.status_code == 200
# After 58 more seconds we should be allowed to make a request again.
# Since we're incrementing timer by one between each request, the
# following 2 requests should work as well.
self.increment_timer(58)
# Should still be able to hit x view 3 times per minute.
response = self.x_view(request) response = self.x_view(request)
assert response.status_code == 200 assert response.status_code == 200
@ -423,13 +480,18 @@ class SimpleRateThrottleTests(TestCase):
assert isinstance(waiting_time, float) assert isinstance(waiting_time, float)
assert waiting_time == 30.0 assert waiting_time == 30.0
def test_wait_returns_none_if_there_are_no_available_requests(self): def test_wait_returns_duration_if_there_are_no_available_requests(self):
def timer():
return throttle.now
throttle = SimpleRateThrottle() throttle = SimpleRateThrottle()
throttle.num_requests = 1 throttle.num_requests = 1
throttle.duration = 60 throttle.duration = 60
throttle.now = throttle.timer() throttle.now = throttle.timer()
throttle.timer = timer # Force time to be fixed for this test.
# Number of requests in history is already over the limit, so clients
# should wait for at least the full duration of the throttle.
throttle.history = [throttle.timer() for _ in range(3)] throttle.history = [throttle.timer() for _ in range(3)]
assert throttle.wait() is None assert throttle.wait() == 60.0
class AnonRateThrottleTests(TestCase): class AnonRateThrottleTests(TestCase):