mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-26 03:23:59 +03:00
* implemented Tom's nice config string for the trotlle rate e.g. '3/sec'
* We now have per-user, per-view and per-resource throttling * Added a new exxception class as a convenience to detect pointless throttles * refactored
This commit is contained in:
parent
f0b3b9d7ea
commit
5be359fb29
|
@ -29,6 +29,10 @@ _503_THROTTLED_RESPONSE = ErrorResponse(
|
|||
{'detail': 'request was throttled'})
|
||||
|
||||
|
||||
class ConfigurationException(BaseException):
|
||||
"""To alert for bad configuration desicions as a convenience."""
|
||||
pass
|
||||
|
||||
|
||||
class BasePermission(object):
|
||||
"""
|
||||
|
@ -87,70 +91,83 @@ class IsUserOrIsAnonReadOnly(BasePermission):
|
|||
self.view.method != 'HEAD'):
|
||||
raise _403_FORBIDDEN_RESPONSE
|
||||
|
||||
|
||||
class PerUserThrottling(BasePermission):
|
||||
class BaseThrottle(BasePermission):
|
||||
"""
|
||||
Rate throttling of requests on a per-user basis.
|
||||
Rate throttling of requests.
|
||||
|
||||
The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
|
||||
The attribute is a two tuple of the form (number of requests, duration in seconds).
|
||||
|
||||
The user id will be used as a unique identifier if the user is authenticated.
|
||||
For anonymous requests, the IP address of the client will be used.
|
||||
The attribute is a string of the form 'number of requests/period'. Period must be an element
|
||||
of (sec, min, hour, day)
|
||||
|
||||
Previous request information used for throttling is stored in the cache.
|
||||
"""
|
||||
|
||||
def check_permission(self, user):
|
||||
(num_requests, duration) = getattr(self.view, 'throttle', (0, 0))
|
||||
def get_cache_key(self):
|
||||
"""Should return the cache-key corresponding to the semantics of the class that implements
|
||||
the throttling behaviour.
|
||||
"""
|
||||
pass
|
||||
|
||||
if user.is_authenticated():
|
||||
ident = str(user)
|
||||
def check_permission(self, auth):
|
||||
num, period = getattr(self.view, 'throttle', '0/sec').split('/')
|
||||
self.num_requests = int(num)
|
||||
self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
|
||||
self.auth = auth
|
||||
self.check_throttle()
|
||||
|
||||
def check_throttle(self):
|
||||
"""On success calls `throttle_success`. On failure calls `throttle_failure`. """
|
||||
self.key = self.get_cache_key()
|
||||
self.history = cache.get(self.key, [])
|
||||
self.now = time.time()
|
||||
|
||||
# Drop any requests from the history which have now passed the throttle duration
|
||||
while self.history and self.history[0] < self.now - self.duration:
|
||||
self.history.pop()
|
||||
|
||||
if len(self.history) >= self.num_requests:
|
||||
self.throttle_failure()
|
||||
else:
|
||||
self.throttle_success()
|
||||
|
||||
def throttle_success(self):
|
||||
"""Inserts the current request's timesatmp along with the key into the cache."""
|
||||
self.history.insert(0, self.now)
|
||||
cache.set(self.key, self.history, self.duration)
|
||||
|
||||
def throttle_failure(self):
|
||||
"""Raises a 503 """
|
||||
raise _503_THROTTLED_RESPONSE
|
||||
|
||||
class PerUserThrottling(BaseThrottle):
|
||||
"""
|
||||
The user id will be used as a unique identifier if the user is authenticated.
|
||||
For anonymous requests, the IP address of the client will be used.
|
||||
"""
|
||||
|
||||
def get_cache_key(self):
|
||||
if self.auth.is_authenticated():
|
||||
ident = str(self.auth)
|
||||
else:
|
||||
ident = self.view.request.META.get('REMOTE_ADDR', None)
|
||||
return 'throttle_%s' % ident
|
||||
|
||||
key = 'throttle_%s' % ident
|
||||
history = cache.get(key, [])
|
||||
now = time.time()
|
||||
|
||||
# Drop any requests from the history which have now passed the throttle duration
|
||||
while history and history[0] < now - duration:
|
||||
history.pop()
|
||||
|
||||
if len(history) >= num_requests:
|
||||
raise _503_THROTTLED_RESPONSE
|
||||
|
||||
history.insert(0, now)
|
||||
cache.set(key, history, duration)
|
||||
|
||||
class PerResourceThrottling(BasePermission):
|
||||
class PerViewThrottling(BaseThrottle):
|
||||
"""
|
||||
Rate throttling of requests on a per-resource basis.
|
||||
|
||||
The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
|
||||
The attribute is a two tuple of the form (number of requests, duration in seconds).
|
||||
|
||||
The user id will be used as a unique identifier if the user is authenticated.
|
||||
For anonymous requests, the IP address of the client will be used.
|
||||
|
||||
Previous request information used for throttling is stored in the cache.
|
||||
The class name of the cuurent view will be used as a unique identifier.
|
||||
"""
|
||||
|
||||
def check_permission(self, ignore):
|
||||
(num_requests, duration) = getattr(self.view, 'throttle', (0, 0))
|
||||
def get_cache_key(self):
|
||||
return 'throttle_%s' % self.view.__class__.__name__
|
||||
|
||||
class PerResourceThrottling(BaseThrottle):
|
||||
"""
|
||||
The class name of the cuurent resource will be used as a unique identifier.
|
||||
Raises :exc:`ConfigurationException` if no resource attribute is set on the view class.
|
||||
"""
|
||||
|
||||
key = 'throttle_%s' % self.view.__class__.__name__
|
||||
|
||||
history = cache.get(key, [])
|
||||
now = time.time()
|
||||
|
||||
# Drop any requests from the history which have now passed the throttle duration
|
||||
while history and history[0] < now - duration:
|
||||
history.pop()
|
||||
|
||||
if len(history) >= num_requests:
|
||||
raise _503_THROTTLED_RESPONSE
|
||||
|
||||
history.insert(0, now)
|
||||
cache.set(key, history, duration)
|
||||
def get_cache_key(self):
|
||||
if self.view.resource != None:
|
||||
return 'throttle_%s' % self.view.resource.__class__.__name__
|
||||
raise ConfigurationException(
|
||||
"A per-resource throttle was set to a view that does not have a resource.")
|
Loading…
Reference in New Issue
Block a user