mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 01:47:59 +03:00 
			
		
		
		
	* fixed test_request_throttling_is_per_user - it didn't make a request for the 2nd user
				
					
				
			* implemented per_resource_throttling + test needs refactoring
This commit is contained in:
		
							parent
							
								
									87db5fbda5
								
							
						
					
					
						commit
						f854bc9065
					
				| 
						 | 
					@ -122,3 +122,35 @@ class PerUserThrottling(BasePermission):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        history.insert(0, now)
 | 
					        history.insert(0, now)
 | 
				
			||||||
        cache.set(key, history, duration)
 | 
					        cache.set(key, history, duration)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PerResourceThrottling(BasePermission):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Rate throttling of requests on a per-resource basis.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The rate (requests / seconds) is set by a :attr:`throttle` attribute on the ``View`` class.
 | 
				
			||||||
 | 
					    The attribute is a two tuple of the form (number of requests, duration in seconds).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The user id will be used as a unique identifier if the user is authenticated.
 | 
				
			||||||
 | 
					    For anonymous requests, the IP address of the client will be used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Previous request information used for throttling is stored in the cache.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_permission(self, ignore):
 | 
				
			||||||
 | 
					        (num_requests, duration) = getattr(self.view, 'throttle', (0, 0))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        key = 'throttle_%s' % self.view.__class__.__name__
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        history = cache.get(key, [])
 | 
				
			||||||
 | 
					        now = time.time()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Drop any requests from the history which have now passed the throttle duration
 | 
				
			||||||
 | 
					        while history and history[0] < now - duration:
 | 
				
			||||||
 | 
					            history.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if len(history) >= num_requests:
 | 
				
			||||||
 | 
					            raise _503_THROTTLED_RESPONSE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        history.insert(0, now)
 | 
				
			||||||
 | 
					        cache.set(key, history, duration)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,7 +8,7 @@ from django.core.cache import cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from djangorestframework.compat import RequestFactory
 | 
					from djangorestframework.compat import RequestFactory
 | 
				
			||||||
from djangorestframework.views import View
 | 
					from djangorestframework.views import View
 | 
				
			||||||
from djangorestframework.permissions import PerUserThrottling
 | 
					from djangorestframework.permissions import PerUserThrottling, PerResourceThrottling
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MockView(View):
 | 
					class MockView(View):
 | 
				
			||||||
| 
						 | 
					@ -18,8 +18,16 @@ class MockView(View):
 | 
				
			||||||
    def get(self, request):
 | 
					    def get(self, request):
 | 
				
			||||||
        return 'foo'
 | 
					        return 'foo'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockView1(View):
 | 
				
			||||||
 | 
					    permissions = ( PerResourceThrottling, )
 | 
				
			||||||
 | 
					    throttle = (3, 1) # 3 requests per second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, request):
 | 
				
			||||||
 | 
					        return 'foo'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
urlpatterns = patterns('',
 | 
					urlpatterns = patterns('',
 | 
				
			||||||
    (r'^$', MockView.as_view()),
 | 
					    (r'^$', MockView.as_view()),
 | 
				
			||||||
 | 
					    (r'^1$', MockView1.as_view()),
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ThrottlingTests(TestCase):
 | 
					class ThrottlingTests(TestCase):
 | 
				
			||||||
| 
						 | 
					@ -37,7 +45,6 @@ class ThrottlingTests(TestCase):
 | 
				
			||||||
        self.assertEqual(503, response.status_code)
 | 
					        self.assertEqual(503, response.status_code)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    def test_request_throttling_is_per_user(self):
 | 
					    def test_request_throttling_is_per_user(self):
 | 
				
			||||||
        #Can not login user.....Dunno why...
 | 
					 | 
				
			||||||
        """Ensure request rate is only limited per user, not globally"""
 | 
					        """Ensure request rate is only limited per user, not globally"""
 | 
				
			||||||
        for username in ('testuser', 'another_testuser'):
 | 
					        for username in ('testuser', 'another_testuser'):
 | 
				
			||||||
            user = User.objects.create(username=username)
 | 
					            user = User.objects.create(username=username)
 | 
				
			||||||
| 
						 | 
					@ -49,8 +56,24 @@ class ThrottlingTests(TestCase):
 | 
				
			||||||
            response = self.client.get('/')
 | 
					            response = self.client.get('/')
 | 
				
			||||||
        self.client.logout()
 | 
					        self.client.logout()
 | 
				
			||||||
        self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
 | 
					        self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
 | 
				
			||||||
 | 
					        response = self.client.get('/')
 | 
				
			||||||
        self.assertEqual(200, response.status_code)
 | 
					        self.assertEqual(200, response.status_code)
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
 | 
					    def test_request_throttling_is_per_resource(self):
 | 
				
			||||||
 | 
					        """Ensure request rate is limited globally per View"""
 | 
				
			||||||
 | 
					        for username in ('testuser', 'another_testuser'):
 | 
				
			||||||
 | 
					            user = User.objects.create(username=username)
 | 
				
			||||||
 | 
					            user.set_password('test')
 | 
				
			||||||
 | 
					            user.save()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.assertTrue(self.client.login(username='testuser', password='test'), msg='Login Failed')
 | 
				
			||||||
 | 
					        for dummy in range(3):
 | 
				
			||||||
 | 
					            response = self.client.get('/1')
 | 
				
			||||||
 | 
					        self.client.logout()
 | 
				
			||||||
 | 
					        self.assertTrue(self.client.login(username='another_testuser', password='test'), msg='Login failed')
 | 
				
			||||||
 | 
					        response = self.client.get('/1')
 | 
				
			||||||
 | 
					        self.assertEqual(503, response.status_code)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
    def test_request_throttling_expires(self):
 | 
					    def test_request_throttling_expires(self):
 | 
				
			||||||
        """Ensure request rate is limited for a limited duration only"""
 | 
					        """Ensure request rate is limited for a limited duration only"""
 | 
				
			||||||
        for dummy in range(3):
 | 
					        for dummy in range(3):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user