diff --git a/rest_framework/test.py b/rest_framework/test.py index ded9d5fe9..e17c19a43 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -121,6 +121,11 @@ if requests is not None: return super(DjangoTestSession, self).request(method, url, *args, **kwargs) +def get_requests_client(): + assert requests is not None, 'requests must be installed' + return DjangoTestSession() + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -321,13 +326,6 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - @property - def requests(self): - if not hasattr(self, '_requests'): - assert requests is not None, 'requests must be installed' - self._requests = DjangoTestSession() - return self._requests - class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index aa99a71da..37bde1092 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -12,7 +12,7 @@ from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response -from rest_framework.test import APITestCase +from rest_framework.test import APITestCase, get_requests_client from rest_framework.views import APIView @@ -103,7 +103,8 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): - response = self.requests.get('/') + client = get_requests_client() + response = client.get('/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -113,7 +114,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_request_query_params_in_url(self): - response = self.requests.get('/?key=value') + client = get_requests_client() + response = client.get('/?key=value') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -123,7 +125,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_request_query_params_by_kwarg(self): - response = self.requests.get('/', params={'key': 'value'}) + client = get_requests_client() + response = client.get('/', params={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -133,14 +136,16 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_with_headers(self): - response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + client = get_requests_client() + response = client.get('/headers/', headers={'User-Agent': 'example'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' headers = response.json()['headers'] assert headers['USER-AGENT'] == 'example' def test_post_form_request(self): - response = self.requests.post('/', data={'key': 'value'}) + client = get_requests_client() + response = client.post('/', data={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -153,7 +158,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_post_json_request(self): - response = self.requests.post('/', json={'key': 'value'}) + client = get_requests_client() + response = client.post('/', json={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -166,10 +172,11 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_post_multipart_request(self): + client = get_requests_client() files = { 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') } - response = self.requests.post('/', files=files) + response = client.post('/', files=files) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -182,19 +189,20 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_session(self): - response = self.requests.get('/session/') + client = get_requests_client() + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {} assert response.json() == expected - response = self.requests.post('/session/', json={'example': 'abc'}) + response = client.post('/session/', json={'example': 'abc'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} assert response.json() == expected - response = self.requests.get('/session/') + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} @@ -202,7 +210,8 @@ class RequestsClientTests(APITestCase): def test_auth(self): # Confirm session is not authenticated - response = self.requests.get('/auth/') + client = get_requests_client() + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -217,7 +226,7 @@ class RequestsClientTests(APITestCase): user.save() # Perform a login - response = self.requests.post('/auth/', json={ + response = client.post('/auth/', json={ 'username': 'tom', 'password': 'password' }, headers={'X-CSRFToken': csrftoken}) @@ -229,7 +238,7 @@ class RequestsClientTests(APITestCase): assert response.json() == expected # Confirm session is authenticated - response = self.requests.get('/auth/') + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {