django-rest-framework/rest_framework/tests/throttling.py
Tom Christie b052c92ac3 Cleanup imports
Mostly adding `from __future__ import unicode_literals` everywhere.
2013-02-04 20:55:35 +00:00

144 lines
4.2 KiB
Python

"""
Tests for the throttling implementations in the permissions module.
"""
from __future__ import unicode_literals
from django.test import TestCase
from django.contrib.auth.models import User
from django.core.cache import cache
from django.test.client import RequestFactory
from rest_framework.views import APIView
from rest_framework.throttling import UserRateThrottle
from rest_framework.response import Response
class User3SecRateThrottle(UserRateThrottle):
rate = '3/sec'
scope = 'seconds'
class User3MinRateThrottle(UserRateThrottle):
rate = '3/min'
scope = 'minutes'
class MockView(APIView):
throttle_classes = (User3SecRateThrottle,)
def get(self, request):
return Response('foo')
class MockView_MinuteThrottling(APIView):
throttle_classes = (User3MinRateThrottle,)
def get(self, request):
return Response('foo')
class ThrottlingTests(TestCase):
urls = 'rest_framework.tests.throttling'
def setUp(self):
"""
Reset the cache so that no throttles will be active
"""
cache.clear()
self.factory = RequestFactory()
def test_requests_are_throttled(self):
"""
Ensure request rate is limited
"""
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(429, response.status_code)
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
view.throttle_classes[0].timer = lambda self: value
def test_request_throttling_expires(self):
"""
Ensure request rate is limited for a limited duration only
"""
self.set_throttle_timer(MockView, 0)
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(429, response.status_code)
# Advance the timer by one second
self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request)
self.assertEqual(200, response.status_code)
def ensure_is_throttled(self, view, expect):
request = self.factory.get('/')
request.user = User.objects.create(username='a')
for dummy in range(3):
view.as_view()(request)
request.user = User.objects.create(username='b')
response = view.as_view()(request)
self.assertEqual(expect, response.status_code)
def test_request_throttling_is_per_user(self):
"""
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
self.ensure_is_throttled(MockView, 200)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an X-Throttle field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
if expect is not None:
self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
else:
self.assertFalse('X-Throttle-Wait-Seconds' in response)
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView,
((0, None),
(0, None),
(0, None),
(0, '1')
))
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, None),
(0, None),
(0, None),
(0, '60')
))
def test_next_rate_remains_constant_if_followed(self):
"""
If a client follows the recommended next request rate,
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
((0, None),
(20, None),
(40, None),
(60, None),
(80, None)
))