mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-10-31 16:07:38 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			516 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			516 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Tests for the throttling implementations in the permissions module.
 | |
| """
 | |
| 
 | |
| import pytest
 | |
| from django.contrib.auth.models import User
 | |
| from django.core.cache import cache
 | |
| from django.core.exceptions import ImproperlyConfigured
 | |
| from django.http import HttpRequest
 | |
| from django.test import TestCase
 | |
| 
 | |
| from rest_framework.request import Request
 | |
| from rest_framework.response import Response
 | |
| from rest_framework.settings import api_settings
 | |
| from rest_framework.test import APIRequestFactory, force_authenticate
 | |
| from rest_framework.throttling import (
 | |
|     AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle,
 | |
|     UserRateThrottle
 | |
| )
 | |
| from rest_framework.views import APIView
 | |
| 
 | |
| 
 | |
| class User3SecRateThrottle(UserRateThrottle):
 | |
|     rate = '3/sec'
 | |
|     scope = 'seconds'
 | |
| 
 | |
| 
 | |
| class User3MinRateThrottle(UserRateThrottle):
 | |
|     rate = '3/min'
 | |
|     scope = 'minutes'
 | |
| 
 | |
| 
 | |
| class User6MinRateThrottle(UserRateThrottle):
 | |
|     rate = '6/min'
 | |
|     scope = 'minutes'
 | |
| 
 | |
| 
 | |
| class NonTimeThrottle(BaseThrottle):
 | |
|     def allow_request(self, request, view):
 | |
|         if not hasattr(self.__class__, 'called'):
 | |
|             self.__class__.called = True
 | |
|             return True
 | |
|         return False
 | |
| 
 | |
| 
 | |
| class MockView_DoubleThrottling(APIView):
 | |
|     throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
 | |
| 
 | |
|     def get(self, request):
 | |
|         return Response('foo')
 | |
| 
 | |
| 
 | |
| class MockView(APIView):
 | |
|     throttle_classes = (User3SecRateThrottle,)
 | |
| 
 | |
|     def get(self, request):
 | |
|         return Response('foo')
 | |
| 
 | |
| 
 | |
| class MockView_MinuteThrottling(APIView):
 | |
|     throttle_classes = (User3MinRateThrottle,)
 | |
| 
 | |
|     def get(self, request):
 | |
|         return Response('foo')
 | |
| 
 | |
| 
 | |
| class MockView_NonTimeThrottling(APIView):
 | |
|     throttle_classes = (NonTimeThrottle,)
 | |
| 
 | |
|     def get(self, request):
 | |
|         return Response('foo')
 | |
| 
 | |
| 
 | |
| class ThrottlingTests(TestCase):
 | |
|     def setUp(self):
 | |
|         """
 | |
|         Reset the cache so that no throttles will be active
 | |
|         """
 | |
|         cache.clear()
 | |
|         self.factory = APIRequestFactory()
 | |
| 
 | |
|     def test_requests_are_throttled(self):
 | |
|         """
 | |
|         Ensure request rate is limited
 | |
|         """
 | |
|         request = self.factory.get('/')
 | |
|         for dummy in range(4):
 | |
|             response = MockView.as_view()(request)
 | |
|         assert response.status_code == 429
 | |
| 
 | |
|     def set_throttle_timer(self, view, value):
 | |
|         """
 | |
|         Explicitly set the timer, overriding time.time()
 | |
|         """
 | |
|         for cls in view.throttle_classes:
 | |
|             cls.timer = lambda self: value
 | |
| 
 | |
|     def test_request_throttling_expires(self):
 | |
|         """
 | |
|         Ensure request rate is limited for a limited duration only
 | |
|         """
 | |
|         self.set_throttle_timer(MockView, 0)
 | |
| 
 | |
|         request = self.factory.get('/')
 | |
|         for dummy in range(4):
 | |
|             response = MockView.as_view()(request)
 | |
|         assert response.status_code == 429
 | |
| 
 | |
|         # Advance the timer by one second
 | |
|         self.set_throttle_timer(MockView, 1)
 | |
| 
 | |
|         response = MockView.as_view()(request)
 | |
|         assert response.status_code == 200
 | |
| 
 | |
|     def ensure_is_throttled(self, view, expect):
 | |
|         request = self.factory.get('/')
 | |
|         request.user = User.objects.create(username='a')
 | |
|         for dummy in range(3):
 | |
|             view.as_view()(request)
 | |
|         request.user = User.objects.create(username='b')
 | |
|         response = view.as_view()(request)
 | |
|         assert response.status_code == expect
 | |
| 
 | |
|     def test_request_throttling_is_per_user(self):
 | |
|         """
 | |
