mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-24 00:04:16 +03:00
Merge pull request #1922 from JonesChi/fix_follow
Fix follow does not work on get of APIRequestFactory
This commit is contained in:
commit
5e1ed0aa95
|
@ -156,6 +156,52 @@ class APIClient(APIRequestFactory, DjangoClient):
|
||||||
kwargs.update(self._credentials)
|
kwargs.update(self._credentials)
|
||||||
return super(APIClient, self).request(**kwargs)
|
return super(APIClient, self).request(**kwargs)
|
||||||
|
|
||||||
|
def get(self, path, data=None, follow=False, **extra):
|
||||||
|
response = super(APIClient, self).get(path, data=data, **extra)
|
||||||
|
if follow:
|
||||||
|
response = self._handle_redirects(response, **extra)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def post(self, path, data=None, format=None, content_type=None,
|
||||||
|
follow=False, **extra):
|
||||||
|
response = super(APIClient, self).post(
|
||||||
|
path, data=data, format=format, content_type=content_type, **extra)
|
||||||
|
if follow:
|
||||||
|
response = self._handle_redirects(response, **extra)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def put(self, path, data=None, format=None, content_type=None,
|
||||||
|
follow=False, **extra):
|
||||||
|
response = super(APIClient, self).put(
|
||||||
|
path, data=data, format=format, content_type=content_type, **extra)
|
||||||
|
if follow:
|
||||||
|
response = self._handle_redirects(response, **extra)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def patch(self, path, data=None, format=None, content_type=None,
|
||||||
|
follow=False, **extra):
|
||||||
|
response = super(APIClient, self).patch(
|
||||||
|
path, data=data, format=format, content_type=content_type, **extra)
|
||||||
|
if follow:
|
||||||
|
response = self._handle_redirects(response, **extra)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def delete(self, path, data=None, format=None, content_type=None,
|
||||||
|
follow=False, **extra):
|
||||||
|
response = super(APIClient, self).delete(
|
||||||
|
path, data=data, format=format, content_type=content_type, **extra)
|
||||||
|
if follow:
|
||||||
|
response = self._handle_redirects(response, **extra)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def options(self, path, data=None, format=None, content_type=None,
|
||||||
|
follow=False, **extra):
|
||||||
|
response = super(APIClient, self).options(
|
||||||
|
path, data=data, format=format, content_type=content_type, **extra)
|
||||||
|
if follow:
|
||||||
|
response = self._handle_redirects(response, **extra)
|
||||||
|
return response
|
||||||
|
|
||||||
def logout(self):
|
def logout(self):
|
||||||
self._credentials = {}
|
self._credentials = {}
|
||||||
return super(APIClient, self).logout()
|
return super(APIClient, self).logout()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from django.conf.urls import patterns, url
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
|
from django.shortcuts import redirect
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from rest_framework.decorators import api_view
|
from rest_framework.decorators import api_view
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
@ -28,10 +29,16 @@ def session_view(request):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
|
||||||
|
def redirect_view(request):
|
||||||
|
return redirect('/view/')
|
||||||
|
|
||||||
|
|
||||||
urlpatterns = patterns(
|
urlpatterns = patterns(
|
||||||
'',
|
'',
|
||||||
url(r'^view/$', view),
|
url(r'^view/$', view),
|
||||||
url(r'^session-view/$', session_view),
|
url(r'^session-view/$', session_view),
|
||||||
|
url(r'^redirect-view/$', redirect_view),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,6 +118,46 @@ class TestAPITestClient(TestCase):
|
||||||
response = self.client.get('/view/')
|
response = self.client.get('/view/')
|
||||||
self.assertEqual(response.data['auth'], b'')
|
self.assertEqual(response.data['auth'], b'')
|
||||||
|
|
||||||
|
def test_follow_redirect(self):
|
||||||
|
"""
|
||||||
|
Follow redirect by setting follow argument.
|
||||||
|
"""
|
||||||
|
response = self.client.get('/redirect-view/')
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
response = self.client.get('/redirect-view/', follow=True)
|
||||||
|
self.assertIsNotNone(response.redirect_chain)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
response = self.client.post('/redirect-view/')
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
response = self.client.post('/redirect-view/', follow=True)
|
||||||
|
self.assertIsNotNone(response.redirect_chain)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
response = self.client.put('/redirect-view/')
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
response = self.client.put('/redirect-view/', follow=True)
|
||||||
|
self.assertIsNotNone(response.redirect_chain)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
response = self.client.patch('/redirect-view/')
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
response = self.client.patch('/redirect-view/', follow=True)
|
||||||
|
self.assertIsNotNone(response.redirect_chain)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
response = self.client.delete('/redirect-view/')
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
response = self.client.delete('/redirect-view/', follow=True)
|
||||||
|
self.assertIsNotNone(response.redirect_chain)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
response = self.client.options('/redirect-view/')
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
response = self.client.options('/redirect-view/', follow=True)
|
||||||
|
self.assertIsNotNone(response.redirect_chain)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
|
||||||
class TestAPIRequestFactory(TestCase):
|
class TestAPIRequestFactory(TestCase):
|
||||||
def test_csrf_exempt_by_default(self):
|
def test_csrf_exempt_by_default(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user