Added test client support for HTTP 307 and 308 redirects (#8419)

* Add retain test data on follow=True

* Simplify TestAPITestClient.test_follow_redirect

Inspired from Django's ClientTest.test_follow_307_and_308_redirect

* Add 307 308 follow redirect test
This commit is contained in:
hashlash 2022-03-24 16:57:42 +07:00 committed by GitHub
parent df4d16d2f1
commit df92e57ad6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 41 deletions

View File

@ -288,7 +288,7 @@ class APIClient(APIRequestFactory, DjangoClient):
def get(self, path, data=None, follow=False, **extra): def get(self, path, data=None, follow=False, **extra):
response = super().get(path, data=data, **extra) response = super().get(path, data=data, **extra)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, data=data, **extra)
return response return response
def post(self, path, data=None, format=None, content_type=None, def post(self, path, data=None, format=None, content_type=None,
@ -296,7 +296,7 @@ class APIClient(APIRequestFactory, DjangoClient):
response = super().post( response = super().post(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response return response
def put(self, path, data=None, format=None, content_type=None, def put(self, path, data=None, format=None, content_type=None,
@ -304,7 +304,7 @@ class APIClient(APIRequestFactory, DjangoClient):
response = super().put( response = super().put(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response return response
def patch(self, path, data=None, format=None, content_type=None, def patch(self, path, data=None, format=None, content_type=None,
@ -312,7 +312,7 @@ class APIClient(APIRequestFactory, DjangoClient):
response = super().patch( response = super().patch(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response return response
def delete(self, path, data=None, format=None, content_type=None, def delete(self, path, data=None, format=None, content_type=None,
@ -320,7 +320,7 @@ class APIClient(APIRequestFactory, DjangoClient):
response = super().delete( response = super().delete(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response return response
def options(self, path, data=None, format=None, content_type=None, def options(self, path, data=None, format=None, content_type=None,
@ -328,7 +328,7 @@ class APIClient(APIRequestFactory, DjangoClient):
response = super().options( response = super().options(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response return response
def logout(self): def logout(self):

View File

@ -1,7 +1,10 @@
import itertools
from io import BytesIO from io import BytesIO
from unittest.mock import patch
import django import django
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.http import HttpResponseRedirect
from django.shortcuts import redirect from django.shortcuts import redirect
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import path from django.urls import path
@ -14,7 +17,7 @@ from rest_framework.test import (
) )
@api_view(['GET', 'POST']) @api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
def view(request): def view(request):
return Response({ return Response({
'auth': request.META.get('HTTP_AUTHORIZATION', b''), 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
@ -36,6 +39,11 @@ def redirect_view(request):
return redirect('/view/') return redirect('/view/')
@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
def redirect_307_308_view(request, code):
return HttpResponseRedirect('/view/', status=code)
class BasicSerializer(serializers.Serializer): class BasicSerializer(serializers.Serializer):
flag = fields.BooleanField(default=lambda: True) flag = fields.BooleanField(default=lambda: True)
@ -51,6 +59,7 @@ urlpatterns = [
path('view/', view), path('view/', view),
path('session-view/', session_view), path('session-view/', session_view),
path('redirect-view/', redirect_view), path('redirect-view/', redirect_view),
path('redirect-view/<int:code>/', redirect_307_308_view),
path('post-view/', post_view) path('post-view/', post_view)
] ]
@ -146,41 +155,32 @@ class TestAPITestClient(TestCase):
""" """
Follow redirect by setting follow argument. Follow redirect by setting follow argument.
""" """
response = self.client.get('/redirect-view/') for method in ('get', 'post', 'put', 'patch', 'delete', 'options'):
with self.subTest(method=method):
req_method = getattr(self.client, method)
response = req_method('/redirect-view/')
assert response.status_code == 302 assert response.status_code == 302
response = self.client.get('/redirect-view/', follow=True) response = req_method('/redirect-view/', follow=True)
assert response.redirect_chain is not None assert response.redirect_chain is not None
assert response.status_code == 200 assert response.status_code == 200
response = self.client.post('/redirect-view/') def test_follow_307_308_preserve_kwargs(self, *mocked_methods):
assert response.status_code == 302 """
response = self.client.post('/redirect-view/', follow=True) Follow redirect by setting follow argument, and make sure the following
assert response.redirect_chain is not None method called with appropriate kwargs.
assert response.status_code == 200 """
methods = ('get', 'post', 'put', 'patch', 'delete', 'options')
response = self.client.put('/redirect-view/') codes = (307, 308)
assert response.status_code == 302 for method, code in itertools.product(methods, codes):
response = self.client.put('/redirect-view/', follow=True) subtest_ctx = self.subTest(method=method, code=code)
assert response.redirect_chain is not None patch_ctx = patch.object(self.client, method, side_effect=getattr(self.client, method))
assert response.status_code == 200 with subtest_ctx, patch_ctx as req_method:
kwargs = {'data': {'example': 'test'}, 'format': 'json'}
response = self.client.patch('/redirect-view/') response = req_method('/redirect-view/%s/' % code, follow=True, **kwargs)
assert response.status_code == 302
response = self.client.patch('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
response = self.client.delete('/redirect-view/')
assert response.status_code == 302
response = self.client.delete('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
response = self.client.options('/redirect-view/')
assert response.status_code == 302
response = self.client.options('/redirect-view/', follow=True)
assert response.redirect_chain is not None assert response.redirect_chain is not None
assert response.status_code == 200 assert response.status_code == 200
for _, call_args, call_kwargs in req_method.mock_calls:
assert all(call_kwargs[k] == kwargs[k] for k in kwargs if k in call_kwargs)
def test_invalid_multipart_data(self): def test_invalid_multipart_data(self):
""" """