""" Tests for the throttling implementations in the permissions module. """ from __future__ import unicode_literals import pytest from django.contrib.auth.models import User from django.core.cache import cache from django.core.exceptions import ImproperlyConfigured from django.http import HttpRequest from django.test import TestCase from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory, force_authenticate from rest_framework.throttling import ( AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle, UserRateThrottle ) from rest_framework.views import APIView class User3SecRateThrottle(UserRateThrottle): rate = '3/sec' scope = 'seconds' class User3MinRateThrottle(UserRateThrottle): rate = '3/min' scope = 'minutes' class NonTimeThrottle(BaseThrottle): def allow_request(self, request, view): if not hasattr(self.__class__, 'called'): self.__class__.called = True return True return False class MockView(APIView): throttle_classes = (User3SecRateThrottle,) def get(self, request): return Response('foo') class MockView_MinuteThrottling(APIView): throttle_classes = (User3MinRateThrottle,) def get(self, request): return Response('foo') class MockView_NonTimeThrottling(APIView): throttle_classes = (NonTimeThrottle,) def get(self, request): return Response('foo') class ThrottlingTests(TestCase): def setUp(self): """ Reset the cache so that no throttles will be active """ cache.clear() self.factory = APIRequestFactory() def test_requests_are_throttled(self): """ Ensure request rate is limited """ request = self.factory.get('/') for dummy in range(4): response = MockView.as_view()(request) assert response.status_code == 429 def set_throttle_timer(self, view, value): """ Explicitly set the timer, overriding time.time() """ view.throttle_classes[0].timer = lambda self: value def test_request_throttling_expires(self): """ Ensure request rate is limited for a limited duration only """ self.set_throttle_timer(MockView, 0) request = self.factory.get('/') for dummy in range(4): response = MockView.as_view()(request) assert response.status_code == 429 # Advance the timer by one second self.set_throttle_timer(MockView, 1) response = MockView.as_view()(request) assert response.status_code == 200 def ensure_is_throttled(self, view, expect): request = self.factory.get('/') request.user = User.objects.create(username='a') for dummy in range(3): view.as_view()(request) request.user = User.objects.create(username='b') response = view.as_view()(request) assert response.status_code == expect def test_request_throttling_is_per_user(self): """ Ensure request rate is only limited per user, not globally for PerUserThrottles """ self.ensure_is_throttled(MockView, 200) def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): """ Ensure the response returns an Retry-After field with status and next attributes set properly. """ request = self.factory.get('/') for timer, expect in expected_headers: self.set_throttle_timer(view, timer) response = view.as_view()(request) if expect is not None: assert response['Retry-After'] == expect else: assert not'Retry-After' in response def test_seconds_fields(self): """ Ensure for second based throttles. """ self.ensure_response_header_contains_proper_throttle_field( MockView, ( (0, None), (0, None), (0, None), (0, '1') ) ) def test_minutes_fields(self): """ Ensure for minute based throttles. """ self.ensure_response_header_contains_proper_throttle_field( MockView_MinuteThrottling, ( (0, None), (0, None), (0, None), (0, '60') ) ) def test_next_rate_remains_constant_if_followed(self): """ If a client follows the recommended next request rate, the throttling rate should stay constant. """ self.ensure_response_header_contains_proper_throttle_field( MockView_MinuteThrottling, ( (0, None), (20, None), (40, None), (60, None), (80, None) ) ) def test_non_time_throttle(self): """ Ensure for second based throttles. """ request = self.factory.get('/') self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) response = MockView_NonTimeThrottling.as_view()(request) self.assertFalse('Retry-After' in response) self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) response = MockView_NonTimeThrottling.as_view()(request) self.assertFalse('Retry-After' in response) class ScopedRateThrottleTests(TestCase): """ Tests for ScopedRateThrottle. """ def setUp(self): self.throttle = ScopedRateThrottle() class XYScopedRateThrottle(ScopedRateThrottle): TIMER_SECONDS = 0 THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} def timer(self): return self.TIMER_SECONDS class XView(APIView): throttle_classes = (XYScopedRateThrottle,) throttle_scope = 'x' def get(self, request): return Response('x') class YView(APIView): throttle_classes = (XYScopedRateThrottle,) throttle_scope = 'y' def get(self, request): return Response('y') class UnscopedView(APIView): throttle_classes = (XYScopedRateThrottle,) def get(self, request): return Response('y') self.throttle_class = XYScopedRateThrottle self.factory = APIRequestFactory() self.x_view = XView.as_view() self.y_view = YView.as_view() self.unscoped_view = UnscopedView.as_view() def increment_timer(self, seconds=1): self.throttle_class.TIMER_SECONDS += seconds def test_scoped_rate_throttle(self): request = self.factory.get('/') # Should be able to hit x view 3 times per minute. response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 429 # Should be able to hit y view 1 time per minute. self.increment_timer() response = self.y_view(request) assert response.status_code == 200 self.increment_timer() response = self.y_view(request) assert response.status_code == 429 # Ensure throttles properly reset by advancing the rest of the minute self.increment_timer(55) # Should still be able to hit x view 3 times per minute. response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 200 self.increment_timer() response = self.x_view(request) assert response.status_code == 429 # Should still be able to hit y view 1 time per minute. self.increment_timer() response = self.y_view(request) assert response.status_code == 200 self.increment_timer() response = self.y_view(request) assert response.status_code == 429 def test_unscoped_view_not_throttled(self): request = self.factory.get('/') for idx in range(10): self.increment_timer() response = self.unscoped_view(request) assert response.status_code == 200 def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self): class DummyView(object): throttle_scope = 'user' request = Request(HttpRequest()) user = User.objects.create(username='test') force_authenticate(request, user) request.user = user self.throttle.allow_request(request, DummyView()) cache_key = self.throttle.get_cache_key(request, view=DummyView()) assert cache_key == 'throttle_user_%s' % user.pk class XffTestingBase(TestCase): def setUp(self): class Throttle(ScopedRateThrottle): THROTTLE_RATES = {'test_limit': '1/day'} TIMER_SECONDS = 0 def timer(self): return self.TIMER_SECONDS class View(APIView): throttle_classes = (Throttle,) throttle_scope = 'test_limit' def get(self, request): return Response('test_limit') cache.clear() self.throttle = Throttle() self.view = View.as_view() self.request = APIRequestFactory().get('/some_uri') self.request.META['REMOTE_ADDR'] = '3.3.3.3' self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2' def config_proxy(self, num_proxies): setattr(api_settings, 'NUM_PROXIES', num_proxies) class IdWithXffBasicTests(XffTestingBase): def test_accepts_request_under_limit(self): self.config_proxy(0) assert self.view(self.request).status_code == 200 def test_denies_request_over_limit(self): self.config_proxy(0) self.view(self.request) assert self.view(self.request).status_code == 429 class XffSpoofingTests(XffTestingBase): def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self): self.config_proxy(1) self.view(self.request) self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2' assert self.view(self.request).status_code == 429 def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self): self.config_proxy(2) self.view(self.request) self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2' assert self.view(self.request).status_code == 429 class XffUniqueMachinesTest(XffTestingBase): def test_unique_clients_are_counted_independently_with_one_proxy(self): self.config_proxy(1) self.view(self.request) self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7' assert self.view(self.request).status_code == 200 def test_unique_clients_are_counted_independently_with_two_proxies(self): self.config_proxy(2) self.view(self.request) self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2' assert self.view(self.request).status_code == 200 class BaseThrottleTests(TestCase): def test_allow_request_raises_not_implemented_error(self): with pytest.raises(NotImplementedError): BaseThrottle().allow_request(request={}, view={}) class SimpleRateThrottleTests(TestCase): def setUp(self): SimpleRateThrottle.scope = 'anon' def test_get_rate_raises_error_if_scope_is_missing(self): throttle = SimpleRateThrottle() with pytest.raises(ImproperlyConfigured): throttle.scope = None throttle.get_rate() def test_throttle_raises_error_if_rate_is_missing(self): SimpleRateThrottle.scope = 'invalid scope' with pytest.raises(ImproperlyConfigured): SimpleRateThrottle() def test_parse_rate_returns_tuple_with_none_if_rate_not_provided(self): rate = SimpleRateThrottle().parse_rate(None) assert rate == (None, None) def test_allow_request_returns_true_if_rate_is_none(self): assert SimpleRateThrottle().allow_request(request={}, view={}) is True def test_get_cache_key_raises_not_implemented_error(self): with pytest.raises(NotImplementedError): SimpleRateThrottle().get_cache_key({}, {}) def test_allow_request_returns_true_if_key_is_none(self): throttle = SimpleRateThrottle() throttle.rate = 'some rate' throttle.get_cache_key = lambda *args: None assert throttle.allow_request(request={}, view={}) is True def test_wait_returns_correct_waiting_time_without_history(self): throttle = SimpleRateThrottle() throttle.num_requests = 1 throttle.duration = 60 throttle.history = [] waiting_time = throttle.wait() assert isinstance(waiting_time, float) assert waiting_time == 30.0 def test_wait_returns_none_if_there_are_no_available_requests(self): throttle = SimpleRateThrottle() throttle.num_requests = 1 throttle.duration = 60 throttle.now = throttle.timer() throttle.history = [throttle.timer() for _ in range(3)] assert throttle.wait() is None class AnonRateThrottleTests(TestCase): def setUp(self): self.throttle = AnonRateThrottle() def test_authenticated_user_not_affected(self): request = Request(HttpRequest()) user = User.objects.create(username='test') force_authenticate(request, user) request.user = user assert self.throttle.get_cache_key(request, view={}) is None def test_get_cache_key_returns_correct_value(self): request = Request(HttpRequest()) cache_key = self.throttle.get_cache_key(request, view={}) assert cache_key == 'throttle_anon_None'