|         Ensure request rate is only limited per user, not globally for
 | |
|         PerUserThrottles
 | |
|         """
 | |
|         self.ensure_is_throttled(MockView, 200)
 | |
| 
 | |
|     def test_request_throttling_multiple_throttles(self):
 | |
|         """
 | |
|         Ensure all throttle classes see each request even when the request is
 | |
|         already being throttled
 | |
|         """
 | |
|         self.set_throttle_timer(MockView_DoubleThrottling, 0)
 | |
|         request = self.factory.get('/')
 | |
|         for dummy in range(4):
 | |
|             response = MockView_DoubleThrottling.as_view()(request)
 | |
|         assert response.status_code == 429
 | |
|         assert int(response['retry-after']) == 1
 | |
| 
 | |
|         # At this point our client made 4 requests (one was throttled) in a
 | |
|         # second. If we advance the timer by one additional second, the client
 | |
|         # should be allowed to make 2 more before being throttled by the 2nd
 | |
|         # throttle class, which has a limit of 6 per minute.
 | |
|         self.set_throttle_timer(MockView_DoubleThrottling, 1)
 | |
|         for dummy in range(2):
 | |
|             response = MockView_DoubleThrottling.as_view()(request)
 | |
|             assert response.status_code == 200
 | |
| 
 | |
|         response = MockView_DoubleThrottling.as_view()(request)
 | |
|         assert response.status_code == 429
 | |
|         assert int(response['retry-after']) == 59
 | |
| 
 | |
|         # Just to make sure check again after two more seconds.
 | |
|         self.set_throttle_timer(MockView_DoubleThrottling, 2)
 | |
|         response = MockView_DoubleThrottling.as_view()(request)
 | |
|         assert response.status_code == 429
 | |
|         assert int(response['retry-after']) == 58
 | |
| 
 | |
|     def test_throttle_rate_change_negative(self):
 | |
|         self.set_throttle_timer(MockView_DoubleThrottling, 0)
 | |
|         request = self.factory.get('/')
 | |
|         for dummy in range(24):
 | |
|             response = MockView_DoubleThrottling.as_view()(request)
 | |
|         assert response.status_code == 429
 | |
|         assert int(response['retry-after']) == 60
 | |
| 
 | |
|         previous_rate = User3SecRateThrottle.rate
 | |
|         try:
 | |
|             User3SecRateThrottle.rate = '1/sec'
 | |
| 
 | |
|             for dummy in range(24):
 | |
|                 response = MockView_DoubleThrottling.as_view()(request)
 | |
| 
 | |
|             assert response.status_code == 429
 | |
|             assert int(response['retry-after']) == 60
 | |
|         finally:
 | |
|             # reset
 | |
|             User3SecRateThrottle.rate = previous_rate
 | |
| 
 | |
|     def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
 | |
|         """
 | |
|         Ensure the response returns an Retry-After field with status and next attributes
 | |
|         set properly.
 | |
|         """
 | |
|         request = self.factory.get('/')
 | |
|         for timer, expect in expected_headers:
 | |
|             self.set_throttle_timer(view, timer)
 | |
|             response = view.as_view()(request)
 | |
|             if expect is not None:
 | |
|                 assert response['Retry-After'] == expect
 | |
|             else:
 | |
|                 assert 'Retry-After' not in response
 | |
| 
 | |
|     def test_seconds_fields(self):
 | |
|         """
 | |
|         Ensure for second based throttles.
 | |
|         """
 | |
|         self.ensure_response_header_contains_proper_throttle_field(
 | |
|             MockView, (
 | |
|                 (0, None),
 | |
|                 (0, None),
 | |
|                 (0, None),
 | |
|                 (0, '1')
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_minutes_fields(self):
 | |
|         """
 | |
