mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-17 19:52:25 +03:00
add amount of period for throttling
This commit is contained in:
parent
337ba211e8
commit
02b338f2e3
|
@ -154,7 +154,7 @@ For example, multiple user throttle rates could be implemented by using the foll
|
||||||
'example.throttles.SustainedRateThrottle'
|
'example.throttles.SustainedRateThrottle'
|
||||||
],
|
],
|
||||||
'DEFAULT_THROTTLE_RATES': {
|
'DEFAULT_THROTTLE_RATES': {
|
||||||
'burst': '60/min',
|
'burst': '60/30-min',
|
||||||
'sustained': '1000/day'
|
'sustained': '1000/day'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,7 +103,13 @@ class SimpleRateThrottle(BaseThrottle):
|
||||||
return (None, None)
|
return (None, None)
|
||||||
num, period = rate.split('/')
|
num, period = rate.split('/')
|
||||||
num_requests = int(num)
|
num_requests = int(num)
|
||||||
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
|
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}
|
||||||
|
if "-" in period:
|
||||||
|
num_period, period = period.split("-")
|
||||||
|
duration = duration[period[0]] * int(num_period)
|
||||||
|
return (num_requests, duration)
|
||||||
|
duration = duration[period[0]]
|
||||||
|
|
||||||
return (num_requests, duration)
|
return (num_requests, duration)
|
||||||
|
|
||||||
def allow_request(self, request, view):
|
def allow_request(self, request, view):
|
||||||
|
|
|
@ -25,6 +25,11 @@ class User3SecRateThrottle(UserRateThrottle):
|
||||||
scope = 'seconds'
|
scope = 'seconds'
|
||||||
|
|
||||||
|
|
||||||
|
class User1RequestIn2SecRateThrottle(UserRateThrottle):
|
||||||
|
rate = '1/2-sec'
|
||||||
|
scope = 'seconds'
|
||||||
|
|
||||||
|
|
||||||
class User3MinRateThrottle(UserRateThrottle):
|
class User3MinRateThrottle(UserRateThrottle):
|
||||||
rate = '3/min'
|
rate = '3/min'
|
||||||
scope = 'minutes'
|
scope = 'minutes'
|
||||||
|
@ -57,6 +62,13 @@ class MockView(APIView):
|
||||||
return Response('foo')
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
|
class MockView_1RequestIn2SecondThrottling(APIView):
|
||||||
|
throttle_classes = (User1RequestIn2SecRateThrottle,)
|
||||||
|
|
||||||
|
def get(self, request):
|
||||||
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class MockView_MinuteThrottling(APIView):
|
class MockView_MinuteThrottling(APIView):
|
||||||
throttle_classes = (User3MinRateThrottle,)
|
throttle_classes = (User3MinRateThrottle,)
|
||||||
|
|
||||||
|
@ -167,18 +179,15 @@ class ThrottlingTests(TestCase):
|
||||||
assert response.status_code == 429
|
assert response.status_code == 429
|
||||||
assert int(response['retry-after']) == 60
|
assert int(response['retry-after']) == 60
|
||||||
|
|
||||||
previous_rate = User3SecRateThrottle.rate
|
def test_request_throttling_with_amount_of_period(self):
|
||||||
try:
|
self.set_throttle_timer(MockView_1RequestIn2SecondThrottling, 0)
|
||||||
User3SecRateThrottle.rate = '1/sec'
|
request = self.factory.get('/')
|
||||||
|
# At this point our client made two requests, second was throttled for a
|
||||||
for dummy in range(24):
|
# two seconds.
|
||||||
response = MockView_DoubleThrottling.as_view()(request)
|
for _ in range(2):
|
||||||
|
response = MockView_1RequestIn2SecondThrottling.as_view()(request)
|
||||||
assert response.status_code == 429
|
assert response.status_code == 429
|
||||||
assert int(response['retry-after']) == 60
|
assert int(response['retry-after']) == 2
|
||||||
finally:
|
|
||||||
# reset
|
|
||||||
User3SecRateThrottle.rate = previous_rate
|
|
||||||
|
|
||||||
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
|
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user