diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index da400b2fc..d35d37092 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -7,7 +7,7 @@ from django.contrib.auth.models import User from django.core.cache import cache from django.test.client import RequestFactory from rest_framework.views import APIView -from rest_framework.throttling import UserRateThrottle +from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle from rest_framework.response import Response @@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView): class ThrottlingTests(TestCase): - urls = 'rest_framework.tests.test_throttling' - def setUp(self): """ Reset the cache so that no throttles will be active @@ -141,3 +139,108 @@ class ThrottlingTests(TestCase): (60, None), (80, None) )) + + +class ScopedRateThrottleTests(TestCase): + """ + Tests for ScopedRateThrottle. + """ + + def setUp(self): + class XYScopedRateThrottle(ScopedRateThrottle): + TIMER_SECONDS = 0 + THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} + timer = lambda self: self.TIMER_SECONDS + + class XView(APIView): + throttle_classes = (XYScopedRateThrottle,) + throttle_scope = 'x' + + def get(self, request): + return Response('x') + + class YView(APIView): + throttle_classes = (XYScopedRateThrottle,) + throttle_scope = 'y' + + def get(self, request): + return Response('y') + + class UnscopedView(APIView): + throttle_classes = (XYScopedRateThrottle,) + + def get(self, request): + return Response('y') + + self.throttle_class = XYScopedRateThrottle + self.factory = RequestFactory() + self.x_view = XView.as_view() + self.y_view = YView.as_view() + self.unscoped_view = UnscopedView.as_view() + + def increment_timer(self, seconds=1): + self.throttle_class.TIMER_SECONDS += seconds + + def test_scoped_rate_throttle(self): + request = self.factory.get('/') + + # Should be able to hit x view 3 times per minute. + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(429, response.status_code) + + # Should be able to hit y view 1 time per minute. + self.increment_timer() + response = self.y_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.y_view(request) + self.assertEqual(429, response.status_code) + + # Ensure throttles properly reset by advancing the rest of the minute + self.increment_timer(55) + + # Should still be able to hit x view 3 times per minute. + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.x_view(request) + self.assertEqual(429, response.status_code) + + # Should still be able to hit y view 1 time per minute. + self.increment_timer() + response = self.y_view(request) + self.assertEqual(200, response.status_code) + + self.increment_timer() + response = self.y_view(request) + self.assertEqual(429, response.status_code) + + def test_unscoped_view_not_throttled(self): + request = self.factory.get('/') + + for idx in range(10): + self.increment_timer() + response = self.unscoped_view(request) + self.assertEqual(200, response.status_code) diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 9d89d1cbc..ec9c80ffd 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle): """ timer = time.time - settings = api_settings cache_format = 'throtte_%(scope)s_%(ident)s' scope = None + THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): if not getattr(self, 'rate', None): @@ -68,7 +68,7 @@ class SimpleRateThrottle(BaseThrottle): raise ImproperlyConfigured(msg) try: - return self.settings.DEFAULT_THROTTLE_RATES[self.scope] + return self.THROTTLE_RATES[self.scope] except KeyError: msg = "No default throttle rate set for '%s' scope" % self.scope raise ImproperlyConfigured(msg) @@ -187,6 +187,19 @@ class ScopedRateThrottle(SimpleRateThrottle): """ scope_attr = 'throttle_scope' + def __init__(self): + pass + + def allow_request(self, request, view): + self.scope = getattr(view, self.scope_attr, None) + + if not self.scope: + return True + + self.rate = self.get_rate() + self.num_requests, self.duration = self.parse_rate(self.rate) + return super(ScopedRateThrottle, self).allow_request(request, view) + def get_cache_key(self, request, view): """ If `view.throttle_scope` is not set, don't apply this throttle.