mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-03 05:04:31 +03:00
Fix and tests for ScopedRateThrottle. Closes #935
This commit is contained in:
parent
6cc4fe5637
commit
df957c8625
|
@ -7,7 +7,7 @@ from django.contrib.auth.models import User
|
|||
from django.core.cache import cache
|
||||
from django.test.client import RequestFactory
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.throttling import UserRateThrottle
|
||||
from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
|
@ -36,8 +36,6 @@ class MockView_MinuteThrottling(APIView):
|
|||
|
||||
|
||||
class ThrottlingTests(TestCase):
|
||||
urls = 'rest_framework.tests.test_throttling'
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Reset the cache so that no throttles will be active
|
||||
|
@ -141,3 +139,108 @@ class ThrottlingTests(TestCase):
|
|||
(60, None),
|
||||
(80, None)
|
||||
))
|
||||
|
||||
|
||||
class ScopedRateThrottleTests(TestCase):
|
||||
"""
|
||||
Tests for ScopedRateThrottle.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
class XYScopedRateThrottle(ScopedRateThrottle):
|
||||
TIMER_SECONDS = 0
|
||||
THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
|
||||
timer = lambda self: self.TIMER_SECONDS
|
||||
|
||||
class XView(APIView):
|
||||
throttle_classes = (XYScopedRateThrottle,)
|
||||
throttle_scope = 'x'
|
||||
|
||||
def get(self, request):
|
||||
return Response('x')
|
||||
|
||||
class YView(APIView):
|
||||
throttle_classes = (XYScopedRateThrottle,)
|
||||
throttle_scope = 'y'
|
||||
|
||||
def get(self, request):
|
||||
return Response('y')
|
||||
|
||||
class UnscopedView(APIView):
|
||||
throttle_classes = (XYScopedRateThrottle,)
|
||||
|
||||
def get(self, request):
|
||||
return Response('y')
|
||||
|
||||
self.throttle_class = XYScopedRateThrottle
|
||||
self.factory = RequestFactory()
|
||||
self.x_view = XView.as_view()
|
||||
self.y_view = YView.as_view()
|
||||
self.unscoped_view = UnscopedView.as_view()
|
||||
|
||||
def increment_timer(self, seconds=1):
|
||||
self.throttle_class.TIMER_SECONDS += seconds
|
||||
|
||||
def test_scoped_rate_throttle(self):
|
||||
request = self.factory.get('/')
|
||||
|
||||
# Should be able to hit x view 3 times per minute.
|
||||
response = self.x_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
self.increment_timer()
|
||||
response = self.x_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
self.increment_timer()
|
||||
response = self.x_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
self.increment_timer()
|
||||
response = self.x_view(request)
|
||||
self.assertEqual(429, response.status_code)
|
||||
|
||||
# Should be able to hit y view 1 time per minute.
|
||||
self.increment_timer()
|
||||
response = self.y_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
self.increment_timer()
|
||||
response = self.y_view(request)
|
||||
self.assertEqual(429, response.status_code)
|
||||
|
||||
# Ensure throttles properly reset by advancing the rest of the minute
|
||||
self.increment_timer(55)
|
||||
|
||||
# Should still be able to hit x view 3 times per minute.
|
||||
response = self.x_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
self.increment_timer()
|
||||
response = self.x_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
self.increment_timer()
|
||||
response = self.x_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
self.increment_timer()
|
||||
response = self.x_view(request)
|
||||
self.assertEqual(429, response.status_code)
|
||||
|
||||
# Should still be able to hit y view 1 time per minute.
|
||||
self.increment_timer()
|
||||
response = self.y_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
||||
self.increment_timer()
|
||||
response = self.y_view(request)
|
||||
self.assertEqual(429, response.status_code)
|
||||
|
||||
def test_unscoped_view_not_throttled(self):
|
||||
request = self.factory.get('/')
|
||||
|
||||
for idx in range(10):
|
||||
self.increment_timer()
|
||||
response = self.unscoped_view(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
|
|
@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle):
|
|||
"""
|
||||
|
||||
timer = time.time
|
||||
settings = api_settings
|
||||
cache_format = 'throtte_%(scope)s_%(ident)s'
|
||||
scope = None
|
||||
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
|
||||
|
||||
def __init__(self):
|
||||
if not getattr(self, 'rate', None):
|
||||
|
@ -68,7 +68,7 @@ class SimpleRateThrottle(BaseThrottle):
|
|||
raise ImproperlyConfigured(msg)
|
||||
|
||||
try:
|
||||
return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
|
||||
return self.THROTTLE_RATES[self.scope]
|
||||
except KeyError:
|
||||
msg = "No default throttle rate set for '%s' scope" % self.scope
|
||||
raise ImproperlyConfigured(msg)
|
||||
|
@ -187,6 +187,19 @@ class ScopedRateThrottle(SimpleRateThrottle):
|
|||
"""
|
||||
scope_attr = 'throttle_scope'
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def allow_request(self, request, view):
|
||||
self.scope = getattr(view, self.scope_attr, None)
|
||||
|
||||
if not self.scope:
|
||||
return True
|
||||
|
||||
self.rate = self.get_rate()
|
||||
self.num_requests, self.duration = self.parse_rate(self.rate)
|
||||
return super(ScopedRateThrottle, self).allow_request(request, view)
|
||||
|
||||
def get_cache_key(self, request, view):
|
||||
"""
|
||||
If `view.throttle_scope` is not set, don't apply this throttle.
|
||||
|
|
Loading…
Reference in New Issue
Block a user