|         Ensure for minute based throttles.
 | |
|         """
 | |
|         self.ensure_response_header_contains_proper_throttle_field(
 | |
|             MockView_MinuteThrottling, (
 | |
|                 (0, None),
 | |
|                 (0, None),
 | |
|                 (0, None),
 | |
|                 (0, '60')
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_next_rate_remains_constant_if_followed(self):
 | |
|         """
 | |
|         If a client follows the recommended next request rate,
 | |
|         the throttling rate should stay constant.
 | |
|         """
 | |
|         self.ensure_response_header_contains_proper_throttle_field(
 | |
|             MockView_MinuteThrottling, (
 | |
|                 (0, None),
 | |
|                 (20, None),
 | |
|                 (40, None),
 | |
|                 (60, None),
 | |
|                 (80, None)
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def test_non_time_throttle(self):
 | |
|         """
 | |
|         Ensure for second based throttles.
 | |
|         """
 | |
|         request = self.factory.get('/')
 | |
| 
 | |
|         self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
 | |
| 
 | |
|         response = MockView_NonTimeThrottling.as_view()(request)
 | |
|         self.assertFalse('Retry-After' in response)
 | |
| 
 | |
|         self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
 | |
| 
 | |
|         response = MockView_NonTimeThrottling.as_view()(request)
 | |
|         self.assertFalse('Retry-After' in response)
 | |
| 
 | |
| 
 | |
| class ScopedRateThrottleTests(TestCase):
 | |
|     """
 | |
|     Tests for ScopedRateThrottle.
 | |
