implemeneted #28

This commit is contained in:
markotibold 2011-06-13 20:42:37 +02:00
parent 1720c44904
commit 437a062b6c
3 changed files with 91 additions and 26 deletions

View File

@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse(
{'detail': 'request was throttled'})
class ConfigurationException(BaseException):
"""To alert for bad configuration decisions as a convenience."""
pass
class BasePermission(object):
"""
A base class from which all permission classes should inherit.
@ -144,12 +139,11 @@ class BaseThrottle(BasePermission):
# throttle duration
while self.history and self.history[0] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
self.throttle_failure()
else:
self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
@ -157,15 +151,23 @@ class BaseThrottle(BasePermission):
"""
self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration)
self.view.add_header('X-Throttle', 'status=SUCCESS; next=%s sec' % self.next())
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response.
"""
self.view.add_header('X-Throttle', 'status=FAILURE; next=%s sec' % self.next())
raise _503_SERVICE_UNAVAILABLE
def next(self):
"""
Returns the recommended next request time in seconds.
"""
return '%.2f' % (self.duration / (self.num_requests - len(self.history) *1.0 + 1))
class PerUserThrottling(BaseThrottle):
"""
Limits the rate of API calls that may be made by a given user.

View File

@ -1,17 +1,14 @@
"""
Tests for the throttling implementations in the permissions module.
"""
import time
from django.conf.urls.defaults import patterns
from django.test import TestCase
from django.utils import simplejson as json
from django.contrib.auth.models import User
from django.core.cache import cache
from djangorestframework.compat import RequestFactory
from djangorestframework.views import View
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
from djangorestframework.resources import FormResource
class MockView(View):
@ -30,28 +27,40 @@ class MockView2(MockView):
class MockView3(MockView2):
resource = FormResource
class MockView4(MockView):
throttle = '3/min' # 3 request per minute
class ThrottlingTests(TestCase):
urls = 'djangorestframework.tests.throttling'
def setUp(self):
"""Reset the cache so that no throttles will be active"""
"""
Reset the cache so that no throttles will be active
"""
cache.clear()
self.factory = RequestFactory()
def test_requests_are_throttled(self):
"""Ensure request rate is limited"""
"""
Ensure request rate is limited
"""
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(503, response.status_code)
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
view.permissions[0].timer = lambda self: value
def test_request_throttling_expires(self):
"""
Ensure request rate is limited for a limited duration only
"""
# Explicitly set the timer, overridding time.time()
MockView.permissions[0].timer = lambda self: 0
self.set_throttle_timer(MockView, 0)
request = self.factory.get('/')
for dummy in range(4):
@ -59,7 +68,7 @@ class ThrottlingTests(TestCase):
self.assertEqual(503, response.status_code)
# Advance the timer by one second
MockView.permissions[0].timer = lambda self: 1
self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request)
self.assertEqual(200, response.status_code)
@ -68,20 +77,61 @@ class ThrottlingTests(TestCase):
request = self.factory.get('/')
request.user = User.objects.create(username='a')
for dummy in range(3):
response = view.as_view()(request)
view.as_view()(request)
request.user = User.objects.create(username='b')
response = view.as_view()(request)
self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self):
"""Ensure request rate is only limited per user, not globally for PerUserThrottles"""
"""
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
self.ensure_is_throttled(MockView, 200)
def test_request_throttling_is_per_view(self):
"""Ensure request rate is limited globally per View for PerViewThrottles"""
"""
Ensure request rate is limited globally per View for PerViewThrottles
"""
self.ensure_is_throttled(MockView1, 503)
def test_request_throttling_is_per_resource(self):
"""Ensure request rate is limited globally per Resource for PerResourceThrottles"""
"""
Ensure request rate is limited globally per Resource for PerResourceThrottles
"""
self.ensure_is_throttled(MockView3, 503)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for expect in expected_headers:
self.set_throttle_timer(view, 0)
response = view.as_view()(request)
self.assertEquals(response['X-Throttle'], expect)
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
('status=SUCCESS; next=0.33 sec',
'status=SUCCESS; next=0.50 sec',
'status=SUCCESS; next=1.00 sec',
'status=FAILURE; next=1.00 sec'
))
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView4,
('status=SUCCESS; next=20.00 sec',
'status=SUCCESS; next=30.00 sec',
'status=SUCCESS; next=60.00 sec',
'status=FAILURE; next=60.00 sec'
))

View File

@ -64,7 +64,11 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
"""
permissions = ( permissions.FullAnonAccess, )
"""
Headers to be sent with response.
"""
headers = {}
@classmethod
def as_view(cls, **initkwargs):
"""
@ -101,6 +105,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
"""
pass
def add_header(self, field, value):
"""
Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
"""
self.headers[field] = value
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.
@csrf_exempt
@ -149,7 +159,10 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
# also it's currently sub-obtimal for HTTP caching - need to sort that out.
response.headers['Allow'] = ', '.join(self.allowed_methods)
response.headers['Vary'] = 'Authenticate, Accept'
# merge with headers possibly set by a Throttle class
response.headers = dict(response.headers.items() + self.headers.items())
return self.render(response)