Merge throttling and fix up a coupla things

This commit is contained in:
Tom Christie 2011-06-15 14:41:09 +01:00
commit 1cb84cd4e8
5 changed files with 136 additions and 50 deletions

View File

@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse(
{'detail': 'request was throttled'})
class ConfigurationException(BaseException):
"""To alert for bad configuration decisions as a convenience."""
pass
class BasePermission(object):
"""
A base class from which all permission classes should inherit.
@ -142,9 +137,8 @@ class BaseThrottle(BasePermission):
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[0] <= self.now - self.duration:
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
self.throttle_failure()
else:
@ -157,14 +151,31 @@ class BaseThrottle(BasePermission):
"""
self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration)
header = 'status=SUCCESS; next=%s sec' % self.next()
self.view.add_header('X-Throttle', header)
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response.
"""
header = 'status=FAILURE; next=%s sec' % self.next()
self.view.add_header('X-Throttle', header)
raise _503_SERVICE_UNAVAILABLE
def next(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
return '%.2f' % (remaining_duration / float(available_requests))
class PerUserThrottling(BaseThrottle):
"""

View File

@ -13,25 +13,27 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings'
from django.conf import settings
from django.test.utils import get_runner
def usage():
return """
Usage: python runtests.py [UnitTestClass].[method]
You can pass the Class name of the `UnitTestClass` you want to test.
Append a method name if you only want to test a specific method of that class.
"""
def main():
TestRunner = get_runner(settings)
if hasattr(TestRunner, 'func_name'):
# Pre 1.2 test runners were just functions,
# and did not support the 'failfast' option.
import warnings
warnings.warn(
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
failures = TestRunner(['djangorestframework'])
test_runner = TestRunner()
if len(sys.argv) == 2:
test_case = '.' + sys.argv[1]
elif len(sys.argv) == 1:
test_case = ''
else:
test_runner = TestRunner()
if len(sys.argv) > 1:
test_case = '.' + sys.argv[1]
else:
test_case = ''
failures = test_runner.run_tests(['djangorestframework' + test_case])
print usage()
sys.exit(1)
failures = test_runner.run_tests(['djangorestframework' + test_case])
sys.exit(failures)

View File

@ -1,57 +1,65 @@
"""
Tests for the throttling implementations in the permissions module.
"""
import time
from django.conf.urls.defaults import patterns
from django.test import TestCase
from django.utils import simplejson as json
from django.contrib.auth.models import User
from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
from djangorestframework.resources import FormResource
class MockView(View):
permissions = ( PerUserThrottling, )
throttle = '3/sec' # 3 requests per second
throttle = '3/sec'
def get(self, request):
return 'foo'
class MockView1(MockView):
class MockView_PerViewThrottling(MockView):
permissions = ( PerViewThrottling, )
class MockView2(MockView):
class MockView_PerResourceThrottling(MockView):
permissions = ( PerResourceThrottling, )
#No resource set
class MockView3(MockView2):
resource = FormResource
class MockView_MinuteThrottling(MockView):
throttle = '3/min'
class ThrottlingTests(TestCase):
urls = 'djangorestframework.tests.throttling'
def setUp(self):
"""Reset the cache so that no throttles will be active"""
"""
Reset the cache so that no throttles will be active
"""
cache.clear()
self.factory = RequestFactory()
def test_requests_are_throttled(self):
"""Ensure request rate is limited"""
"""
Ensure request rate is limited
"""
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(503, response.status_code)
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
view.permissions[0].timer = lambda self: value
def test_request_throttling_expires(self):
"""
Ensure request rate is limited for a limited duration only
"""
# Explicitly set the timer, overridding time.time()
MockView.permissions[0].timer = lambda self: 0
self.set_throttle_timer(MockView, 0)
request = self.factory.get('/')
for dummy in range(4):
@ -59,7 +67,7 @@ class ThrottlingTests(TestCase):
self.assertEqual(503, response.status_code)
# Advance the timer by one second
MockView.permissions[0].timer = lambda self: 1
self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request)
self.assertEqual(200, response.status_code)
@ -68,20 +76,73 @@ class ThrottlingTests(TestCase):
request = self.factory.get('/')
request.user = User.objects.create(username='a')
for dummy in range(3):
response = view.as_view()(request)
view.as_view()(request)
request.user = User.objects.create(username='b')
response = view.as_view()(request)
self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self):
"""Ensure request rate is only limited per user, not globally for PerUserThrottles"""
"""
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
self.ensure_is_throttled(MockView, 200)
def test_request_throttling_is_per_view(self):
"""Ensure request rate is limited globally per View for PerViewThrottles"""
self.ensure_is_throttled(MockView1, 503)
"""
Ensure request rate is limited globally per View for PerViewThrottles
"""
self.ensure_is_throttled(MockView_PerViewThrottling, 503)
def test_request_throttling_is_per_resource(self):
"""Ensure request rate is limited globally per Resource for PerResourceThrottles"""
self.ensure_is_throttled(MockView3, 503)
"""
Ensure request rate is limited globally per Resource for PerResourceThrottles
"""
self.ensure_is_throttled(MockView_PerResourceThrottling, 503)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
self.assertEquals(response['X-Throttle'], expect)
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
((0, 'status=SUCCESS; next=0.33 sec'),
(0, 'status=SUCCESS; next=0.50 sec'),
(0, 'status=SUCCESS; next=1.00 sec'),
(0, 'status=FAILURE; next=1.00 sec')
))
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'),
(0, 'status=SUCCESS; next=30.00 sec'),
(0, 'status=SUCCESS; next=60.00 sec'),
(0, 'status=FAILURE; next=60.00 sec')
))
def test_next_rate_remains_constant_if_followed(self):
"""
If a client follows the recommended next request rate,
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'),
(20, 'status=SUCCESS; next=20.00 sec'),
(40, 'status=SUCCESS; next=20.00 sec'),
(60, 'status=SUCCESS; next=20.00 sec'),
(80, 'status=SUCCESS; next=20.00 sec')
))

View File

@ -101,6 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
"""
pass
def add_header(self, field, value):
"""
Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
"""
self.headers[field] = value
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.
@csrf_exempt
@ -108,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
self.request = request
self.args = args
self.kwargs = kwargs
self.headers = {}
# Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here.
prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host())
@ -150,6 +159,9 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
response.headers['Allow'] = ', '.join(self.allowed_methods)
response.headers['Vary'] = 'Authenticate, Accept'
# merge with headers possibly set at some point in the view
response.headers.update(self.headers)
return self.render(response)

View File

@ -31,7 +31,7 @@ Resources
* The ``djangorestframework`` package is `available on PyPI <http://pypi.python.org/pypi/djangorestframework>`_.
* We have an active `discussion group <http://groups.google.com/group/django-rest-framework>`_ and a `project blog <http://blog.django-rest-framework.org>`_.
* Bug reports are handled on the `issue tracker <https://github.com/tomchristie/django-rest-framework/issues>`_.
* There is a `Jenkins CI server <http://datacenter.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!)
* There is a `Jenkins CI server <http://jenkins.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!)
Any and all questions, thoughts, bug reports and contributions are *hugely appreciated*.