|     """
 | |
| 
 | |
|     def setUp(self):
 | |
|         self.throttle = ScopedRateThrottle()
 | |
| 
 | |
|         class XYScopedRateThrottle(ScopedRateThrottle):
 | |
|             TIMER_SECONDS = 0
 | |
|             THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
 | |
| 
 | |
|             def timer(self):
 | |
|                 return self.TIMER_SECONDS
 | |
| 
 | |
|         class XView(APIView):
 | |
|             throttle_classes = (XYScopedRateThrottle,)
 | |
|             throttle_scope = 'x'
 | |
| 
 | |
|             def get(self, request):
 | |
|                 return Response('x')
 | |
| 
 | |
|         class YView(APIView):
 | |
|             throttle_classes = (XYScopedRateThrottle,)
 | |
|             throttle_scope = 'y'
 | |
| 
 | |
|             def get(self, request):
 | |
|                 return Response('y')
 | |
| 
 | |
|         class UnscopedView(APIView):
 | |
|             throttle_classes = (XYScopedRateThrottle,)
 | |
| 
 | |
|             def get(self, request):
 | |
|                 return Response('y')
 | |
| 
 | |
|         self.throttle_class = XYScopedRateThrottle
 | |
|         self.factory = APIRequestFactory()
 | |
|         self.x_view = XView.as_view()
 | |
|         self.y_view = YView.as_view()
 | |
|         self.unscoped_view = UnscopedView.as_view()
 | |
| 
 | |
|     def increment_timer(self, seconds=1):
 | |
|         self.throttle_class.TIMER_SECONDS += seconds
 | |
| 
 | |
|     def test_scoped_rate_throttle(self):
 | |
|         request = self.factory.get('/')
 | |
| 
 | |
|         # Should be able to hit x view 3 times per minute.
 | |
|         response = self.x_view(request)
 | |
|         assert response.status_code == 200
 | |
| 
 | |
|         self.increment_timer()
 | |
|         response = self.x_view(request)
 | |
|         assert response.status_code == 200
 | |
| 
 | |
|         self.increment_timer()
 | |
|         response = self.x_view(request)
 | |
|         assert response.status_code == 200
 | |
|         self.increment_timer()
 | |
|         response = self.x_view(request)
 | |
|         assert response.status_code == 429
 | |
| 
 | |
|         # Should be able to hit y view 1 time per minute.
 | |
|         self.increment_timer()
 | |
|         response = self.y_view(request)
 | |
|         assert response.status_code == 200
 | |
| 
 | |
|         self.increment_timer()
 | |
|         response = self.y_view(request)
 | |
|         assert response.status_code == 429
 | |
| 
 | |
|         # Ensure throttles properly reset by advancing the rest of the minute
 | |
|         self.increment_timer(55)
 | |
| 
 | |
|         # Should still be able to hit x view 3 times per minute.
 | |
|         response = self.x_view(request)
 | |
|         assert response.status_code == 200
 | |
| 
 | |
|         self.increment_timer()
 | |
|         response = self.x_view(request)
 | |
|         assert response.status_code == 200
 | |
| 
 | |
|         self.increment_timer()
 | |
|         response = self.x_view(request)
 | |
|         assert response.status_code == 200
 | |
| 
 | |
|         self.increment_timer()
 | |
|         response = self.x_view(request)
 | |
|         assert response.status_code == 429
 | |
| 
 | |
|         # Should still be able to hit y view 1 time per minute.
 | |
|         self.increment_timer()
 | |
|         response = self.y_view(request)
 | |
|         assert response.status_code == 200
 | |
| 
 | |
|         self.increment_timer()
 | |
|         response = self.y_view(request)
 | |
|         assert response.status_code == 429
 | |
| 
 | |
|     def test_unscoped_view_not_throttled(self):
 | |
|         request = self.factory.get('/')
 | |
| 
 | |
|         for idx in range(10):
 | |
|             self.increment_timer()
 | |
|             response = self.unscoped_view(request)
 | |
|             assert response.status_code == 200
 | |
| 
 | |
|     def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self):
 | |
|         class DummyView:
 | |
|             throttle_scope = 'user'
 | |
| 
 | |
|         request = Request(HttpRequest())
 | |
|         user = User.objects.create(username='test')
 | |
|         force_authenticate(request, user)
 | |
|         request.user = user
 | |
|         self.throttle.allow_request(request, DummyView())
 | |
|         cache_key = self.throttle.get_cache_key(request, view=DummyView())
 | |
|         assert cache_key == 'throttle_user_%s' % user.pk
 | |
| 
 | |
| 
 | |
| class XffTestingBase(TestCase):
 | |
|     def setUp(self):
 | |
| 
 | |
|         class Throttle(ScopedRateThrottle):
 | |
|             THROTTLE_RATES = {'test_limit': '1/day'}
 | |
|             TIMER_SECONDS = 0
 | |
| 
 | |
|             def timer(self):
 | |
|                 return self.TIMER_SECONDS
 | |
| 
 | |
|         class View(APIView):
 | |
|             throttle_classes = (Throttle,)
 | |
|             throttle_scope = 'test_limit'
 | |
| 
 | |
|             def get(self, request):
 | |
|                 return Response('test_limit')
 | |
| 
 | |
|         cache.clear()
 | |
|         self.throttle = Throttle()
 | |
|         self.view = View.as_view()
 | |
|         self.request = APIRequestFactory().get('/some_uri')
 | |
|         self.request.META['REMOTE_ADDR'] = '3.3.3.3'
 | |
|         self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'
 | |
| 
 | |
|     def config_proxy(self, num_proxies):
 | |
|         setattr(api_settings, 'NUM_PROXIES', num_proxies)
 | |
| 
 | |
| 
 | |
| class IdWithXffBasicTests(XffTestingBase):
 | |
|     def test_accepts_request_under_limit(self):
 | |
|         self.config_proxy(0)
 | |
|         assert self.view(self.request).status_code == 200
 | |
| 
 | |
|     def test_denies_request_over_limit(self):
 | |
|         self.config_proxy(0)
 | |
|         self.view(self.request)
 | |
|         assert self.view(self.request).status_code == 429
 | |
| 
 | |
| 
 | |
| class XffSpoofingTests(XffTestingBase):
 | |
|     def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
 | |
|         self.config_proxy(1)
 | |
|         self.view(self.request)
 | |
|         self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
 | |
|         assert self.view(self.request).status_code == 429
 | |
| 
 | |
|     def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
 | |
|         self.config_proxy(2)
 | |
|         self.view(self.request)
 | |
|         self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
 | |
|         assert self.view(self.request).status_code == 429
 | |
| 
 | |
| 
 | |
| class XffUniqueMachinesTest(XffTestingBase):
 | |
|     def test_unique_clients_are_counted_independently_with_one_proxy(self):
 | |
|         self.config_proxy(1)
 | |
|         self.view(self.request)
 | |
|         self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
 | |
|         assert self.view(self.request).status_code == 200
 | |
| 
 | |
|     def test_unique_clients_are_counted_independently_with_two_proxies(self):
 | |
|         self.config_proxy(2)
 | |
|         self.view(self.request)
 | |
|         self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
 | |
|         assert self.view(self.request).status_code == 200
 | |
| 
 | |
| 
 | |
| class BaseThrottleTests(TestCase):
 | |
| 
 | |
|     def test_allow_request_raises_not_implemented_error(self):
 | |
|         with pytest.raises(NotImplementedError):
 | |
|             BaseThrottle().allow_request(request={}, view={})
 | |
| 
 | |
| 
 | |
| class SimpleRateThrottleTests(TestCase):
 | |
| 
 | |
|     def setUp(self):
 | |
|         SimpleRateThrottle.scope = 'anon'
 | |
| 
 | |
|     def test_get_rate_raises_error_if_scope_is_missing(self):
 | |
|         throttle = SimpleRateThrottle()
 | |
|         with pytest.raises(ImproperlyConfigured):
 | |
|             throttle.scope = None
 | |
|             throttle.get_rate()
 | |
| 
 | |
|     def test_throttle_raises_error_if_rate_is_missing(self):
 | |
|         SimpleRateThrottle.scope = 'invalid scope'
 | |
|         with pytest.raises(ImproperlyConfigured):
 | |
|             SimpleRateThrottle()
 | |
| 
 | |
|     def test_parse_rate_returns_tuple_with_none_if_rate_not_provided(self):
 | |
|         rate = SimpleRateThrottle().parse_rate(None)
 | |
|         assert rate == (None, None)
 | |
| 
 | |
|     def test_allow_request_returns_true_if_rate_is_none(self):
 | |
|         assert SimpleRateThrottle().allow_request(request={}, view={}) is True
 | |
| 
 | |
|     def test_get_cache_key_raises_not_implemented_error(self):
 | |
|         with pytest.raises(NotImplementedError):
 | |
|             SimpleRateThrottle().get_cache_key({}, {})
 | |
| 
 | |
|     def test_allow_request_returns_true_if_key_is_none(self):
 | |
|         throttle = SimpleRateThrottle()
 | |
|         throttle.rate = 'some rate'
 | |
|         throttle.get_cache_key = lambda *args: None
 | |
|         assert throttle.allow_request(request={}, view={}) is True
 | |
| 
 | |
|     def test_wait_returns_correct_waiting_time_without_history(self):
 | |
|         throttle = SimpleRateThrottle()
 | |
|         throttle.num_requests = 1
 | |
|         throttle.duration = 60
 | |
|         throttle.history = []
 | |
|         waiting_time = throttle.wait()
 | |
|         assert isinstance(waiting_time, float)
 | |
|         assert waiting_time == 30.0
 | |
| 
 | |
|     def test_wait_returns_none_if_there_are_no_available_requests(self):
 | |
|         throttle = SimpleRateThrottle()
 | |
|         throttle.num_requests = 1
 | |
|         throttle.duration = 60
 | |
|         throttle.now = throttle.timer()
 | |
|         throttle.history = [throttle.timer() for _ in range(3)]
 | |
|         assert throttle.wait() is None
 | |
| 
 | |
| 
 | |
| class AnonRateThrottleTests(TestCase):
 | |
| 
 | |
|     def setUp(self):
 | |
|         self.throttle = AnonRateThrottle()
 | |
| 
 | |
|     def test_authenticated_user_not_affected(self):
 | |
|         request = Request(HttpRequest())
 | |
|         user = User.objects.create(username='test')
 | |
|         force_authenticate(request, user)
 | |
|         request.user = user
 | |
|         assert self.throttle.get_cache_key(request, view={}) is None
 | |
| 
 | |
|     def test_get_cache_key_returns_correct_value(self):
 | |
|         request = Request(HttpRequest())
 | |
|         cache_key = self.throttle.get_cache_key(request, view={})
 | |
|         assert cache_key == 'throttle_anon_None'
 |