Fix follow does not work on APIClient

Handle follow just like Django's Client.
This commit is contained in:
Jones Chi 2014-10-03 14:42:49 +08:00 committed by ys.chi
parent ad1497898b
commit 2dfe75c23a
2 changed files with 93 additions and 0 deletions

View File

@ -156,6 +156,52 @@ class APIClient(APIRequestFactory, DjangoClient):
kwargs.update(self._credentials)
return super(APIClient, self).request(**kwargs)
def get(self, path, data=None, follow=False, **extra):
response = super(APIClient, self).get(path, data=data, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def post(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).post(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def put(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).put(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def patch(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).patch(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def delete(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).delete(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def options(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).options(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def logout(self):
self._credentials = {}
return super(APIClient, self).logout()

View File

@ -5,6 +5,7 @@ from django.conf.urls import patterns, url
from io import BytesIO
from django.contrib.auth.models import User
from django.shortcuts import redirect
from django.test import TestCase
from rest_framework.decorators import api_view
from rest_framework.response import Response
@ -28,10 +29,16 @@ def session_view(request):
})
@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
def redirect_view(request):
return redirect('/view/')
urlpatterns = patterns(
'',
url(r'^view/$', view),
url(r'^session-view/$', session_view),
url(r'^redirect-view/$', redirect_view),
)
@ -111,6 +118,46 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/')
self.assertEqual(response.data['auth'], b'')
def test_follow_redirect(self):
"""
Follow redirect by setting follow argument.
"""
response = self.client.get('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.get('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.post('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.post('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.put('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.put('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.patch('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.patch('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.delete('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.delete('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.options('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.options('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
class TestAPIRequestFactory(TestCase):
def test_csrf_exempt_by_default(self):