Add get_requests_client

This commit is contained in:
Tom Christie 2016-08-18 15:34:19 +01:00
parent 0b3db028a2
commit 0cc3f5008f
2 changed files with 28 additions and 21 deletions

View File

@ -121,6 +121,11 @@ if requests is not None:
return super(DjangoTestSession, self).request(method, url, *args, **kwargs) return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
def get_requests_client():
assert requests is not None, 'requests must be installed'
return DjangoTestSession()
class APIRequestFactory(DjangoRequestFactory): class APIRequestFactory(DjangoRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
@ -321,13 +326,6 @@ class APITransactionTestCase(testcases.TransactionTestCase):
class APITestCase(testcases.TestCase): class APITestCase(testcases.TestCase):
client_class = APIClient client_class = APIClient
@property
def requests(self):
if not hasattr(self, '_requests'):
assert requests is not None, 'requests must be installed'
self._requests = DjangoTestSession()
return self._requests
class APISimpleTestCase(testcases.SimpleTestCase): class APISimpleTestCase(testcases.SimpleTestCase):
client_class = APIClient client_class = APIClient

View File

@ -12,7 +12,7 @@ from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
from rest_framework.compat import is_authenticated, requests from rest_framework.compat import is_authenticated, requests
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APITestCase from rest_framework.test import APITestCase, get_requests_client
from rest_framework.views import APIView from rest_framework.views import APIView
@ -103,7 +103,8 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_requests_client') @override_settings(ROOT_URLCONF='tests.test_requests_client')
class RequestsClientTests(APITestCase): class RequestsClientTests(APITestCase):
def test_get_request(self): def test_get_request(self):
response = self.requests.get('/') client = get_requests_client()
response = client.get('/')
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = { expected = {
@ -113,7 +114,8 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected assert response.json() == expected
def test_get_request_query_params_in_url(self): def test_get_request_query_params_in_url(self):
response = self.requests.get('/?key=value') client = get_requests_client()
response = client.get('/?key=value')
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = { expected = {
@ -123,7 +125,8 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected assert response.json() == expected
def test_get_request_query_params_by_kwarg(self): def test_get_request_query_params_by_kwarg(self):
response = self.requests.get('/', params={'key': 'value'}) client = get_requests_client()
response = client.get('/', params={'key': 'value'})
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = { expected = {
@ -133,14 +136,16 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected assert response.json() == expected
def test_get_with_headers(self): def test_get_with_headers(self):
response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) client = get_requests_client()
response = client.get('/headers/', headers={'User-Agent': 'example'})
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
headers = response.json()['headers'] headers = response.json()['headers']
assert headers['USER-AGENT'] == 'example' assert headers['USER-AGENT'] == 'example'
def test_post_form_request(self): def test_post_form_request(self):
response = self.requests.post('/', data={'key': 'value'}) client = get_requests_client()
response = client.post('/', data={'key': 'value'})
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = { expected = {
@ -153,7 +158,8 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected assert response.json() == expected
def test_post_json_request(self): def test_post_json_request(self):
response = self.requests.post('/', json={'key': 'value'}) client = get_requests_client()
response = client.post('/', json={'key': 'value'})
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = { expected = {
@ -166,10 +172,11 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected assert response.json() == expected
def test_post_multipart_request(self): def test_post_multipart_request(self):
client = get_requests_client()
files = { files = {
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
} }
response = self.requests.post('/', files=files) response = client.post('/', files=files)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = { expected = {
@ -182,19 +189,20 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected assert response.json() == expected
def test_session(self): def test_session(self):
response = self.requests.get('/session/') client = get_requests_client()
response = client.get('/session/')
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = {} expected = {}
assert response.json() == expected assert response.json() == expected
response = self.requests.post('/session/', json={'example': 'abc'}) response = client.post('/session/', json={'example': 'abc'})
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'} expected = {'example': 'abc'}
assert response.json() == expected assert response.json() == expected
response = self.requests.get('/session/') response = client.get('/session/')
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'} expected = {'example': 'abc'}
@ -202,7 +210,8 @@ class RequestsClientTests(APITestCase):
def test_auth(self): def test_auth(self):
# Confirm session is not authenticated # Confirm session is not authenticated
response = self.requests.get('/auth/') client = get_requests_client()
response = client.get('/auth/')
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = { expected = {
@ -217,7 +226,7 @@ class RequestsClientTests(APITestCase):
user.save() user.save()
# Perform a login # Perform a login
response = self.requests.post('/auth/', json={ response = client.post('/auth/', json={
'username': 'tom', 'username': 'tom',
'password': 'password' 'password': 'password'
}, headers={'X-CSRFToken': csrftoken}) }, headers={'X-CSRFToken': csrftoken})
@ -229,7 +238,7 @@ class RequestsClientTests(APITestCase):
assert response.json() == expected assert response.json() == expected
# Confirm session is authenticated # Confirm session is authenticated
response = self.requests.get('/auth/') response = client.get('/auth/')
assert response.status_code == 200 assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json' assert response.headers['Content-Type'] == 'application/json'
expected = { expected = {