diff --git a/rest_framework/test.py b/rest_framework/test.py index 9b40353a3..74d2c868f 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -156,6 +156,52 @@ class APIClient(APIRequestFactory, DjangoClient): kwargs.update(self._credentials) return super(APIClient, self).request(**kwargs) + def get(self, path, data=None, follow=False, **extra): + response = super(APIClient, self).get(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def post(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).post( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def put(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).put( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def patch(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).patch( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def delete(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).delete( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def options(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).options( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + def logout(self): self._credentials = {} return super(APIClient, self).logout() diff --git a/tests/test_testing.py b/tests/test_testing.py index 9c472026a..9fd5966eb 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -5,6 +5,7 @@ from django.conf.urls import patterns, url from io import BytesIO from django.contrib.auth.models import User +from django.shortcuts import redirect from django.test import TestCase from rest_framework.decorators import api_view from rest_framework.response import Response @@ -28,10 +29,16 @@ def session_view(request): }) +@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) +def redirect_view(request): + return redirect('/view/') + + urlpatterns = patterns( '', url(r'^view/$', view), url(r'^session-view/$', session_view), + url(r'^redirect-view/$', redirect_view), ) @@ -111,6 +118,46 @@ class TestAPITestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['auth'], b'') + def test_follow_redirect(self): + """ + Follow redirect by setting follow argument. + """ + response = self.client.get('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.get('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.post('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.post('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.put('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.put('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.patch('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.patch('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.delete('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.delete('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.options('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.options('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + class TestAPIRequestFactory(TestCase): def test_csrf_exempt_by_default(self):