Added failing test case and possible fix

This commit is contained in:
João Pedro 2019-03-07 12:02:46 +13:00
parent 9d06e43d05
commit 1b6994dba1
3 changed files with 165 additions and 9 deletions

View File

@ -124,22 +124,25 @@ class SimpleRateThrottle(BaseThrottle):
self.history = self.cache.get(self.key, [])
self.now = self.timer()
print(self.history)
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
if len(self.history) >= self.num_requests:
return False
return True # self.throttle_success()
def throttle_success(self, request, view):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
if self.scope:
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
@ -152,6 +155,7 @@ class SimpleRateThrottle(BaseThrottle):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:

View File

@ -352,9 +352,18 @@ class APIView(View):
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
throttles = self.get_throttles()
request_allowed = all(
[throttle.allow_request(request, self) for throttle in throttles]
)
print(request_allowed, [throttle.allow_request(request, self) for throttle in throttles])
if request_allowed:
[throttle.throttle_success(request, self) for throttle in throttles]
else:
min_wait = min(
[throttle.wait() for throttle in throttles]
)
self.throttled(request, min_wait)
def determine_version(self, request, *args, **kwargs):
"""

View File

@ -38,6 +38,9 @@ class NonTimeThrottle(BaseThrottle):
return True
return False
def throttle_success(self, request, view):
pass
class MockView(APIView):
throttle_classes = (User3SecRateThrottle,)
@ -449,3 +452,143 @@ class AnonRateThrottleTests(TestCase):
request = Request(HttpRequest())
cache_key = self.throttle.get_cache_key(request, view={})
assert cache_key == 'throttle_anon_None'
class TwoPeriodRateThrottleTests(TestCase):
"""
Tests for views with more than one period.
The order of the throttle classes should not matter
Eg.
2/min and 5/day
"""
def setUp(self):
self.DEFAULT_THROTTLE_RATES = {
'burst-anon': '2/min',
'sustained-anon': '5/day'
}
self.TIMER_SECONDS = 0
cache.clear()
class BurstRateThrottle(AnonRateThrottle):
THROTTLE_RATES = self.DEFAULT_THROTTLE_RATES
TIMER_SECONDS = self.TIMER_SECONDS
scope = 'burst-anon'
def timer(self):
return self.TIMER_SECONDS
class SustainedRateThrottle(AnonRateThrottle):
THROTTLE_RATES = self.DEFAULT_THROTTLE_RATES
TIMER_SECONDS = self.TIMER_SECONDS
scope = 'sustained-anon'
def timer(self):
return self.TIMER_SECONDS
class BurstSustainedView(APIView):
throttle_classes = (BurstRateThrottle, SustainedRateThrottle)
def get(self, request):
return Response('x')
class SustainedBurstView(APIView):
throttle_classes = (SustainedRateThrottle, BurstRateThrottle)
def get(self, request):
return Response('y')
self.factory = APIRequestFactory()
self.sustained_burst_view = SustainedBurstView.as_view()
self.burst_sustained_view = BurstSustainedView.as_view()
self.burst_sustained_throttle = BurstRateThrottle
self.sustained_burst_throttle = SustainedRateThrottle
def increment_timer(self, seconds=1):
self.burst_sustained_throttle.TIMER_SECONDS += seconds
self.sustained_burst_throttle.TIMER_SECONDS += seconds
def test_sustained_burst_throttles_ordering(self):
request = self.factory.get('/')
# Should be able to hit x view 2 times per minute.
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(58)
# Should still be able to hit x view 2 times per minute.
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(58)
# Should still be able to hit y view 1 time per minute.
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.sustained_burst_view(request)
assert response.status_code == 429
def test_burst_sustained_throttles_ordering(self):
request = self.factory.get('/')
# Should be able to hit x view 2 times per minute.
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(58)
# Should still be able to hit x view 2 times per minute.
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(58)
# Should still be able to hit y view 1 time per minute.
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.burst_sustained_view(request)
assert response.status_code == 429