Bits of cleaning up for the throttling

This commit is contained in:
Tom Christie 2011-06-14 11:08:29 +01:00
parent fb26b11a75
commit 323d52e7c4
2 changed files with 77 additions and 41 deletions

View File

@ -1,6 +1,6 @@
"""
The :mod:`permissions` module bundles a set of permission classes that are used
for checking if a request passes a certain set of constraints. You can assign a permision
for checking if a request passes a certain set of constraints. You can assign a permission
class to your view by setting your View's :attr:`permissions` class attribute.
"""
@ -26,14 +26,14 @@ _403_FORBIDDEN_RESPONSE = ErrorResponse(
{'detail': 'You do not have permission to access this resource. ' +
'You may need to login or otherwise authenticate the request.'})
_503_THROTTLED_RESPONSE = ErrorResponse(
_503_SERVICE_UNAVAILABLE = ErrorResponse(
status.HTTP_503_SERVICE_UNAVAILABLE,
{'detail': 'request was throttled'})
class ConfigurationException(BaseException):
"""To alert for bad configuration desicions as a convenience."""
pass
"""To alert for bad configuration decisions as a convenience."""
pass
class BasePermission(object):
@ -93,38 +93,56 @@ class IsUserOrIsAnonReadOnly(BasePermission):
self.view.method != 'HEAD'):
raise _403_FORBIDDEN_RESPONSE
class BaseThrottle(BasePermission):
"""
Rate throttling of requests.
The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
The attribute is a string of the form 'number of requests/period'. Period must be an element
of (sec, min, hour, day)
The rate (requests / seconds) is set by a :attr:`throttle` attribute
on the :class:`.View` class. The attribute is a string of the form 'number of
requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
attr_name = 'throttle'
default = '0/sec'
timer = time.time
def get_cache_key(self):
"""Should return the cache-key corresponding to the semantics of the class that implements
the throttling behaviour.
"""
Should return a unique cache-key which can be used for throttling.
Muse be overridden.
"""
pass
def check_permission(self, auth):
num, period = getattr(self.view, 'throttle', '0/sec').split('/')
"""
Check the throttling.
Return `None` or raise an :exc:`.ErrorResponse`.
"""
num, period = getattr(self.view, self.attr_name, self.default).split('/')
self.num_requests = int(num)
self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
self.auth = auth
self.check_throttle()
def check_throttle(self):
"""On success calls `throttle_success`. On failure calls `throttle_failure`. """
"""
Implement the check to see if the request should be throttled.
On success calls :meth:`throttle_success`.
On failure calls :meth:`throttle_failure`.
"""
self.key = self.get_cache_key()
self.history = cache.get(self.key, [])
self.now = time.time()
self.now = self.timer()
# Drop any requests from the history which have now passed the throttle duration
while self.history and self.history[0] < self.now - self.duration:
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[0] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
@ -133,18 +151,28 @@ class BaseThrottle(BasePermission):
self.throttle_success()
def throttle_success(self):
"""Inserts the current request's timesatmp along with the key into the cache."""
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
cache.set(self.key, self.history, self.duration)
def throttle_failure(self):
"""Raises a 503 """
raise _503_THROTTLED_RESPONSE
"""
Called when a request to the API has failed due to throttling.
Raises a '503 service unavailable' response.
"""
raise _503_SERVICE_UNAVAILABLE
class PerUserThrottling(BaseThrottle):
"""
The user id will be used as a unique identifier if the user is authenticated.
For anonymous requests, the IP address of the client will be used.
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique identifier if the user is
authenticated. For anonymous requests, the IP address of the client will
be used.
"""
def get_cache_key(self):
@ -152,24 +180,29 @@ class PerUserThrottling(BaseThrottle):
ident = str(self.auth)
else:
ident = self.view.request.META.get('REMOTE_ADDR', None)
return 'throttle_%s' % ident
return 'throttle_user_%s' % ident
class PerViewThrottling(BaseThrottle):
"""
The class name of the cuurent view will be used as a unique identifier.
Limits the rate of API calls that may be used on a given view.
The class name of the view is used as a unique identifier to
throttle against.
"""
def get_cache_key(self):
return 'throttle_%s' % self.view.__class__.__name__
return 'throttle_view_%s' % self.view.__class__.__name__
class PerResourceThrottling(BaseThrottle):
"""
The class name of the cuurent resource will be used as a unique identifier.
Raises :exc:`ConfigurationException` if no resource attribute is set on the view class.
Limits the rate of API calls that may be used against all views on
a given resource.
The class name of the resource is used as a unique identifier to
throttle against.
"""
def get_cache_key(self):
if self.view.resource != None:
return 'throttle_%s' % self.view.resource.__class__.__name__
raise ConfigurationException(
"A per-resource throttle was set to a view that does not have a resource.")
return 'throttle_resource_%s' % self.view.resource.__class__.__name__

View File

@ -1,3 +1,6 @@
"""
Tests for the throttling implementations in the permissions module.
"""
import time
from django.conf.urls.defaults import patterns
@ -44,12 +47,20 @@ class ThrottlingTests(TestCase):
self.assertEqual(503, response.status_code)
def test_request_throttling_expires(self):
"""Ensure request rate is limited for a limited duration only"""
"""
Ensure request rate is limited for a limited duration only
"""
# Explicitly set the timer, overridding time.time()
MockView.permissions[0].timer = lambda self: 0
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(503, response.status_code)
time.sleep(1)
# Advance the timer by one second
MockView.permissions[0].timer = lambda self: 1
response = MockView.as_view()(request)
self.assertEqual(200, response.status_code)
@ -63,7 +74,7 @@ class ThrottlingTests(TestCase):
self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self):
"""Ensure request rate is only limited per user, not globally for PerUserTrottles"""
"""Ensure request rate is only limited per user, not globally for PerUserThrottles"""
self.ensure_is_throttled(MockView, 200)
def test_request_throttling_is_per_view(self):
@ -73,12 +84,4 @@ class ThrottlingTests(TestCase):
def test_request_throttling_is_per_resource(self):
"""Ensure request rate is limited globally per Resource for PerResourceThrottles"""
self.ensure_is_throttled(MockView3, 503)
def test_raises_no_resource_found(self):
"""Ensure an Exception is raised when someone sets at per-resource throttle
on a view with no resource set."""
request = self.factory.get('/')
view = MockView2.as_view()
self.assertRaises(ConfigurationException, view, request)