Merge throttling and fix up a coupla things

This commit is contained in:
Tom Christie 2011-06-15 14:41:09 +01:00
commit 1cb84cd4e8
5 changed files with 136 additions and 50 deletions

View File

@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse(
{'detail': 'request was throttled'}) {'detail': 'request was throttled'})
class ConfigurationException(BaseException):
"""To alert for bad configuration decisions as a convenience."""
pass
class BasePermission(object): class BasePermission(object):
""" """
A base class from which all permission classes should inherit. A base class from which all permission classes should inherit.
@ -142,9 +137,8 @@ class BaseThrottle(BasePermission):
# Drop any requests from the history which have now passed the # Drop any requests from the history which have now passed the
# throttle duration # throttle duration
while self.history and self.history[0] <= self.now - self.duration: while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop() self.history.pop()
if len(self.history) >= self.num_requests: if len(self.history) >= self.num_requests:
self.throttle_failure() self.throttle_failure()
else: else:
@ -157,14 +151,31 @@ class BaseThrottle(BasePermission):
""" """
self.history.insert(0, self.now) self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration) cache.set(self.key, self.history, self.duration)
header = 'status=SUCCESS; next=%s sec' % self.next()
self.view.add_header('X-Throttle', header)
def throttle_failure(self): def throttle_failure(self):
""" """
Called when a request to the API has failed due to throttling. Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response. Raises a '503 service unavailable' response.
""" """
header = 'status=FAILURE; next=%s sec' % self.next()
self.view.add_header('X-Throttle', header)
raise _503_SERVICE_UNAVAILABLE raise _503_SERVICE_UNAVAILABLE
def next(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
return '%.2f' % (remaining_duration / float(available_requests))
class PerUserThrottling(BaseThrottle): class PerUserThrottling(BaseThrottle):
""" """

View File

@ -13,25 +13,27 @@ os.environ['DJANGO_SETTINGS_MODULE'] = 'djangorestframework.runtests.settings'
from django.conf import settings from django.conf import settings
from django.test.utils import get_runner from django.test.utils import get_runner
def usage():
return """
Usage: python runtests.py [UnitTestClass].[method]
You can pass the Class name of the `UnitTestClass` you want to test.
Append a method name if you only want to test a specific method of that class.
"""
def main(): def main():
TestRunner = get_runner(settings) TestRunner = get_runner(settings)
if hasattr(TestRunner, 'func_name'): test_runner = TestRunner()
# Pre 1.2 test runners were just functions, if len(sys.argv) == 2:
# and did not support the 'failfast' option. test_case = '.' + sys.argv[1]
import warnings elif len(sys.argv) == 1:
warnings.warn( test_case = ''
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
failures = TestRunner(['djangorestframework'])
else: else:
test_runner = TestRunner() print usage()
if len(sys.argv) > 1: sys.exit(1)
test_case = '.' + sys.argv[1] failures = test_runner.run_tests(['djangorestframework' + test_case])
else:
test_case = ''
failures = test_runner.run_tests(['djangorestframework' + test_case])
sys.exit(failures) sys.exit(failures)

View File

@ -1,57 +1,65 @@
""" """
Tests for the throttling implementations in the permissions module. Tests for the throttling implementations in the permissions module.
""" """
import time
from django.conf.urls.defaults import patterns
from django.test import TestCase from django.test import TestCase
from django.utils import simplejson as json
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.cache import cache from django.core.cache import cache
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from djangorestframework.views import View from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
from djangorestframework.resources import FormResource from djangorestframework.resources import FormResource
class MockView(View): class MockView(View):
permissions = ( PerUserThrottling, ) permissions = ( PerUserThrottling, )
throttle = '3/sec' # 3 requests per second throttle = '3/sec'
def get(self, request): def get(self, request):
return 'foo' return 'foo'
class MockView1(MockView): class MockView_PerViewThrottling(MockView):
permissions = ( PerViewThrottling, ) permissions = ( PerViewThrottling, )
class MockView2(MockView): class MockView_PerResourceThrottling(MockView):
permissions = ( PerResourceThrottling, ) permissions = ( PerResourceThrottling, )
#No resource set
class MockView3(MockView2):
resource = FormResource resource = FormResource
class MockView_MinuteThrottling(MockView):
throttle = '3/min'
class ThrottlingTests(TestCase): class ThrottlingTests(TestCase):
urls = 'djangorestframework.tests.throttling' urls = 'djangorestframework.tests.throttling'
def setUp(self): def setUp(self):
"""Reset the cache so that no throttles will be active""" """
Reset the cache so that no throttles will be active
"""
cache.clear() cache.clear()
self.factory = RequestFactory() self.factory = RequestFactory()
def test_requests_are_throttled(self): def test_requests_are_throttled(self):
"""Ensure request rate is limited""" """
Ensure request rate is limited
"""
request = self.factory.get('/') request = self.factory.get('/')
for dummy in range(4): for dummy in range(4):
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(503, response.status_code) self.assertEqual(503, response.status_code)
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
view.permissions[0].timer = lambda self: value
def test_request_throttling_expires(self): def test_request_throttling_expires(self):
""" """
Ensure request rate is limited for a limited duration only Ensure request rate is limited for a limited duration only
""" """
# Explicitly set the timer, overridding time.time() self.set_throttle_timer(MockView, 0)
MockView.permissions[0].timer = lambda self: 0
request = self.factory.get('/') request = self.factory.get('/')
for dummy in range(4): for dummy in range(4):
@ -59,7 +67,7 @@ class ThrottlingTests(TestCase):
self.assertEqual(503, response.status_code) self.assertEqual(503, response.status_code)
# Advance the timer by one second # Advance the timer by one second
MockView.permissions[0].timer = lambda self: 1 self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
@ -68,20 +76,73 @@ class ThrottlingTests(TestCase):
request = self.factory.get('/') request = self.factory.get('/')
request.user = User.objects.create(username='a') request.user = User.objects.create(username='a')
for dummy in range(3): for dummy in range(3):
response = view.as_view()(request) view.as_view()(request)
request.user = User.objects.create(username='b') request.user = User.objects.create(username='b')
response = view.as_view()(request) response = view.as_view()(request)
self.assertEqual(expect, response.status_code) self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self): def test_request_throttling_is_per_user(self):
"""Ensure request rate is only limited per user, not globally for PerUserThrottles""" """
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
self.ensure_is_throttled(MockView, 200) self.ensure_is_throttled(MockView, 200)
def test_request_throttling_is_per_view(self): def test_request_throttling_is_per_view(self):
"""Ensure request rate is limited globally per View for PerViewThrottles""" """
self.ensure_is_throttled(MockView1, 503) Ensure request rate is limited globally per View for PerViewThrottles
"""
self.ensure_is_throttled(MockView_PerViewThrottling, 503)
def test_request_throttling_is_per_resource(self): def test_request_throttling_is_per_resource(self):
"""Ensure request rate is limited globally per Resource for PerResourceThrottles""" """
self.ensure_is_throttled(MockView3, 503) Ensure request rate is limited globally per Resource for PerResourceThrottles
"""
self.ensure_is_throttled(MockView_PerResourceThrottling, 503)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
self.assertEquals(response['X-Throttle'], expect)
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
((0, 'status=SUCCESS; next=0.33 sec'),
(0, 'status=SUCCESS; next=0.50 sec'),
(0, 'status=SUCCESS; next=1.00 sec'),
(0, 'status=FAILURE; next=1.00 sec')
))
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'),
(0, 'status=SUCCESS; next=30.00 sec'),
(0, 'status=SUCCESS; next=60.00 sec'),
(0, 'status=FAILURE; next=60.00 sec')
))
def test_next_rate_remains_constant_if_followed(self):
"""
If a client follows the recommended next request rate,
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, 'status=SUCCESS; next=20.00 sec'),
(20, 'status=SUCCESS; next=20.00 sec'),
(40, 'status=SUCCESS; next=20.00 sec'),
(60, 'status=SUCCESS; next=20.00 sec'),
(80, 'status=SUCCESS; next=20.00 sec')
))

