""" Tests for the throttling implementations in the permissions module. """ import pytest from django.contrib.auth.models import User from django.core.cache import cache from django.core.exceptions import ImproperlyConfigured from django.http import HttpRequest from django.test import TestCase from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory, force_authenticate from rest_framework.throttling import ( AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle, UserRateThrottle ) from rest_framework.views import APIView class User3SecRateThrottle(UserRateThrottle): rate = '3/sec' scope = 'seconds' class User3MinRateThrottle(UserRateThrottle): rate = '3/min' scope = 'minutes' class User6MinRateThrottle(UserRateThrottle): rate = '6/min' scope = 'minutes' class NonTimeThrottle(BaseThrottle): def allow_request(self, request, view): if not hasattr(self.__class__, 'called'): self.__class__.called = True return True return False class MockView_DoubleThrottling(APIView): throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,) def get(self, request): return Response('foo') class MockView(APIView): throttle_classes = (User3SecRateThrottle,) def get(self, request): return Response('foo') class MockView_MinuteThrottling(APIView): throttle_classes = (User3MinRateThrottle,) def get(self, request): return Response('foo') class MockView_NonTimeThrottling(APIView): throttle_classes = (NonTimeThrottle,) def get(self, request): return Response('foo') class ThrottlingTests(TestCase): def setUp(self): """ Reset the cache so that no throttles will be active """ cache.clear() self.factory = APIRequestFactory() def test_requests_are_throttled(self): """ Ensure request rate is limited """ request = self.factory.get('/') for dummy in range(4): response = MockView.as_view()(request) assert response.status_code == 429 def set_throttle_timer(self, view, value): """ Explicitly set the timer, overriding time.time() """ for cls in view.throttle_classes: cls.timer = lambda self: value def test_request_throttling_expires(self): """ Ensure request rate is limited for a limited duration only """ self.set_throttle_timer(MockView, 0) request = self.factory.get('/') for dummy in range(4): response = MockView.as_view()(request) assert response.status_code == 429 # Advance the timer by one second self.set_throttle_timer(MockView, 1) response = MockView.as_view()(request) assert response.status_code == 200 def ensure_is_throttled(self, view, expect): request = self.factory.get('/') request.user = User.objects.create(username='a') for dummy in range(3): view.as_view()(request) request.user = User.objects.create(username='b') response = view.as_view()(request) assert response.status_code == expect def test_request_throttling_is_per_user(self): """ Ensure request rate is only limited per user, not globally for PerUserThrottles """ self.ensure_is_throttled(MockView, 200) def test_request_throttling_multiple_throttles(self): """ Ensure all throttle classes see each request even when the request is already being throttled """ self.set_throttle_timer(MockView_DoubleThrottling, 0) request = self.factory.get('/') for dummy in range(4): response = MockView_DoubleThrottling.as_view()(request) assert response.status_code == 429 assert int(response['retry-after']) == 1 # At this point our client made 4 requests (one was throttled) in a # second. If we advance the timer by one additional second, the client # should be allowed to make 2 more before being throttled by the 2nd # throttle class, which has a limit of 6 per minute. self.set_throttle_timer(MockView_DoubleThrottling, 1) for dummy in range(2): response = MockView_DoubleThrottling.as_view()(request) assert response.status_code == 200 response = MockView_DoubleThrottling.as_view()(request) assert response.status_code == 429 assert int(response['retry-after']) == 59 # Just to make sure check again after two more seconds. self.set_throttle_timer(MockView_DoubleThrottling, 2) response = MockView_DoubleThrottling.as_view()(request) assert response.status_code == 429 assert int(response['retry-after']) == 58 def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): """ Ensure the response returns an Retry-After field with status and next attributes set properly. """ request = self.factory.get('/') for timer, expect in expected_headers: self.set_throttle_timer(view, timer) response = view.as_view()(request) if expect is not None: assert response['Retry-After'] == expect else: assert not'Retry-After' in response def test_seconds_fields(self): """ Ensure for second based throttles. """ self.ensure_response_header_contains_proper_throttle_field( MockView, ( (0, None), (0, None), (0, None), (0, '1') ) ) def test_minutes_fields(self): """ Ensure for minute based throttles. """ self.ensure_response_header_contains_proper_throttle_field( MockView_MinuteThrottling, ( (0, None), (0, None), (0, None), (0, '60') ) ) def test_next_rate_remains_constant_if_followed(self): """ If a client follows the recommended next request rate, the throttling rate should stay constant. """ self.ensure_response_header_contains_proper_throttle_field( MockView_MinuteThrottling, ( (0, None), (20, None), (40, None), (60, None), (80, None) ) ) def test_non_time_throttle(self): """ Ensure for second based throttles. """ request = self.factory.get('/') self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) response = MockView_NonTimeThrottling.as_view()(request) self.assertFalse('Retry-After' in response) self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) response = MockView_NonTimeThrottling.as_view()(request) self.assertFalse('Retry-After' in response) class ScopedRateThrottleTests(TestCase): """ Tests for ScopedRateThrottle. """ def setUp(self): self.throttle = ScopedRateThrottle() class XYScopedRateThrottle(ScopedRateThrottle): TIMER_SECONDS = 0 THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} def timer(self): return self.TIMER_SECONDS class XView(APIView): throttle_classes = (XYScopedRateThrottle,) throttle_scope = 'x' def get(self, request): return Response('x') class YView(APIView): throttle_classes = (XYScopedRateThrottle,) throttle_scope = 'y' def get(self, request): return Response('y') class UnscopedView(APIView): throttle_classes = (XYScopedRateThrottle,) def get(self, request): return Response('y') self.throttle_class = XYScopedRateThrottle self.factory = APIRequestFactory() self.x_view = XView.as_view() self.y_view = YView.as_view() self.unscoped_view = UnscopedView.as_view() def increment_timer(self, seconds=1): self.throttle_class.TIMER_SECONDS += seconds def test_scoped_rate_throttle(self): request = self.factory.get('/') # Should be able to hit x view 3 times per minute. response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 429 # Should be able to hit y view 1 time per minute. self.increment_timer() response = self.y_view(request) assert response.status_code == 200 self.increment_timer() response = self.y_view(request) assert response.status_code == 429 # Ensure throttles properly reset by advancing the rest of the minute self.increment_timer(55) # Should still be able to hit x view 3 times per minute. response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 429 # Should still be able to hit y view 1 time per minute. self.increment_timer() response = self.y_view(request) assert response.status_code == 200 self.increment_timer() response = self.y_view(request) assert response.status_code == 429 def test_unscoped_view_not_throttled(self): request = self.factory.get('/') for idx in range(10): self.increment_timer() response = self.unscoped_view(request) assert response.status_code == 200 def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self): class DummyView: throttle_scope = 'user' request = Request(HttpRequest()) user = User.objects.create(username='test') force_authenticate(request, user) request.user = user self.throttle.allow_request(request, DummyView()) cache_key = self.throttle.get_cache_key(request, view=DummyView()) assert cache_key == 'throttle_user_%s' % user.pk class XffTestingBase(TestCase): def setUp(self): class Throttle(ScopedRateThrottle): THROTTLE_RATES = {'test_limit': '1/day'} TIMER_SECONDS = 0 def timer(self): return self.TIMER_SECONDS class View(APIView): throttle_classes = (Throttle,) throttle_scope = 'test_limit' def get(self, request): return Response('test_limit') cache.clear() self.throttle = Throttle() self.view = View.as_view() self.request = APIRequestFactory().get('/some_uri') self.request.META['REMOTE_ADDR'] = '3.3.3.3' self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2' def config_proxy(self, num_proxies): setattr(api_settings, 'NUM_PROXIES', num_proxies) class IdWithXffBasicTests(XffTestingBase): def test_accepts_request_under_limit(self): self.config_proxy(0) assert self.view(self.request).status_code == 200 def test_denies_request_over_limit(self): self.config_proxy(0) self.view(self.request) assert self.view(self.request).status_code == 429 class XffSpoofingTests(XffTestingBase): def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self): self.config_proxy(1) self.view(self.request) self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2' assert self.view(self.request).status_code == 429 def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self): self.config_proxy(2) self.view(self.request) self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2' assert self.view(self.request).status_code == 429 class XffUniqueMachinesTest(XffTestingBase): def test_unique_clients_are_counted_independently_with_one_proxy(self): self.config_proxy(1) self.view(self.request) self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7' assert self.view(self.request).status_code == 200 def test_unique_clients_are_counted_independently_with_two_proxies(self): self.config_proxy(2) self.view(self.request) self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2' assert self.view(self.request).status_code == 200 class BaseThrottleTests(TestCase): def test_allow_request_raises_not_implemented_error(self): with pytest.raises(NotImplementedError): BaseThrottle().allow_request(request={}, view={}) class SimpleRateThrottleTests(TestCase): def setUp(self): SimpleRateThrottle.scope = 'anon' def test_get_rate_raises_error_if_scope_is_missing(self): throttle = SimpleRateThrottle() with pytest.raises(ImproperlyConfigured): throttle.scope = None throttle.get_rate() def test_throttle_raises_error_if_rate_is_missing(self): SimpleRateThrottle.scope = 'invalid scope' with pytest.raises(ImproperlyConfigured): SimpleRateThrottle() def test_parse_rate_returns_tuple_with_none_if_rate_not_provided(self): rate = SimpleRateThrottle().parse_rate(None) assert rate == (None, None) def test_allow_request_returns_true_if_rate_is_none(self): assert SimpleRateThrottle().allow_request(request={}, view={}) is True def test_get_cache_key_raises_not_implemented_error(self): with pytest.raises(NotImplementedError): SimpleRateThrottle().get_cache_key({}, {}) def test_allow_request_returns_true_if_key_is_none(self): throttle = SimpleRateThrottle() throttle.rate = 'some rate' throttle.get_cache_key = lambda *args: None assert throttle.allow_request(request={}, view={}) is True def test_wait_returns_correct_waiting_time_without_history(self): throttle = SimpleRateThrottle() throttle.num_requests = 1 throttle.duration = 60 throttle.history = [] waiting_time = throttle.wait() assert isinstance(waiting_time, float) assert waiting_time == 30.0 def test_wait_returns_none_if_there_are_no_available_requests(self): throttle = SimpleRateThrottle() throttle.num_requests = 1 throttle.duration = 60 throttle.now = throttle.timer() throttle.history = [throttle.timer() for _ in range(3)] assert throttle.wait() is None class AnonRateThrottleTests(TestCase): def setUp(self): self.throttle = AnonRateThrottle() def test_authenticated_user_not_affected(self): request = Request(HttpRequest()) user = User.objects.create(username='test') force_authenticate(request, user) request.user = user assert self.throttle.get_cache_key(request, view={}) is None def test_get_cache_key_returns_correct_value(self): request = Request(HttpRequest()) cache_key = self.throttle.get_cache_key(request, view={}) assert cache_key == 'throttle_anon_None'