mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-30 18:09:59 +03:00
Always add all requests to all throttling classes history
- Even when the request is going to fail because of the current throttle - Even when another throttling class already failed the request This punishes brute-forcing, rewarding clients that wait for the full duration of the suggested throttle before making new requests, and also makes setting multiple throttling classes more efficient/useful.
This commit is contained in:
parent
37f210a455
commit
203fb2004e
|
@ -131,19 +131,26 @@ class SimpleRateThrottle(BaseThrottle):
|
||||||
return self.throttle_failure()
|
return self.throttle_failure()
|
||||||
return self.throttle_success()
|
return self.throttle_success()
|
||||||
|
|
||||||
def throttle_success(self):
|
def add_request_to_history(self):
|
||||||
"""
|
"""
|
||||||
Inserts the current request's timestamp along with the key
|
Inserts the current request's timestamp along with the key
|
||||||
into the cache.
|
into the cache.
|
||||||
"""
|
"""
|
||||||
self.history.insert(0, self.now)
|
self.history.insert(0, self.now)
|
||||||
self.cache.set(self.key, self.history, self.duration)
|
self.cache.set(self.key, self.history, self.duration)
|
||||||
|
|
||||||
|
def throttle_success(self):
|
||||||
|
"""
|
||||||
|
Called when a request to the API has passed throttling checks.
|
||||||
|
"""
|
||||||
|
self.add_request_to_history()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def throttle_failure(self):
|
def throttle_failure(self):
|
||||||
"""
|
"""
|
||||||
Called when a request to the API has failed due to throttling.
|
Called when a request to the API has failed due to throttling.
|
||||||
"""
|
"""
|
||||||
|
self.add_request_to_history()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def wait(self):
|
def wait(self):
|
||||||
|
@ -155,9 +162,12 @@ class SimpleRateThrottle(BaseThrottle):
|
||||||
else:
|
else:
|
||||||
remaining_duration = self.duration
|
remaining_duration = self.duration
|
||||||
|
|
||||||
available_requests = self.num_requests - len(self.history) + 1
|
# If we go over the num of requests in history, ensure the
|
||||||
if available_requests <= 0:
|
# 'available_requests' will stay at 1, suggesting clients to wait for
|
||||||
return None
|
# the full duration of the throttle.
|
||||||
|
available_requests = (
|
||||||
|
self.num_requests - min(self.num_requests, len(self.history)) + 1
|
||||||
|
)
|
||||||
|
|
||||||
return remaining_duration / float(available_requests)
|
return remaining_duration / float(available_requests)
|
||||||
|
|
||||||
|
|
|
@ -350,9 +350,13 @@ class APIView(View):
|
||||||
Check if request should be throttled.
|
Check if request should be throttled.
|
||||||
Raises an appropriate exception if the request is throttled.
|
Raises an appropriate exception if the request is throttled.
|
||||||
"""
|
"""
|
||||||
|
throttle_durations = []
|
||||||
for throttle in self.get_throttles():
|
for throttle in self.get_throttles():
|
||||||
if not throttle.allow_request(request, self):
|
if not throttle.allow_request(request, self):
|
||||||
self.throttled(request, throttle.wait())
|
throttle_durations.append(throttle.wait())
|
||||||
|
|
||||||
|
if throttle_durations:
|
||||||
|
self.throttled(request, max(throttle_durations))
|
||||||
|
|
||||||
def determine_version(self, request, *args, **kwargs):
|
def determine_version(self, request, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Tests for the throttling implementations in the permissions module.
|
Tests for the throttling implementations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -30,6 +30,11 @@ class User3MinRateThrottle(UserRateThrottle):
|
||||||
scope = 'minutes'
|
scope = 'minutes'
|
||||||
|
|
||||||
|
|
||||||
|
class User6MinRateThrottle(UserRateThrottle):
|
||||||
|
rate = '6/min'
|
||||||
|
scope = 'minutes'
|
||||||
|
|
||||||
|
|
||||||
class NonTimeThrottle(BaseThrottle):
|
class NonTimeThrottle(BaseThrottle):
|
||||||
def allow_request(self, request, view):
|
def allow_request(self, request, view):
|
||||||
if not hasattr(self.__class__, 'called'):
|
if not hasattr(self.__class__, 'called'):
|
||||||
|
@ -59,6 +64,13 @@ class MockView_NonTimeThrottling(APIView):
|
||||||
return Response('foo')
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
|
class MockView_DoubleThrottling(APIView):
|
||||||
|
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
|
||||||
|
|
||||||
|
def get(self, request):
|
||||||
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class ThrottlingTests(TestCase):
|
class ThrottlingTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""
|
"""
|
||||||
|
@ -80,7 +92,8 @@ class ThrottlingTests(TestCase):
|
||||||
"""
|
"""
|
||||||
Explicitly set the timer, overriding time.time()
|
Explicitly set the timer, overriding time.time()
|
||||||
"""
|
"""
|
||||||
view.throttle_classes[0].timer = lambda self: value
|
for cls in view.throttle_classes:
|
||||||
|
cls.timer = lambda self: value
|
||||||
|
|
||||||
def test_request_throttling_expires(self):
|
def test_request_throttling_expires(self):
|
||||||
"""
|
"""
|
||||||
|
@ -99,11 +112,13 @@ class ThrottlingTests(TestCase):
|
||||||
response = MockView.as_view()(request)
|
response = MockView.as_view()(request)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
def ensure_is_throttled(self, view, expect):
|
def ensure_is_throttled_separately(self, view, expect):
|
||||||
request = self.factory.get('/')
|
request = self.factory.get('/')
|
||||||
request.user = User.objects.create(username='a')
|
request.user = User.objects.create(username='a')
|
||||||
for dummy in range(3):
|
for dummy in range(3):
|
||||||
view.as_view()(request)
|
view.as_view()(request)
|
||||||
|
response = view.as_view()(request)
|
||||||
|
assert response.status_code == 429
|
||||||
request.user = User.objects.create(username='b')
|
request.user = User.objects.create(username='b')
|
||||||
response = view.as_view()(request)
|
response = view.as_view()(request)
|
||||||
assert response.status_code == expect
|
assert response.status_code == expect
|
||||||
|
@ -113,7 +128,34 @@ class ThrottlingTests(TestCase):
|
||||||
Ensure request rate is only limited per user, not globally for
|
Ensure request rate is only limited per user, not globally for
|
||||||
PerUserThrottles
|
PerUserThrottles
|
||||||
"""
|
"""
|
||||||
self.ensure_is_throttled(MockView, 200)
|
self.set_throttle_timer(MockView, 0)
|
||||||
|
self.ensure_is_throttled_separately(MockView, 200)
|
||||||
|
|
||||||
|
def test_request_throttling_multiple_throttles(self):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
self.set_throttle_timer(MockView_DoubleThrottling, 0)
|
||||||
|
request = self.factory.get('/')
|
||||||
|
for dummy in range(4):
|
||||||
|
response = MockView_DoubleThrottling.as_view()(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# At this point they made 4 requests (one was throttled) in a second.
|
||||||
|
# If we advance the timer by one second, they should be allowed to make
|
||||||
|
# 2 more before being throttled by the 2nd throttle class, which has a
|
||||||
|
# limit of 6 per minute.
|
||||||
|
self.set_throttle_timer(MockView_DoubleThrottling, 1)
|
||||||
|
response = MockView_DoubleThrottling.as_view()(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response = MockView_DoubleThrottling.as_view()(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Shouldn't be necessary, but increment timer by one to make sure the
|
||||||
|
# throttling is caused by the User6MinRateThrottle class.
|
||||||
|
self.set_throttle_timer(MockView_DoubleThrottling, 1)
|
||||||
|
response = MockView_DoubleThrottling.as_view()(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
|
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
|
||||||
"""
|
"""
|
||||||
|
@ -125,9 +167,11 @@ class ThrottlingTests(TestCase):
|
||||||
self.set_throttle_timer(view, timer)
|
self.set_throttle_timer(view, timer)
|
||||||
response = view.as_view()(request)
|
response = view.as_view()(request)
|
||||||
if expect is not None:
|
if expect is not None:
|
||||||
|
assert response.status_code == 429
|
||||||
assert response['Retry-After'] == expect
|
assert response['Retry-After'] == expect
|
||||||
else:
|
else:
|
||||||
assert not'Retry-After' in response
|
assert response.status_code == 200
|
||||||
|
assert 'Retry-After' not in response
|
||||||
|
|
||||||
def test_seconds_fields(self):
|
def test_seconds_fields(self):
|
||||||
"""
|
"""
|
||||||
|
@ -258,10 +302,23 @@ class ScopedRateThrottleTests(TestCase):
|
||||||
response = self.y_view(request)
|
response = self.y_view(request)
|
||||||
assert response.status_code == 429
|
assert response.status_code == 429
|
||||||
|
|
||||||
# Ensure throttles properly reset by advancing the rest of the minute
|
# Increment by 55 seconds. Because of the recent failures it should
|
||||||
|
# still not be allowed (and this request itself should be added to the
|
||||||
|
# history).
|
||||||
self.increment_timer(55)
|
self.increment_timer(55)
|
||||||
|
response = self.x_view(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# Waiting 2 more seconds should clear the first items from the history.
|
||||||
|
self.increment_timer(2)
|
||||||
|
response = self.x_view(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# After 58 more seconds we should be allowed to make a request again.
|
||||||
|
# Since we're incrementing timer by one between each request, the
|
||||||
|
# following 2 requests should work as well.
|
||||||
|
self.increment_timer(58)
|
||||||
|
|
||||||
# Should still be able to hit x view 3 times per minute.
|
|
||||||
response = self.x_view(request)
|
response = self.x_view(request)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@ -423,13 +480,18 @@ class SimpleRateThrottleTests(TestCase):
|
||||||
assert isinstance(waiting_time, float)
|
assert isinstance(waiting_time, float)
|
||||||
assert waiting_time == 30.0
|
assert waiting_time == 30.0
|
||||||
|
|
||||||
def test_wait_returns_none_if_there_are_no_available_requests(self):
|
def test_wait_returns_duration_if_there_are_no_available_requests(self):
|
||||||
|
def timer():
|
||||||
|
return throttle.now
|
||||||
throttle = SimpleRateThrottle()
|
throttle = SimpleRateThrottle()
|
||||||
throttle.num_requests = 1
|
throttle.num_requests = 1
|
||||||
throttle.duration = 60
|
throttle.duration = 60
|
||||||
throttle.now = throttle.timer()
|
throttle.now = throttle.timer()
|
||||||
|
throttle.timer = timer # Force time to be fixed for this test.
|
||||||
|
# Number of requests in history is already over the limit, so clients
|
||||||
|
# should wait for at least the full duration of the throttle.
|
||||||
throttle.history = [throttle.timer() for _ in range(3)]
|
throttle.history = [throttle.timer() for _ in range(3)]
|
||||||
assert throttle.wait() is None
|
assert throttle.wait() == 60.0
|
||||||
|
|
||||||
|
|
||||||
class AnonRateThrottleTests(TestCase):
|
class AnonRateThrottleTests(TestCase):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user