View File

@ -101,6 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
""" """
pass pass
def add_header(self, field, value):
"""
Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
"""
self.headers[field] = value
# Note: session based authentication is explicitly CSRF validated, # Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt. # all other authentication is CSRF exempt.
@csrf_exempt @csrf_exempt
@ -108,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
self.request = request self.request = request
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
self.headers = {}
# Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here. # Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here.
prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host()) prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host())
@ -150,6 +159,9 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
response.headers['Allow'] = ', '.join(self.allowed_methods) response.headers['Allow'] = ', '.join(self.allowed_methods)
response.headers['Vary'] = 'Authenticate, Accept' response.headers['Vary'] = 'Authenticate, Accept'
# merge with headers possibly set at some point in the view
response.headers.update(self.headers)
return self.render(response) return self.render(response)

View File

@ -31,7 +31,7 @@ Resources
* The ``djangorestframework`` package is `available on PyPI <http://pypi.python.org/pypi/djangorestframework>`_. * The ``djangorestframework`` package is `available on PyPI <http://pypi.python.org/pypi/djangorestframework>`_.
* We have an active `discussion group <http://groups.google.com/group/django-rest-framework>`_ and a `project blog <http://blog.django-rest-framework.org>`_. * We have an active `discussion group <http://groups.google.com/group/django-rest-framework>`_ and a `project blog <http://blog.django-rest-framework.org>`_.
* Bug reports are handled on the `issue tracker <https://github.com/tomchristie/django-rest-framework/issues>`_. * Bug reports are handled on the `issue tracker <https://github.com/tomchristie/django-rest-framework/issues>`_.
* There is a `Jenkins CI server <http://datacenter.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!) * There is a `Jenkins CI server <http://jenkins.tibold.nl/job/djangorestframework/>`_ which tracks test status and coverage reporting. (Thanks Marko!)
Any and all questions, thoughts, bug reports and contributions are *hugely appreciated*. Any and all questions, thoughts, bug reports and contributions are *hugely appreciated*.