implemeneted #28

This commit is contained in:
markotibold 2011-06-13 20:42:37 +02:00
parent 1720c44904
commit 437a062b6c
3 changed files with 91 additions and 26 deletions

View File

@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse(
{'detail': 'request was throttled'}) {'detail': 'request was throttled'})
class ConfigurationException(BaseException):
"""To alert for bad configuration decisions as a convenience."""
pass
class BasePermission(object): class BasePermission(object):
""" """
A base class from which all permission classes should inherit. A base class from which all permission classes should inherit.
@ -144,7 +139,6 @@ class BaseThrottle(BasePermission):
# throttle duration # throttle duration
while self.history and self.history[0] <= self.now - self.duration: while self.history and self.history[0] <= self.now - self.duration:
self.history.pop() self.history.pop()
if len(self.history) >= self.num_requests: if len(self.history) >= self.num_requests:
self.throttle_failure() self.throttle_failure()
else: else:
@ -157,14 +151,22 @@ class BaseThrottle(BasePermission):
""" """
self.history.insert(0, self.now) self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration) cache.set(self.key, self.history, self.duration)
self.view.add_header('X-Throttle', 'status=SUCCESS; next=%s sec' % self.next())
def throttle_failure(self): def throttle_failure(self):
""" """
Called when a request to the API has failed due to throttling. Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response. Raises a '503 service unavailable' response.
""" """
self.view.add_header('X-Throttle', 'status=FAILURE; next=%s sec' % self.next())
raise _503_SERVICE_UNAVAILABLE raise _503_SERVICE_UNAVAILABLE
def next(self):
"""
Returns the recommended next request time in seconds.
"""
return '%.2f' % (self.duration / (self.num_requests - len(self.history) *1.0 + 1))
class PerUserThrottling(BaseThrottle): class PerUserThrottling(BaseThrottle):
""" """

View File

@ -1,17 +1,14 @@
""" """
Tests for the throttling implementations in the permissions module. Tests for the throttling implementations in the permissions module.
""" """
import time
from django.conf.urls.defaults import patterns
from django.test import TestCase from django.test import TestCase
from django.utils import simplejson as json
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.cache import cache from django.core.cache import cache
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from djangorestframework.views import View from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
from djangorestframework.resources import FormResource from djangorestframework.resources import FormResource
class MockView(View): class MockView(View):
@ -31,27 +28,39 @@ class MockView2(MockView):
class MockView3(MockView2): class MockView3(MockView2):
resource = FormResource resource = FormResource
class MockView4(MockView):
throttle = '3/min' # 3 request per minute
class ThrottlingTests(TestCase): class ThrottlingTests(TestCase):
urls = 'djangorestframework.tests.throttling' urls = 'djangorestframework.tests.throttling'
def setUp(self): def setUp(self):
"""Reset the cache so that no throttles will be active""" """
Reset the cache so that no throttles will be active
"""
cache.clear() cache.clear()
self.factory = RequestFactory() self.factory = RequestFactory()
def test_requests_are_throttled(self): def test_requests_are_throttled(self):
"""Ensure request rate is limited""" """
Ensure request rate is limited
"""
request = self.factory.get('/') request = self.factory.get('/')
for dummy in range(4): for dummy in range(4):
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(503, response.status_code) self.assertEqual(503, response.status_code)
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
view.permissions[0].timer = lambda self: value
def test_request_throttling_expires(self): def test_request_throttling_expires(self):
""" """
Ensure request rate is limited for a limited duration only Ensure request rate is limited for a limited duration only
""" """
# Explicitly set the timer, overridding time.time() self.set_throttle_timer(MockView, 0)
MockView.permissions[0].timer = lambda self: 0
request = self.factory.get('/') request = self.factory.get('/')
for dummy in range(4): for dummy in range(4):
@ -59,7 +68,7 @@ class ThrottlingTests(TestCase):
self.assertEqual(503, response.status_code) self.assertEqual(503, response.status_code)
# Advance the timer by one second # Advance the timer by one second
MockView.permissions[0].timer = lambda self: 1 self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
@ -68,20 +77,61 @@ class ThrottlingTests(TestCase):
request = self.factory.get('/') request = self.factory.get('/')
request.user = User.objects.create(username='a') request.user = User.objects.create(username='a')
for dummy in range(3): for dummy in range(3):
response = view.as_view()(request) view.as_view()(request)
request.user = User.objects.create(username='b') request.user = User.objects.create(username='b')
response = view.as_view()(request) response = view.as_view()(request)
self.assertEqual(expect, response.status_code) self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self): def test_request_throttling_is_per_user(self):
"""Ensure request rate is only limited per user, not globally for PerUserThrottles""" """
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
self.ensure_is_throttled(MockView, 200) self.ensure_is_throttled(MockView, 200)
def test_request_throttling_is_per_view(self): def test_request_throttling_is_per_view(self):
"""Ensure request rate is limited globally per View for PerViewThrottles""" """
Ensure request rate is limited globally per View for PerViewThrottles
"""
self.ensure_is_throttled(MockView1, 503) self.ensure_is_throttled(MockView1, 503)
def test_request_throttling_is_per_resource(self): def test_request_throttling_is_per_resource(self):
"""Ensure request rate is limited globally per Resource for PerResourceThrottles""" """
Ensure request rate is limited globally per Resource for PerResourceThrottles
"""
self.ensure_is_throttled(MockView3, 503) self.ensure_is_throttled(MockView3, 503)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for expect in expected_headers:
self.set_throttle_timer(view, 0)
response = view.as_view()(request)
self.assertEquals(response['X-Throttle'], expect)
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
('status=SUCCESS; next=0.33 sec',
'status=SUCCESS; next=0.50 sec',
'status=SUCCESS; next=1.00 sec',
'status=FAILURE; next=1.00 sec'
))
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView4,
('status=SUCCESS; next=20.00 sec',
'status=SUCCESS; next=30.00 sec',
'status=SUCCESS; next=60.00 sec',
'status=FAILURE; next=60.00 sec'
))

View File

@ -64,6 +64,10 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
""" """
permissions = ( permissions.FullAnonAccess, ) permissions = ( permissions.FullAnonAccess, )
"""
Headers to be sent with response.
"""
headers = {}
@classmethod @classmethod
def as_view(cls, **initkwargs): def as_view(cls, **initkwargs):
@ -101,6 +105,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
""" """
pass pass
def add_header(self, field, value):
"""
Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
"""
self.headers[field] = value
# Note: session based authentication is explicitly CSRF validated, # Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt. # all other authentication is CSRF exempt.
@csrf_exempt @csrf_exempt
@ -150,6 +160,9 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
response.headers['Allow'] = ', '.join(self.allowed_methods) response.headers['Allow'] = ', '.join(self.allowed_methods)
response.headers['Vary'] = 'Authenticate, Accept' response.headers['Vary'] = 'Authenticate, Accept'
# merge with headers possibly set by a Throttle class
response.headers = dict(response.headers.items() + self.headers.items())
return self.render(response) return self.render(response)