diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 834ced148..e449ac2f3 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -124,22 +124,25 @@ class SimpleRateThrottle(BaseThrottle): self.history = self.cache.get(self.key, []) self.now = self.timer() + print(self.history) # Drop any requests from the history which have now passed the # throttle duration while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() - if len(self.history) >= self.num_requests: - return self.throttle_failure() - return self.throttle_success() - def throttle_success(self): + if len(self.history) >= self.num_requests: + return False + return True # self.throttle_success() + + def throttle_success(self, request, view): """ Inserts the current request's timestamp along with the key into the cache. """ - self.history.insert(0, self.now) - self.cache.set(self.key, self.history, self.duration) + if self.scope: + self.history.insert(0, self.now) + self.cache.set(self.key, self.history, self.duration) return True def throttle_failure(self): @@ -152,6 +155,7 @@ class SimpleRateThrottle(BaseThrottle): """ Returns the recommended next request time in seconds. """ + if self.history: remaining_duration = self.duration - (self.now - self.history[-1]) else: diff --git a/rest_framework/views.py b/rest_framework/views.py index 04951ed93..6b3da7aac 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -352,9 +352,18 @@ class APIView(View): Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ - for throttle in self.get_throttles(): - if not throttle.allow_request(request, self): - self.throttled(request, throttle.wait()) + throttles = self.get_throttles() + request_allowed = all( + [throttle.allow_request(request, self) for throttle in throttles] + ) + print(request_allowed, [throttle.allow_request(request, self) for throttle in throttles]) + if request_allowed: + [throttle.throttle_success(request, self) for throttle in throttles] + else: + min_wait = min( + [throttle.wait() for throttle in throttles] + ) + self.throttled(request, min_wait) def determine_version(self, request, *args, **kwargs): """ diff --git a/tests/test_throttling.py b/tests/test_throttling.py index b220a33a6..99d2e7912 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -38,6 +38,9 @@ class NonTimeThrottle(BaseThrottle): return True return False + def throttle_success(self, request, view): + pass + class MockView(APIView): throttle_classes = (User3SecRateThrottle,) @@ -449,3 +452,143 @@ class AnonRateThrottleTests(TestCase): request = Request(HttpRequest()) cache_key = self.throttle.get_cache_key(request, view={}) assert cache_key == 'throttle_anon_None' + + +class TwoPeriodRateThrottleTests(TestCase): + """ + Tests for views with more than one period. + The order of the throttle classes should not matter + Eg. + 2/min and 5/day + """ + + def setUp(self): + self.DEFAULT_THROTTLE_RATES = { + 'burst-anon': '2/min', + 'sustained-anon': '5/day' + } + self.TIMER_SECONDS = 0 + cache.clear() + + class BurstRateThrottle(AnonRateThrottle): + THROTTLE_RATES = self.DEFAULT_THROTTLE_RATES + TIMER_SECONDS = self.TIMER_SECONDS + scope = 'burst-anon' + + def timer(self): + return self.TIMER_SECONDS + + class SustainedRateThrottle(AnonRateThrottle): + THROTTLE_RATES = self.DEFAULT_THROTTLE_RATES + TIMER_SECONDS = self.TIMER_SECONDS + scope = 'sustained-anon' + + def timer(self): + return self.TIMER_SECONDS + + class BurstSustainedView(APIView): + throttle_classes = (BurstRateThrottle, SustainedRateThrottle) + + def get(self, request): + return Response('x') + + class SustainedBurstView(APIView): + throttle_classes = (SustainedRateThrottle, BurstRateThrottle) + + def get(self, request): + return Response('y') + + self.factory = APIRequestFactory() + self.sustained_burst_view = SustainedBurstView.as_view() + self.burst_sustained_view = BurstSustainedView.as_view() + self.burst_sustained_throttle = BurstRateThrottle + self.sustained_burst_throttle = SustainedRateThrottle + + def increment_timer(self, seconds=1): + self.burst_sustained_throttle.TIMER_SECONDS += seconds + self.sustained_burst_throttle.TIMER_SECONDS += seconds + + def test_sustained_burst_throttles_ordering(self): + request = self.factory.get('/') + + # Should be able to hit x view 2 times per minute. + + response = self.sustained_burst_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.sustained_burst_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.sustained_burst_view(request) + assert response.status_code == 429 + + # Ensure throttles properly reset by advancing the rest of the minute + self.increment_timer(58) + + # Should still be able to hit x view 2 times per minute. + response = self.sustained_burst_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.sustained_burst_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.sustained_burst_view(request) + assert response.status_code == 429 + + # Ensure throttles properly reset by advancing the rest of the minute + self.increment_timer(58) + + # Should still be able to hit y view 1 time per minute. + self.increment_timer() + response = self.sustained_burst_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.sustained_burst_view(request) + assert response.status_code == 429 + + def test_burst_sustained_throttles_ordering(self): + request = self.factory.get('/') + + # Should be able to hit x view 2 times per minute. + response = self.burst_sustained_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.burst_sustained_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.burst_sustained_view(request) + assert response.status_code == 429 + + # Ensure throttles properly reset by advancing the rest of the minute + self.increment_timer(58) + + # Should still be able to hit x view 2 times per minute. + response = self.burst_sustained_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.burst_sustained_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.burst_sustained_view(request) + assert response.status_code == 429 + + # Ensure throttles properly reset by advancing the rest of the minute + self.increment_timer(58) + + # Should still be able to hit y view 1 time per minute. + self.increment_timer() + response = self.burst_sustained_view(request) + assert response.status_code == 200 + + self.increment_timer() + response = self.burst_sustained_view(request) + assert response.status_code == 429