mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-30 18:09:59 +03:00
Added failing test case and possible fix
This commit is contained in:
parent
9d06e43d05
commit
1b6994dba1
|
@ -124,22 +124,25 @@ class SimpleRateThrottle(BaseThrottle):
|
|||
|
||||
self.history = self.cache.get(self.key, [])
|
||||
self.now = self.timer()
|
||||
print(self.history)
|
||||
|
||||
# Drop any requests from the history which have now passed the
|
||||
# throttle duration
|
||||
while self.history and self.history[-1] <= self.now - self.duration:
|
||||
self.history.pop()
|
||||
if len(self.history) >= self.num_requests:
|
||||
return self.throttle_failure()
|
||||
return self.throttle_success()
|
||||
|
||||
def throttle_success(self):
|
||||
if len(self.history) >= self.num_requests:
|
||||
return False
|
||||
return True # self.throttle_success()
|
||||
|
||||
def throttle_success(self, request, view):
|
||||
"""
|
||||
Inserts the current request's timestamp along with the key
|
||||
into the cache.
|
||||
"""
|
||||
self.history.insert(0, self.now)
|
||||
self.cache.set(self.key, self.history, self.duration)
|
||||
if self.scope:
|
||||
self.history.insert(0, self.now)
|
||||
self.cache.set(self.key, self.history, self.duration)
|
||||
return True
|
||||
|
||||
def throttle_failure(self):
|
||||
|
@ -152,6 +155,7 @@ class SimpleRateThrottle(BaseThrottle):
|
|||
"""
|
||||
Returns the recommended next request time in seconds.
|
||||
"""
|
||||
|
||||
if self.history:
|
||||
remaining_duration = self.duration - (self.now - self.history[-1])
|
||||
else:
|
||||
|
|
|
@ -352,9 +352,18 @@ class APIView(View):
|
|||
Check if request should be throttled.
|
||||
Raises an appropriate exception if the request is throttled.
|
||||
"""
|
||||
for throttle in self.get_throttles():
|
||||
if not throttle.allow_request(request, self):
|
||||
self.throttled(request, throttle.wait())
|
||||
throttles = self.get_throttles()
|
||||
request_allowed = all(
|
||||
[throttle.allow_request(request, self) for throttle in throttles]
|
||||
)
|
||||
print(request_allowed, [throttle.allow_request(request, self) for throttle in throttles])
|
||||
if request_allowed:
|
||||
[throttle.throttle_success(request, self) for throttle in throttles]
|
||||
else:
|
||||
min_wait = min(
|
||||
[throttle.wait() for throttle in throttles]
|
||||
)
|
||||
self.throttled(request, min_wait)
|
||||
|
||||
def determine_version(self, request, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -38,6 +38,9 @@ class NonTimeThrottle(BaseThrottle):
|
|||
return True
|
||||
return False
|
||||
|
||||
def throttle_success(self, request, view):
|
||||
pass
|
||||
|
||||
|
||||
class MockView(APIView):
|
||||
throttle_classes = (User3SecRateThrottle,)
|
||||
|
@ -449,3 +452,143 @@ class AnonRateThrottleTests(TestCase):
|
|||
request = Request(HttpRequest())
|
||||
cache_key = self.throttle.get_cache_key(request, view={})
|
||||
assert cache_key == 'throttle_anon_None'
|
||||
|
||||
|
||||
class TwoPeriodRateThrottleTests(TestCase):
|
||||
"""
|
||||
Tests for views with more than one period.
|
||||
The order of the throttle classes should not matter
|
||||
Eg.
|
||||
2/min and 5/day
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.DEFAULT_THROTTLE_RATES = {
|
||||
'burst-anon': '2/min',
|
||||
'sustained-anon': '5/day'
|
||||
}
|
||||
self.TIMER_SECONDS = 0
|
||||
cache.clear()
|
||||
|
||||
class BurstRateThrottle(AnonRateThrottle):
|
||||
THROTTLE_RATES = self.DEFAULT_THROTTLE_RATES
|
||||
TIMER_SECONDS = self.TIMER_SECONDS
|
||||
scope = 'burst-anon'
|
||||
|
||||
def timer(self):
|
||||
return self.TIMER_SECONDS
|
||||
|
||||
class SustainedRateThrottle(AnonRateThrottle):
|
||||
THROTTLE_RATES = self.DEFAULT_THROTTLE_RATES
|
||||
TIMER_SECONDS = self.TIMER_SECONDS
|
||||
scope = 'sustained-anon'
|
||||
|
||||
def timer(self):
|
||||
return self.TIMER_SECONDS
|
||||
|
||||
class BurstSustainedView(APIView):
|
||||
throttle_classes = (BurstRateThrottle, SustainedRateThrottle)
|
||||
|
||||
def get(self, request):
|
||||
return Response('x')
|
||||
|
||||
class SustainedBurstView(APIView):
|
||||
throttle_classes = (SustainedRateThrottle, BurstRateThrottle)
|
||||
|
||||
def get(self, request):
|
||||
return Response('y')
|
||||
|
||||
self.factory = APIRequestFactory()
|
||||
self.sustained_burst_view = SustainedBurstView.as_view()
|
||||
self.burst_sustained_view = BurstSustainedView.as_view()
|
||||
self.burst_sustained_throttle = BurstRateThrottle
|
||||
self.sustained_burst_throttle = SustainedRateThrottle
|
||||
|
||||
def increment_timer(self, seconds=1):
|
||||
self.burst_sustained_throttle.TIMER_SECONDS += seconds
|
||||
self.sustained_burst_throttle.TIMER_SECONDS += seconds
|
||||
|
||||
def test_sustained_burst_throttles_ordering(self):
|
||||
request = self.factory.get('/')
|
||||
|
||||
# Should be able to hit x view 2 times per minute.
|
||||
|
||||
response = self.sustained_burst_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.sustained_burst_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.sustained_burst_view(request)
|
||||
assert response.status_code == 429
|
||||
|
||||
# Ensure throttles properly reset by advancing the rest of the minute
|
||||
self.increment_timer(58)
|
||||
|
||||
# Should still be able to hit x view 2 times per minute.
|
||||
response = self.sustained_burst_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.sustained_burst_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.sustained_burst_view(request)
|
||||
assert response.status_code == 429
|
||||
|
||||
# Ensure throttles properly reset by advancing the rest of the minute
|
||||
self.increment_timer(58)
|
||||
|
||||
# Should still be able to hit y view 1 time per minute.
|
||||
self.increment_timer()
|
||||
response = self.sustained_burst_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.sustained_burst_view(request)
|
||||
assert response.status_code == 429
|
||||
|
||||
def test_burst_sustained_throttles_ordering(self):
|
||||
request = self.factory.get('/')
|
||||
|
||||
# Should be able to hit x view 2 times per minute.
|
||||
response = self.burst_sustained_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.burst_sustained_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.burst_sustained_view(request)
|
||||
assert response.status_code == 429
|
||||
|
||||
# Ensure throttles properly reset by advancing the rest of the minute
|
||||
self.increment_timer(58)
|
||||
|
||||
# Should still be able to hit x view 2 times per minute.
|
||||
response = self.burst_sustained_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.burst_sustained_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.burst_sustained_view(request)
|
||||
assert response.status_code == 429
|
||||
|
||||
# Ensure throttles properly reset by advancing the rest of the minute
|
||||
self.increment_timer(58)
|
||||
|
||||
# Should still be able to hit y view 1 time per minute.
|
||||
self.increment_timer()
|
||||
response = self.burst_sustained_view(request)
|
||||
assert response.status_code == 200
|
||||
|
||||
self.increment_timer()
|
||||
response = self.burst_sustained_view(request)
|
||||
assert response.status_code == 429
|
||||
|
|
Loading…
Reference in New Issue
Block a user