conerted throttling tests asserts to pytest

This commit is contained in:
Asif Saifuddin Auvi 2017-01-04 15:59:21 +06:00
parent 7874bcabe9
commit a7d33f4519

View File

@ -70,7 +70,7 @@ class ThrottlingTests(TestCase):
request = self.factory.get('/') request = self.factory.get('/')
for dummy in range(4): for dummy in range(4):
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(429, response.status_code) assert response.status_code == 429
def set_throttle_timer(self, view, value): def set_throttle_timer(self, view, value):
""" """
@ -87,13 +87,13 @@ class ThrottlingTests(TestCase):
request = self.factory.get('/') request = self.factory.get('/')
for dummy in range(4): for dummy in range(4):
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(429, response.status_code) assert response.status_code == 429
# Advance the timer by one second # Advance the timer by one second
self.set_throttle_timer(MockView, 1) self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request) response = MockView.as_view()(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
def ensure_is_throttled(self, view, expect): def ensure_is_throttled(self, view, expect):
request = self.factory.get('/') request = self.factory.get('/')
@ -102,7 +102,7 @@ class ThrottlingTests(TestCase):
view.as_view()(request) view.as_view()(request)
request.user = User.objects.create(username='b') request.user = User.objects.create(username='b')
response = view.as_view()(request) response = view.as_view()(request)
self.assertEqual(expect, response.status_code) assert response.status_code == expect
def test_request_throttling_is_per_user(self): def test_request_throttling_is_per_user(self):
""" """
@ -121,9 +121,9 @@ class ThrottlingTests(TestCase):
self.set_throttle_timer(view, timer) self.set_throttle_timer(view, timer)
response = view.as_view()(request) response = view.as_view()(request)
if expect is not None: if expect is not None:
self.assertEqual(response['Retry-After'], expect) assert response['Retry-After'] == expect
else: else:
self.assertFalse('Retry-After' in response) assert not'Retry-After' in response
def test_seconds_fields(self): def test_seconds_fields(self):
""" """
@ -230,56 +230,55 @@ class ScopedRateThrottleTests(TestCase):
# Should be able to hit x view 3 times per minute. # Should be able to hit x view 3 times per minute.
response = self.x_view(request) response = self.x_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
self.increment_timer() self.increment_timer()
response = self.x_view(request) response = self.x_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
self.increment_timer() self.increment_timer()
response = self.x_view(request) response = self.x_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
self.increment_timer() self.increment_timer()
response = self.x_view(request) response = self.x_view(request)
self.assertEqual(429, response.status_code) assert response.status_code == 429
# Should be able to hit y view 1 time per minute. # Should be able to hit y view 1 time per minute.
self.increment_timer() self.increment_timer()
response = self.y_view(request) response = self.y_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
self.increment_timer() self.increment_timer()
response = self.y_view(request) response = self.y_view(request)
self.assertEqual(429, response.status_code) assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute # Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(55) self.increment_timer(55)
# Should still be able to hit x view 3 times per minute. # Should still be able to hit x view 3 times per minute.
response = self.x_view(request) response = self.x_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
self.increment_timer() self.increment_timer()
response = self.x_view(request) response = self.x_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
self.increment_timer() self.increment_timer()
response = self.x_view(request) response = self.x_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
self.increment_timer() self.increment_timer()
response = self.x_view(request) response = self.x_view(request)
self.assertEqual(429, response.status_code) assert response.status_code == 429
# Should still be able to hit y view 1 time per minute. # Should still be able to hit y view 1 time per minute.
self.increment_timer() self.increment_timer()
response = self.y_view(request) response = self.y_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
self.increment_timer() self.increment_timer()
response = self.y_view(request) response = self.y_view(request)
self.assertEqual(429, response.status_code) assert response.status_code == 429
def test_unscoped_view_not_throttled(self): def test_unscoped_view_not_throttled(self):
request = self.factory.get('/') request = self.factory.get('/')
@ -287,7 +286,7 @@ class ScopedRateThrottleTests(TestCase):
for idx in range(10): for idx in range(10):
self.increment_timer() self.increment_timer()
response = self.unscoped_view(request) response = self.unscoped_view(request)
self.assertEqual(200, response.status_code) assert response.status_code == 200
class XffTestingBase(TestCase): class XffTestingBase(TestCase):
@ -321,12 +320,12 @@ class XffTestingBase(TestCase):
class IdWithXffBasicTests(XffTestingBase): class IdWithXffBasicTests(XffTestingBase):
def test_accepts_request_under_limit(self): def test_accepts_request_under_limit(self):
self.config_proxy(0) self.config_proxy(0)
self.assertEqual(200, self.view(self.request).status_code) assert self.view(self.request).status_code == 200
def test_denies_request_over_limit(self): def test_denies_request_over_limit(self):
self.config_proxy(0) self.config_proxy(0)
self.view(self.request) self.view(self.request)
self.assertEqual(429, self.view(self.request).status_code) assert self.view(self.request).status_code == 429
class XffSpoofingTests(XffTestingBase): class XffSpoofingTests(XffTestingBase):
@ -334,13 +333,13 @@ class XffSpoofingTests(XffTestingBase):
self.config_proxy(1) self.config_proxy(1)
self.view(self.request) self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2' self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
self.assertEqual(429, self.view(self.request).status_code) assert self.view(self.request).status_code == 429
def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self): def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
self.config_proxy(2) self.config_proxy(2)
self.view(self.request) self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2' self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
self.assertEqual(429, self.view(self.request).status_code) assert self.view(self.request).status_code == 429
class XffUniqueMachinesTest(XffTestingBase): class XffUniqueMachinesTest(XffTestingBase):
@ -348,10 +347,10 @@ class XffUniqueMachinesTest(XffTestingBase):
self.config_proxy(1) self.config_proxy(1)
self.view(self.request) self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7' self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
self.assertEqual(200, self.view(self.request).status_code) assert self.view(self.request).status_code == 200
def test_unique_clients_are_counted_independently_with_two_proxies(self): def test_unique_clients_are_counted_independently_with_two_proxies(self):
self.config_proxy(2) self.config_proxy(2)
self.view(self.request) self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2' self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
self.assertEqual(200, self.view(self.request).status_code) assert self.view(self.request).status_code == 200