Always call all throttling classes on the view when checking throttles (#6711)

This commit is contained in:
Mathieu Pillard 2019-05-23 15:42:29 +02:00 committed by Tom Christie
parent 19ca86d8d6
commit afb678433b
2 changed files with 50 additions and 2 deletions

View File

@ -350,9 +350,13 @@ class APIView(View):
Check if request should be throttled. Check if request should be throttled.
Raises an appropriate exception if the request is throttled. Raises an appropriate exception if the request is throttled.
""" """
throttle_durations = []
for throttle in self.get_throttles(): for throttle in self.get_throttles():
if not throttle.allow_request(request, self): if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait()) throttle_durations.append(throttle.wait())
if throttle_durations:
self.throttled(request, max(throttle_durations))
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
""" """

View File

@ -30,6 +30,11 @@ class User3MinRateThrottle(UserRateThrottle):
scope = 'minutes' scope = 'minutes'
class User6MinRateThrottle(UserRateThrottle):
rate = '6/min'
scope = 'minutes'
class NonTimeThrottle(BaseThrottle): class NonTimeThrottle(BaseThrottle):
def allow_request(self, request, view): def allow_request(self, request, view):
if not hasattr(self.__class__, 'called'): if not hasattr(self.__class__, 'called'):
@ -38,6 +43,13 @@ class NonTimeThrottle(BaseThrottle):
return False return False
class MockView_DoubleThrottling(APIView):
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
def get(self, request):
return Response('foo')
class MockView(APIView): class MockView(APIView):
throttle_classes = (User3SecRateThrottle,) throttle_classes = (User3SecRateThrottle,)
@ -80,7 +92,8 @@ class ThrottlingTests(TestCase):
""" """
Explicitly set the timer, overriding time.time() Explicitly set the timer, overriding time.time()
""" """
view.throttle_classes[0].timer = lambda self: value for cls in view.throttle_classes:
cls.timer = lambda self: value
def test_request_throttling_expires(self): def test_request_throttling_expires(self):
""" """
@ -115,6 +128,37 @@ class ThrottlingTests(TestCase):
""" """
self.ensure_is_throttled(MockView, 200) self.ensure_is_throttled(MockView, 200)
def test_request_throttling_multiple_throttles(self):
"""
Ensure all throttle classes see each request even when the request is
already being throttled
"""
self.set_throttle_timer(MockView_DoubleThrottling, 0)
request = self.factory.get('/')
for dummy in range(4):
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 429
assert int(response['retry-after']) == 1
# At this point our client made 4 requests (one was throttled) in a
# second. If we advance the timer by one additional second, the client
# should be allowed to make 2 more before being throttled by the 2nd
# throttle class, which has a limit of 6 per minute.
self.set_throttle_timer(MockView_DoubleThrottling, 1)
for dummy in range(2):
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 200
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 429
assert int(response['retry-after']) == 59
# Just to make sure check again after two more seconds.
self.set_throttle_timer(MockView_DoubleThrottling, 2)
response = MockView_DoubleThrottling.as_view()(request)
assert response.status_code == 429
assert int(response['retry-after']) == 58
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
""" """
Ensure the response returns an Retry-After field with status and next attributes Ensure the response returns an Retry-After field with status and next attributes