From 02b338f2e338ac2cd6a7f5c081ffbec12312a00c Mon Sep 17 00:00:00 2001 From: abulaysov Date: Mon, 18 Mar 2024 02:55:16 +0300 Subject: [PATCH] add amount of period for throttling --- docs/api-guide/throttling.md | 2 +- rest_framework/throttling.py | 8 +++++++- tests/test_throttling.py | 33 +++++++++++++++++++++------------ 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md index 4c58fa713..b037966f5 100644 --- a/docs/api-guide/throttling.md +++ b/docs/api-guide/throttling.md @@ -154,7 +154,7 @@ For example, multiple user throttle rates could be implemented by using the foll 'example.throttles.SustainedRateThrottle' ], 'DEFAULT_THROTTLE_RATES': { - 'burst': '60/min', + 'burst': '60/30-min', 'sustained': '1000/day' } } diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index c0d6cf42f..7990ca22c 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -103,7 +103,13 @@ class SimpleRateThrottle(BaseThrottle): return (None, None) num, period = rate.split('/') num_requests = int(num) - duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400} + if "-" in period: + num_period, period = period.split("-") + duration = duration[period[0]] * int(num_period) + return (num_requests, duration) + duration = duration[period[0]] + return (num_requests, duration) def allow_request(self, request, view): diff --git a/tests/test_throttling.py b/tests/test_throttling.py index be9decebc..60ce24cae 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -25,6 +25,11 @@ class User3SecRateThrottle(UserRateThrottle): scope = 'seconds' +class User1RequestIn2SecRateThrottle(UserRateThrottle): + rate = '1/2-sec' + scope = 'seconds' + + class User3MinRateThrottle(UserRateThrottle): rate = '3/min' scope = 'minutes' @@ -57,6 +62,13 @@ class MockView(APIView): return Response('foo') +class MockView_1RequestIn2SecondThrottling(APIView): + throttle_classes = (User1RequestIn2SecRateThrottle,) + + def get(self, request): + return Response('foo') + + class MockView_MinuteThrottling(APIView): throttle_classes = (User3MinRateThrottle,) @@ -167,18 +179,15 @@ class ThrottlingTests(TestCase): assert response.status_code == 429 assert int(response['retry-after']) == 60 - previous_rate = User3SecRateThrottle.rate - try: - User3SecRateThrottle.rate = '1/sec' - - for dummy in range(24): - response = MockView_DoubleThrottling.as_view()(request) - - assert response.status_code == 429 - assert int(response['retry-after']) == 60 - finally: - # reset - User3SecRateThrottle.rate = previous_rate + def test_request_throttling_with_amount_of_period(self): + self.set_throttle_timer(MockView_1RequestIn2SecondThrottling, 0) + request = self.factory.get('/') + # At this point our client made two requests, second was throttled for a + # two seconds. + for _ in range(2): + response = MockView_1RequestIn2SecondThrottling.as_view()(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 2 def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): """