Added failing test case and possible fix

This commit is contained in:
João Pedro 2019-03-07 12:02:46 +13:00
parent 9d06e43d05
commit 1b6994dba1
3 changed files with 165 additions and 9 deletions

View File

@ -124,22 +124,25 @@ class SimpleRateThrottle(BaseThrottle):
self.history = self.cache.get(self.key, []) self.history = self.cache.get(self.key, [])
self.now = self.timer() self.now = self.timer()
print(self.history)
# Drop any requests from the history which have now passed the # Drop any requests from the history which have now passed the
# throttle duration # throttle duration
while self.history and self.history[-1] <= self.now - self.duration: while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop() self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self): if len(self.history) >= self.num_requests:
return False
return True # self.throttle_success()
def throttle_success(self, request, view):
""" """
Inserts the current request's timestamp along with the key Inserts the current request's timestamp along with the key
into the cache. into the cache.
""" """
self.history.insert(0, self.now) if self.scope:
self.cache.set(self.key, self.history, self.duration) self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True return True
def throttle_failure(self): def throttle_failure(self):
@ -152,6 +155,7 @@ class SimpleRateThrottle(BaseThrottle):
""" """
Returns the recommended next request time in seconds. Returns the recommended next request time in seconds.
""" """
if self.history: if self.history:
remaining_duration = self.duration - (self.now - self.history[-1]) remaining_duration = self.duration - (self.now - self.history[-1])
else: else:

View File

@ -352,9 +352,18 @@ class APIView(View):
Check if request should be throttled. Check if request should be throttled.
Raises an appropriate exception if the request is throttled. Raises an appropriate exception if the request is throttled.
""" """
for throttle in self.get_throttles(): throttles = self.get_throttles()
if not throttle.allow_request(request, self): request_allowed = all(
self.throttled(request, throttle.wait()) [throttle.allow_request(request, self) for throttle in throttles]
)
print(request_allowed, [throttle.allow_request(request, self) for throttle in throttles])
if request_allowed:
[throttle.throttle_success(request, self) for throttle in throttles]
else:
min_wait = min(
[throttle.wait() for throttle in throttles]
)
self.throttled(request, min_wait)
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
""" """

View File

@ -38,6 +38,9 @@ class NonTimeThrottle(BaseThrottle):
return True return True
return False return False
def throttle_success(self, request, view):
pass
class MockView(APIView): class MockView(APIView):
throttle_classes = (User3SecRateThrottle,) throttle_classes = (User3SecRateThrottle,)
@ -449,3 +452,143 @@ class AnonRateThrottleTests(TestCase):
request = Request(HttpRequest()) request = Request(HttpRequest())
cache_key = self.throttle.get_cache_key(request, view={}) cache_key = self.throttle.get_cache_key(request, view={})
assert cache_key == 'throttle_anon_None' assert cache_key == 'throttle_anon_None'
class TwoPeriodRateThrottleTests(TestCase):
"""
Tests for views with more than one period.
The order of the throttle classes should not matter
Eg.
2/min and 5/day
"""
def setUp(self):
self.DEFAULT_THROTTLE_RATES = {
'burst-anon': '2/min',
'sustained-anon': '5/day'
}
self.TIMER_SECONDS = 0
cache.clear()
class BurstRateThrottle(AnonRateThrottle):
THROTTLE_RATES = self.DEFAULT_THROTTLE_RATES
TIMER_SECONDS = self.TIMER_SECONDS
scope = 'burst-anon'
def timer(self):
return self.TIMER_SECONDS
class SustainedRateThrottle(AnonRateThrottle):
THROTTLE_RATES = self.DEFAULT_THROTTLE_RATES
TIMER_SECONDS = self.TIMER_SECONDS
scope = 'sustained-anon'
def timer(self):
return self.TIMER_SECONDS
class BurstSustainedView(APIView):
throttle_classes = (BurstRateThrottle, SustainedRateThrottle)
def get(self, request):
return Response('x')
class SustainedBurstView(APIView):
throttle_classes = (SustainedRateThrottle, BurstRateThrottle)
def get(self, request):
return Response('y')
self.factory = APIRequestFactory()
self.sustained_burst_view = SustainedBurstView.as_view()
self.burst_sustained_view = BurstSustainedView.as_view()
self.burst_sustained_throttle = BurstRateThrottle
self.sustained_burst_throttle = SustainedRateThrottle
def increment_timer(self, seconds=1):
self.burst_sustained_throttle.TIMER_SECONDS += seconds
self.sustained_burst_throttle.TIMER_SECONDS += seconds
def test_sustained_burst_throttles_ordering(self):
request = self.factory.get('/')
# Should be able to hit x view 2 times per minute.
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(58)
# Should still be able to hit x view 2 times per minute.
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(58)
# Should still be able to hit y view 1 time per minute.
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 429
def test_burst_sustained_throttles_ordering(self):
request = self.factory.get('/')
# Should be able to hit x view 2 times per minute.
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(58)
# Should still be able to hit x view 2 times per minute.
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(58)
# Should still be able to hit y view 1 time per minute.
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 429