mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-05-29 18:53:17 +03:00
* fixed test_request_throttling_is_per_user
- it didn't make a request for the 2nd user
* implemented per_resource_throttling + test needs refactoring
This commit is contained in:
parent
87db5fbda5
commit
f854bc9065
|
@ -122,3 +122,35 @@ class PerUserThrottling(BasePermission):
|
||||||
|
|
||||||
history.insert(0, now)
|
history.insert(0, now)
|
||||||
cache.set(key, history, duration)
|
cache.set(key, history, duration)
|
||||||
|
|
||||||
|
class PerResourceThrottling(BasePermission):
|
||||||
|
"""
|
||||||
|
Rate throttling of requests on a per-resource basis.
|
||||||
|
|
||||||
|
The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
|
||||||
|
The attribute is a two tuple of the form (number of requests, duration in seconds).
|
||||||
|
|
||||||
|
The user id will be used as a unique identifier if the user is authenticated.
|
||||||
|
For anonymous requests, the IP address of the client will be used.
|
||||||
|
|
||||||
|
Previous request information used for throttling is stored in the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def check_permission(self, ignore):
|
||||||
|
(num_requests, duration) = getattr(self.view, 'throttle', (0, 0))
|
||||||
|
|
||||||
|
|
||||||
|
key = 'throttle_%s' % self.view.__class__.__name__
|
||||||
|
|
||||||
|
history = cache.get(key, [])
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Drop any requests from the history which have now passed the throttle duration
|
||||||
|
while history and history[0] < now - duration:
|
||||||
|
history.pop()
|
||||||
|
|
||||||
|
if len(history) >= num_requests:
|
||||||
|
raise _503_THROTTLED_RESPONSE
|
||||||
|
|
||||||
|
history.insert(0, now)
|
||||||
|
cache.set(key, history, duration)
|
||||||
|
|
|
@ -8,7 +8,7 @@ from django.core.cache import cache
|
||||||
|
|
||||||
from djangorestframework.compat import RequestFactory
|
from djangorestframework.compat import RequestFactory
|
||||||
from djangorestframework.views import View
|
from djangorestframework.views import View
|
||||||
from djangorestframework.permissions import PerUserThrottling
|
from djangorestframework.permissions import PerUserThrottling, PerResourceThrottling
|
||||||
|
|
||||||
|
|
||||||
class MockView(View):
|
class MockView(View):
|
||||||
|
@ -18,8 +18,16 @@ class MockView(View):
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
return 'foo'
|
return 'foo'
|
||||||
|
|
||||||
|
class MockView1(View):
|
||||||
|
permissions = ( PerResourceThrottling, )
|
||||||
|
throttle = (3, 1) # 3 requests per second
|
||||||
|
|
||||||
|
def get(self, request):
|
||||||
|
return 'foo'
|
||||||
|
|
||||||
urlpatterns = patterns('',
|
urlpatterns = patterns('',
|
||||||
(r'^$', MockView.as_view()),
|
(r'^$', MockView.as_view()),
|
||||||
|
(r'^1$', MockView1.as_view()),
|
||||||
)
|
)
|
||||||
|
|
||||||
class ThrottlingTests(TestCase):
|
class ThrottlingTests(TestCase):
|
||||||
|
@ -37,7 +45,6 @@ class ThrottlingTests(TestCase):
|
||||||
self.assertEqual(503, response.status_code)
|
self.assertEqual(503, response.status_code)
|
||||||
|
|
||||||
def test_request_throttling_is_per_user(self):
|
def test_request_throttling_is_per_user(self):
|
||||||
#Can not login user.....Dunno why...
|
|
||||||
"""Ensure request rate is only limited per user, not globally"""
|
"""Ensure request rate is only limited per user, not globally"""
|
||||||
for username in ('testuser', 'another_testuser'):
|
for username in ('testuser', 'another_testuser'):
|
||||||
user = User.objects.create(username=username)
|
user = User.objects.create(username=username)
|
||||||
|
@ -49,7 +56,23 @@ class ThrottlingTests(TestCase):
|
||||||
response = self.client.get('/')
|
response = self.client.get('/')
|
||||||
self.client.logout()
|
self.client.logout()
|
||||||
self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
|
self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
|
||||||
|
response = self.client.get('/')
|
||||||
self.assertEqual(200, response.status_code)
|
self.assertEqual(200, response.status_code)
|
||||||
|
|
||||||
|
def test_request_throttling_is_per_resource(self):
|
||||||
|
"""Ensure request rate is limited globally per View"""
|
||||||
|
for username in ('testuser', 'another_testuser'):
|
||||||
|
user = User.objects.create(username=username)
|
||||||
|
user.set_password('test')
|
||||||
|
user.save()
|
||||||
|
|
||||||
|
self.assertTrue(self.client.login(username='testuser', password='test'), msg='Login Failed')
|
||||||
|
for dummy in range(3):
|
||||||
|
response = self.client.get('/1')
|
||||||
|
self.client.logout()
|
||||||
|
self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
|
||||||
|
response = self.client.get('/1')
|
||||||
|
self.assertEqual(503, response.status_code)
|
||||||
|
|
||||||
def test_request_throttling_expires(self):
|
def test_request_throttling_expires(self):
|
||||||
"""Ensure request rate is limited for a limited duration only"""
|
"""Ensure request rate is limited for a limited duration only"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user