diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index a9b3f08b8..b3fd212b8 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -122,3 +122,35 @@ class PerUserThrottling(BasePermission): history.insert(0, now) cache.set(key, history, duration) + +class PerResourceThrottling(BasePermission): + """ + Rate throttling of requests on a per-resource basis. + + The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class. + The attribute is a two tuple of the form (number of requests, duration in seconds). + + The user id will be used as a unique identifier if the user is authenticated. + For anonymous requests, the IP address of the client will be used. + + Previous request information used for throttling is stored in the cache. + """ + + def check_permission(self, ignore): + (num_requests, duration) = getattr(self.view, 'throttle', (0, 0)) + + + key = 'throttle_%s' % self.view.__class__.__name__ + + history = cache.get(key, []) + now = time.time() + + # Drop any requests from the history which have now passed the throttle duration + while history and history[0] < now - duration: + history.pop() + + if len(history) >= num_requests: + raise _503_THROTTLED_RESPONSE + + history.insert(0, now) + cache.set(key, history, duration) diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py index b33836a24..a9e6803be 100644 --- a/djangorestframework/tests/throttling.py +++ b/djangorestframework/tests/throttling.py @@ -8,7 +8,7 @@ from django.core.cache import cache from djangorestframework.compat import RequestFactory from djangorestframework.views import View -from djangorestframework.permissions import PerUserThrottling +from djangorestframework.permissions import PerUserThrottling, PerResourceThrottling class MockView(View): @@ -18,8 +18,16 @@ class MockView(View): def get(self, request): return 'foo' +class MockView1(View): + permissions = ( PerResourceThrottling, ) + throttle = (3, 1) # 3 requests per second + + def get(self, request): + return 'foo' + urlpatterns = patterns('', (r'^$', MockView.as_view()), + (r'^1$', MockView1.as_view()), ) class ThrottlingTests(TestCase): @@ -37,7 +45,6 @@ class ThrottlingTests(TestCase): self.assertEqual(503, response.status_code) def test_request_throttling_is_per_user(self): - #Can not login user.....Dunno why... """Ensure request rate is only limited per user, not globally""" for username in ('testuser', 'another_testuser'): user = User.objects.create(username=username) @@ -49,7 +56,23 @@ class ThrottlingTests(TestCase): response = self.client.get('/') self.client.logout() self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed') + response = self.client.get('/') self.assertEqual(200, response.status_code) + + def test_request_throttling_is_per_resource(self): + """Ensure request rate is limited globally per View""" + for username in ('testuser', 'another_testuser'): + user = User.objects.create(username=username) + user.set_password('test') + user.save() + + self.assertTrue(self.client.login(username='testuser', password='test'), msg='Login Failed') + for dummy in range(3): + response = self.client.get('/1') + self.client.logout() + self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed') + response = self.client.get('/1') + self.assertEqual(503, response.status_code) def test_request_throttling_expires(self): """Ensure request rate is limited for a limited duration only"""