Fix follow does not work on APIClient

Handle follow just like Django's Client.
This commit is contained in:
Jones Chi 2014-10-03 14:42:49 +08:00 committed by ys.chi
parent ad1497898b
commit 2dfe75c23a
2 changed files with 93 additions and 0 deletions

View File

@ -156,6 +156,52 @@ class APIClient(APIRequestFactory, DjangoClient):
kwargs.update(self._credentials) kwargs.update(self._credentials)
return super(APIClient, self).request(**kwargs) return super(APIClient, self).request(**kwargs)
def get(self, path, data=None, follow=False, **extra):
response = super(APIClient, self).get(path, data=data, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def post(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).post(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def put(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).put(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def patch(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).patch(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def delete(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).delete(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def options(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super(APIClient, self).options(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def logout(self): def logout(self):
self._credentials = {} self._credentials = {}
return super(APIClient, self).logout() return super(APIClient, self).logout()

View File

@ -5,6 +5,7 @@ from django.conf.urls import patterns, url
from io import BytesIO from io import BytesIO
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.shortcuts import redirect
from django.test import TestCase from django.test import TestCase
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from rest_framework.response import Response from rest_framework.response import Response
@ -28,10 +29,16 @@ def session_view(request):
}) })
@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
def redirect_view(request):
return redirect('/view/')
urlpatterns = patterns( urlpatterns = patterns(
'', '',
url(r'^view/$', view), url(r'^view/$', view),
url(r'^session-view/$', session_view), url(r'^session-view/$', session_view),
url(r'^redirect-view/$', redirect_view),
) )
@ -111,6 +118,46 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/') response = self.client.get('/view/')
self.assertEqual(response.data['auth'], b'') self.assertEqual(response.data['auth'], b'')
def test_follow_redirect(self):
"""
Follow redirect by setting follow argument.
"""
response = self.client.get('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.get('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.post('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.post('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.put('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.put('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.patch('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.patch('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.delete('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.delete('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
response = self.client.options('/redirect-view/')
self.assertEqual(response.status_code, 302)
response = self.client.options('/redirect-view/', follow=True)
self.assertIsNotNone(response.redirect_chain)
self.assertEqual(response.status_code, 200)
class TestAPIRequestFactory(TestCase): class TestAPIRequestFactory(TestCase):
def test_csrf_exempt_by_default(self): def test_csrf_exempt_by_default(self):