mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 01:47:59 +03:00 
			
		
		
		
	Added test client support for HTTP 307 and 308 redirects (#8419)
* Add retain test data on follow=True * Simplify TestAPITestClient.test_follow_redirect Inspired from Django's ClientTest.test_follow_307_and_308_redirect * Add 307 308 follow redirect test
This commit is contained in:
		
							parent
							
								
									df4d16d2f1
								
							
						
					
					
						commit
						df92e57ad6
					
				| 
						 | 
				
			
			@ -288,7 +288,7 @@ class APIClient(APIRequestFactory, DjangoClient):
 | 
			
		|||
    def get(self, path, data=None, follow=False, **extra):
 | 
			
		||||
        response = super().get(path, data=data, **extra)
 | 
			
		||||
        if follow:
 | 
			
		||||
            response = self._handle_redirects(response, **extra)
 | 
			
		||||
            response = self._handle_redirects(response, data=data, **extra)
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
    def post(self, path, data=None, format=None, content_type=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -296,7 +296,7 @@ class APIClient(APIRequestFactory, DjangoClient):
 | 
			
		|||
        response = super().post(
 | 
			
		||||
            path, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        if follow:
 | 
			
		||||
            response = self._handle_redirects(response, **extra)
 | 
			
		||||
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
    def put(self, path, data=None, format=None, content_type=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -304,7 +304,7 @@ class APIClient(APIRequestFactory, DjangoClient):
 | 
			
		|||
        response = super().put(
 | 
			
		||||
            path, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        if follow:
 | 
			
		||||
            response = self._handle_redirects(response, **extra)
 | 
			
		||||
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
    def patch(self, path, data=None, format=None, content_type=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -312,7 +312,7 @@ class APIClient(APIRequestFactory, DjangoClient):
 | 
			
		|||
        response = super().patch(
 | 
			
		||||
            path, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        if follow:
 | 
			
		||||
            response = self._handle_redirects(response, **extra)
 | 
			
		||||
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
    def delete(self, path, data=None, format=None, content_type=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -320,7 +320,7 @@ class APIClient(APIRequestFactory, DjangoClient):
 | 
			
		|||
        response = super().delete(
 | 
			
		||||
            path, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        if follow:
 | 
			
		||||
            response = self._handle_redirects(response, **extra)
 | 
			
		||||
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
    def options(self, path, data=None, format=None, content_type=None,
 | 
			
		||||
| 
						 | 
				
			
			@ -328,7 +328,7 @@ class APIClient(APIRequestFactory, DjangoClient):
 | 
			
		|||
        response = super().options(
 | 
			
		||||
            path, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        if follow:
 | 
			
		||||
            response = self._handle_redirects(response, **extra)
 | 
			
		||||
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
    def logout(self):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,10 @@
 | 
			
		|||
import itertools
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import django
 | 
			
		||||
from django.contrib.auth.models import User
 | 
			
		||||
from django.http import HttpResponseRedirect
 | 
			
		||||
from django.shortcuts import redirect
 | 
			
		||||
from django.test import TestCase, override_settings
 | 
			
		||||
from django.urls import path
 | 
			
		||||
| 
						 | 
				
			
			@ -14,7 +17,7 @@ from rest_framework.test import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@api_view(['GET', 'POST'])
 | 
			
		||||
@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
 | 
			
		||||
def view(request):
 | 
			
		||||
    return Response({
 | 
			
		||||
        'auth': request.META.get('HTTP_AUTHORIZATION', b''),
 | 
			
		||||
| 
						 | 
				
			
			@ -36,6 +39,11 @@ def redirect_view(request):
 | 
			
		|||
    return redirect('/view/')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
 | 
			
		||||
def redirect_307_308_view(request, code):
 | 
			
		||||
    return HttpResponseRedirect('/view/', status=code)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BasicSerializer(serializers.Serializer):
 | 
			
		||||
    flag = fields.BooleanField(default=lambda: True)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -51,6 +59,7 @@ urlpatterns = [
 | 
			
		|||
    path('view/', view),
 | 
			
		||||
    path('session-view/', session_view),
 | 
			
		||||
    path('redirect-view/', redirect_view),
 | 
			
		||||
    path('redirect-view/<int:code>/', redirect_307_308_view),
 | 
			
		||||
    path('post-view/', post_view)
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -146,41 +155,32 @@ class TestAPITestClient(TestCase):
 | 
			
		|||
        """
 | 
			
		||||
        Follow redirect by setting follow argument.
 | 
			
		||||
        """
 | 
			
		||||
        response = self.client.get('/redirect-view/')
 | 
			
		||||
        assert response.status_code == 302
 | 
			
		||||
        response = self.client.get('/redirect-view/', follow=True)
 | 
			
		||||
        assert response.redirect_chain is not None
 | 
			
		||||
        assert response.status_code == 200
 | 
			
		||||
        for method in ('get', 'post', 'put', 'patch', 'delete', 'options'):
 | 
			
		||||
            with self.subTest(method=method):
 | 
			
		||||
                req_method = getattr(self.client, method)
 | 
			
		||||
                response = req_method('/redirect-view/')
 | 
			
		||||
                assert response.status_code == 302
 | 
			
		||||
                response = req_method('/redirect-view/', follow=True)
 | 
			
		||||
                assert response.redirect_chain is not None
 | 
			
		||||
                assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
        response = self.client.post('/redirect-view/')
 | 
			
		||||
        assert response.status_code == 302
 | 
			
		||||
        response = self.client.post('/redirect-view/', follow=True)
 | 
			
		||||
        assert response.redirect_chain is not None
 | 
			
		||||
        assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
        response = self.client.put('/redirect-view/')
 | 
			
		||||
        assert response.status_code == 302
 | 
			
		||||
        response = self.client.put('/redirect-view/', follow=True)
 | 
			
		||||
        assert response.redirect_chain is not None
 | 
			
		||||
        assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
        response = self.client.patch('/redirect-view/')
 | 
			
		||||
        assert response.status_code == 302
 | 
			
		||||
        response = self.client.patch('/redirect-view/', follow=True)
 | 
			
		||||
        assert response.redirect_chain is not None
 | 
			
		||||
        assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
        response = self.client.delete('/redirect-view/')
 | 
			
		||||
        assert response.status_code == 302
 | 
			
		||||
        response = self.client.delete('/redirect-view/', follow=True)
 | 
			
		||||
        assert response.redirect_chain is not None
 | 
			
		||||
        assert response.status_code == 200
 | 
			
		||||
 | 
			
		||||
        response = self.client.options('/redirect-view/')
 | 
			
		||||
        assert response.status_code == 302
 | 
			
		||||
        response = self.client.options('/redirect-view/', follow=True)
 | 
			
		||||
        assert response.redirect_chain is not None
 | 
			
		||||
        assert response.status_code == 200
 | 
			
		||||
    def test_follow_307_308_preserve_kwargs(self, *mocked_methods):
 | 
			
		||||
        """
 | 
			
		||||
        Follow redirect by setting follow argument, and make sure the following
 | 
			
		||||
        method called with appropriate kwargs.
 | 
			
		||||
        """
 | 
			
		||||
        methods = ('get', 'post', 'put', 'patch', 'delete', 'options')
 | 
			
		||||
        codes = (307, 308)
 | 
			
		||||
        for method, code in itertools.product(methods, codes):
 | 
			
		||||
            subtest_ctx = self.subTest(method=method, code=code)
 | 
			
		||||
            patch_ctx = patch.object(self.client, method, side_effect=getattr(self.client, method))
 | 
			
		||||
            with subtest_ctx, patch_ctx as req_method:
 | 
			
		||||
                kwargs = {'data': {'example': 'test'}, 'format': 'json'}
 | 
			
		||||
                response = req_method('/redirect-view/%s/' % code, follow=True, **kwargs)
 | 
			
		||||
                assert response.redirect_chain is not None
 | 
			
		||||
                assert response.status_code == 200
 | 
			
		||||
                for _, call_args, call_kwargs in req_method.mock_calls:
 | 
			
		||||
                    assert all(call_kwargs[k] == kwargs[k] for k in kwargs if k in call_kwargs)
 | 
			
		||||
 | 
			
		||||
    def test_invalid_multipart_data(self):
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user