mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	Always call all throttling classes on the view when checking throttles (#6711)
This commit is contained in:
		
							parent
							
								
									19ca86d8d6
								
							
						
					
					
						commit
						afb678433b
					
				| 
						 | 
					@ -350,9 +350,13 @@ class APIView(View):
 | 
				
			||||||
        Check if request should be throttled.
 | 
					        Check if request should be throttled.
 | 
				
			||||||
        Raises an appropriate exception if the request is throttled.
 | 
					        Raises an appropriate exception if the request is throttled.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        throttle_durations = []
 | 
				
			||||||
        for throttle in self.get_throttles():
 | 
					        for throttle in self.get_throttles():
 | 
				
			||||||
            if not throttle.allow_request(request, self):
 | 
					            if not throttle.allow_request(request, self):
 | 
				
			||||||
                self.throttled(request, throttle.wait())
 | 
					                throttle_durations.append(throttle.wait())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if throttle_durations:
 | 
				
			||||||
 | 
					            self.throttled(request, max(throttle_durations))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def determine_version(self, request, *args, **kwargs):
 | 
					    def determine_version(self, request, *args, **kwargs):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -30,6 +30,11 @@ class User3MinRateThrottle(UserRateThrottle):
 | 
				
			||||||
    scope = 'minutes'
 | 
					    scope = 'minutes'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class User6MinRateThrottle(UserRateThrottle):
 | 
				
			||||||
 | 
					    rate = '6/min'
 | 
				
			||||||
 | 
					    scope = 'minutes'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class NonTimeThrottle(BaseThrottle):
 | 
					class NonTimeThrottle(BaseThrottle):
 | 
				
			||||||
    def allow_request(self, request, view):
 | 
					    def allow_request(self, request, view):
 | 
				
			||||||
        if not hasattr(self.__class__, 'called'):
 | 
					        if not hasattr(self.__class__, 'called'):
 | 
				
			||||||
| 
						 | 
					@ -38,6 +43,13 @@ class NonTimeThrottle(BaseThrottle):
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MockView_DoubleThrottling(APIView):
 | 
				
			||||||
 | 
					    throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, request):
 | 
				
			||||||
 | 
					        return Response('foo')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MockView(APIView):
 | 
					class MockView(APIView):
 | 
				
			||||||
    throttle_classes = (User3SecRateThrottle,)
 | 
					    throttle_classes = (User3SecRateThrottle,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -80,7 +92,8 @@ class ThrottlingTests(TestCase):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Explicitly set the timer, overriding time.time()
 | 
					        Explicitly set the timer, overriding time.time()
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        view.throttle_classes[0].timer = lambda self: value
 | 
					        for cls in view.throttle_classes:
 | 
				
			||||||
 | 
					            cls.timer = lambda self: value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_request_throttling_expires(self):
 | 
					    def test_request_throttling_expires(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -115,6 +128,37 @@ class ThrottlingTests(TestCase):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.ensure_is_throttled(MockView, 200)
 | 
					        self.ensure_is_throttled(MockView, 200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_request_throttling_multiple_throttles(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Ensure all throttle classes see each request even when the request is
 | 
				
			||||||
 | 
					        already being throttled
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.set_throttle_timer(MockView_DoubleThrottling, 0)
 | 
				
			||||||
 | 
					        request = self.factory.get('/')
 | 
				
			||||||
 | 
					        for dummy in range(4):
 | 
				
			||||||
 | 
					            response = MockView_DoubleThrottling.as_view()(request)
 | 
				
			||||||
 | 
					        assert response.status_code == 429
 | 
				
			||||||
 | 
					        assert int(response['retry-after']) == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # At this point our client made 4 requests (one was throttled) in a
 | 
				
			||||||
 | 
					        # second. If we advance the timer by one additional second, the client
 | 
				
			||||||
 | 
					        # should be allowed to make 2 more before being throttled by the 2nd
 | 
				
			||||||
 | 
					        # throttle class, which has a limit of 6 per minute.
 | 
				
			||||||
 | 
					        self.set_throttle_timer(MockView_DoubleThrottling, 1)
 | 
				
			||||||
 | 
					        for dummy in range(2):
 | 
				
			||||||
 | 
					            response = MockView_DoubleThrottling.as_view()(request)
 | 
				
			||||||
 | 
					            assert response.status_code == 200
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = MockView_DoubleThrottling.as_view()(request)
 | 
				
			||||||
 | 
					        assert response.status_code == 429
 | 
				
			||||||
 | 
					        assert int(response['retry-after']) == 59
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Just to make sure check again after two more seconds.
 | 
				
			||||||
 | 
					        self.set_throttle_timer(MockView_DoubleThrottling, 2)
 | 
				
			||||||
 | 
					        response = MockView_DoubleThrottling.as_view()(request)
 | 
				
			||||||
 | 
					        assert response.status_code == 429
 | 
				
			||||||
 | 
					        assert int(response['retry-after']) == 58
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
 | 
					    def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Ensure the response returns an Retry-After field with status and next attributes
 | 
					        Ensure the response returns an Retry-After field with status and next attributes
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user