mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-11 04:07:39 +03:00
implemeneted #28
This commit is contained in:
parent
1720c44904
commit
437a062b6c
|
@ -31,11 +31,6 @@ _503_SERVICE_UNAVAILABLE = ErrorResponse(
|
|||
{'detail': 'request was throttled'})
|
||||
|
||||
|
||||
class ConfigurationException(BaseException):
|
||||
"""To alert for bad configuration decisions as a convenience."""
|
||||
pass
|
||||
|
||||
|
||||
class BasePermission(object):
|
||||
"""
|
||||
A base class from which all permission classes should inherit.
|
||||
|
@ -144,12 +139,11 @@ class BaseThrottle(BasePermission):
|
|||
# throttle duration
|
||||
while self.history and self.history[0] <= self.now - self.duration:
|
||||
self.history.pop()
|
||||
|
||||
if len(self.history) >= self.num_requests:
|
||||
self.throttle_failure()
|
||||
else:
|
||||
self.throttle_success()
|
||||
|
||||
|
||||
def throttle_success(self):
|
||||
"""
|
||||
Inserts the current request's timestamp along with the key
|
||||
|
@ -157,15 +151,23 @@ class BaseThrottle(BasePermission):
|
|||
"""
|
||||
self.history.insert(0, self.now)
|
||||
cache.set(self.key, self.history, self.duration)
|
||||
|
||||
self.view.add_header('X-Throttle', 'status=SUCCESS; next=%s sec' % self.next())
|
||||
|
||||
def throttle_failure(self):
|
||||
"""
|
||||
Called when a request to the API has failed due to throttling.
|
||||
Raises a '503 service unavailable' response.
|
||||
"""
|
||||
self.view.add_header('X-Throttle', 'status=FAILURE; next=%s sec' % self.next())
|
||||
raise _503_SERVICE_UNAVAILABLE
|
||||
|
||||
|
||||
|
||||
def next(self):
|
||||
"""
|
||||
Returns the recommended next request time in seconds.
|
||||
"""
|
||||
return '%.2f' % (self.duration / (self.num_requests - len(self.history) *1.0 + 1))
|
||||
|
||||
|
||||
class PerUserThrottling(BaseThrottle):
|
||||
"""
|
||||
Limits the rate of API calls that may be made by a given user.
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
"""
|
||||
Tests for the throttling implementations in the permissions module.
|
||||
"""
|
||||
import time
|
||||
|
||||
from django.conf.urls.defaults import patterns
|
||||
from django.test import TestCase
|
||||
from django.utils import simplejson as json
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.cache import cache
|
||||
|
||||
from djangorestframework.compat import RequestFactory
|
||||
from djangorestframework.views import View
|
||||
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling, ConfigurationException
|
||||
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling, PerResourceThrottling
|
||||
from djangorestframework.resources import FormResource
|
||||
|
||||
class MockView(View):
|
||||
|
@ -30,28 +27,40 @@ class MockView2(MockView):
|
|||
|
||||
class MockView3(MockView2):
|
||||
resource = FormResource
|
||||
|
||||
class MockView4(MockView):
|
||||
throttle = '3/min' # 3 request per minute
|
||||
|
||||
class ThrottlingTests(TestCase):
|
||||
urls = 'djangorestframework.tests.throttling'
|
||||
|
||||
def setUp(self):
|
||||
"""Reset the cache so that no throttles will be active"""
|
||||
"""
|
||||
Reset the cache so that no throttles will be active
|
||||
"""
|
||||
cache.clear()
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def test_requests_are_throttled(self):
|
||||
"""Ensure request rate is limited"""
|
||||
"""
|
||||
Ensure request rate is limited
|
||||
"""
|
||||
request = self.factory.get('/')
|
||||
for dummy in range(4):
|
||||
response = MockView.as_view()(request)
|
||||
self.assertEqual(503, response.status_code)
|
||||
|
||||
def set_throttle_timer(self, view, value):
|
||||
"""
|
||||
Explicitly set the timer, overriding time.time()
|
||||
"""
|
||||
view.permissions[0].timer = lambda self: value
|
||||
|
||||
def test_request_throttling_expires(self):
|
||||
"""
|
||||
Ensure request rate is limited for a limited duration only
|
||||
"""
|
||||
# Explicitly set the timer, overridding time.time()
|
||||
MockView.permissions[0].timer = lambda self: 0
|
||||
self.set_throttle_timer(MockView, 0)
|
||||
|
||||
request = self.factory.get('/')
|
||||
for dummy in range(4):
|
||||
|
@ -59,7 +68,7 @@ class ThrottlingTests(TestCase):
|
|||
self.assertEqual(503, response.status_code)
|
||||
|
||||
# Advance the timer by one second
|
||||
MockView.permissions[0].timer = lambda self: 1
|
||||
self.set_throttle_timer(MockView, 1)
|
||||
|
||||
response = MockView.as_view()(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
|
@ -68,20 +77,61 @@ class ThrottlingTests(TestCase):
|
|||
request = self.factory.get('/')
|
||||
request.user = User.objects.create(username='a')
|
||||
for dummy in range(3):
|
||||
response = view.as_view()(request)
|
||||
view.as_view()(request)
|
||||
request.user = User.objects.create(username='b')
|
||||
response = view.as_view()(request)
|
||||
self.assertEqual(expect, response.status_code)
|
||||
|
||||
def test_request_throttling_is_per_user(self):
|
||||
"""Ensure request rate is only limited per user, not globally for PerUserThrottles"""
|
||||
"""
|
||||
Ensure request rate is only limited per user, not globally for
|
||||
PerUserThrottles
|
||||
"""
|
||||
self.ensure_is_throttled(MockView, 200)
|
||||
|
||||
def test_request_throttling_is_per_view(self):
|
||||
"""Ensure request rate is limited globally per View for PerViewThrottles"""
|
||||
"""
|
||||
Ensure request rate is limited globally per View for PerViewThrottles
|
||||
"""
|
||||
self.ensure_is_throttled(MockView1, 503)
|
||||
|
||||
def test_request_throttling_is_per_resource(self):
|
||||
"""Ensure request rate is limited globally per Resource for PerResourceThrottles"""
|
||||
"""
|
||||
Ensure request rate is limited globally per Resource for PerResourceThrottles
|
||||
"""
|
||||
self.ensure_is_throttled(MockView3, 503)
|
||||
|
||||
|
||||
|
||||
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
|
||||
"""
|
||||
Ensure the response returns an X-Throttle field with status and next attributes
|
||||
set properly.
|
||||
"""
|
||||
request = self.factory.get('/')
|
||||
for expect in expected_headers:
|
||||
self.set_throttle_timer(view, 0)
|
||||
response = view.as_view()(request)
|
||||
self.assertEquals(response['X-Throttle'], expect)
|
||||
|
||||
def test_seconds_fields(self):
|
||||
"""
|
||||
Ensure for second based throttles.
|
||||
"""
|
||||
self.ensure_response_header_contains_proper_throttle_field(MockView,
|
||||
('status=SUCCESS; next=0.33 sec',
|
||||
'status=SUCCESS; next=0.50 sec',
|
||||
'status=SUCCESS; next=1.00 sec',
|
||||
'status=FAILURE; next=1.00 sec'
|
||||
))
|
||||
|
||||
def test_minutes_fields(self):
|
||||
"""
|
||||
Ensure for minute based throttles.
|
||||
"""
|
||||
self.ensure_response_header_contains_proper_throttle_field(MockView4,
|
||||
('status=SUCCESS; next=20.00 sec',
|
||||
'status=SUCCESS; next=30.00 sec',
|
||||
'status=SUCCESS; next=60.00 sec',
|
||||
'status=FAILURE; next=60.00 sec'
|
||||
))
|
||||
|
|
@ -64,7 +64,11 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
|
|||
"""
|
||||
permissions = ( permissions.FullAnonAccess, )
|
||||
|
||||
|
||||
"""
|
||||
Headers to be sent with response.
|
||||
"""
|
||||
headers = {}
|
||||
|
||||
@classmethod
|
||||
def as_view(cls, **initkwargs):
|
||||
"""
|
||||
|
@ -101,6 +105,12 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
|
|||
"""
|
||||
pass
|
||||
|
||||
def add_header(self, field, value):
|
||||
"""
|
||||
Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.
|
||||
"""
|
||||
self.headers[field] = value
|
||||
|
||||
# Note: session based authentication is explicitly CSRF validated,
|
||||
# all other authentication is CSRF exempt.
|
||||
@csrf_exempt
|
||||
|
@ -149,7 +159,10 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
|
|||
# also it's currently sub-obtimal for HTTP caching - need to sort that out.
|
||||
response.headers['Allow'] = ', '.join(self.allowed_methods)
|
||||
response.headers['Vary'] = 'Authenticate, Accept'
|
||||
|
||||
|
||||
# merge with headers possibly set by a Throttle class
|
||||
response.headers = dict(response.headers.items() + self.headers.items())
|
||||
|
||||
return self.render(response)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user