mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	Merge pull request #3054 from thekorn/compat_requestfactory_secure
Compat requestfactory secure
This commit is contained in:
		
						commit
						894aa9b47e
					
				| 
						 | 
				
			
			@ -199,8 +199,6 @@ if 'patch' not in View.http_method_names:
 | 
			
		|||
    View.http_method_names = View.http_method_names + ['patch']
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# RequestFactory only provides `generic` from 1.5 onwards
 | 
			
		||||
from django.test.client import RequestFactory as DjangoRequestFactory
 | 
			
		||||
from django.test.client import FakePayload
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
| 
						 | 
				
			
			@ -211,24 +209,30 @@ except ImportError:
 | 
			
		|||
    from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RequestFactory(DjangoRequestFactory):
 | 
			
		||||
    def generic(self, method, path,
 | 
			
		||||
            data='', content_type='application/octet-stream', **extra):
 | 
			
		||||
        parsed = _urlparse(path)
 | 
			
		||||
        data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
 | 
			
		||||
        r = {
 | 
			
		||||
            'PATH_INFO': self._get_path(parsed),
 | 
			
		||||
            'QUERY_STRING': force_text(parsed[4]),
 | 
			
		||||
            'REQUEST_METHOD': six.text_type(method),
 | 
			
		||||
        }
 | 
			
		||||
        if data:
 | 
			
		||||
            r.update({
 | 
			
		||||
                'CONTENT_LENGTH': len(data),
 | 
			
		||||
                'CONTENT_TYPE': six.text_type(content_type),
 | 
			
		||||
                'wsgi.input': FakePayload(data),
 | 
			
		||||
            })
 | 
			
		||||
        r.update(extra)
 | 
			
		||||
        return self.request(**r)
 | 
			
		||||
# RequestFactory only provides `generic` from 1.5 onwards
 | 
			
		||||
if django.VERSION >= (1, 5):
 | 
			
		||||
    from django.test.client import RequestFactory
 | 
			
		||||
else:
 | 
			
		||||
    from django.test.client import RequestFactory as DjangoRequestFactory
 | 
			
		||||
 | 
			
		||||
    class RequestFactory(DjangoRequestFactory):
 | 
			
		||||
        def generic(self, method, path,
 | 
			
		||||
                data='', content_type='application/octet-stream', **extra):
 | 
			
		||||
            parsed = _urlparse(path)
 | 
			
		||||
            data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
 | 
			
		||||
            r = {
 | 
			
		||||
                'PATH_INFO': self._get_path(parsed),
 | 
			
		||||
                'QUERY_STRING': force_text(parsed[4]),
 | 
			
		||||
                'REQUEST_METHOD': six.text_type(method),
 | 
			
		||||
            }
 | 
			
		||||
            if data:
 | 
			
		||||
                r.update({
 | 
			
		||||
                    'CONTENT_LENGTH': len(data),
 | 
			
		||||
                    'CONTENT_TYPE': six.text_type(content_type),
 | 
			
		||||
                    'wsgi.input': FakePayload(data),
 | 
			
		||||
                })
 | 
			
		||||
            r.update(extra)
 | 
			
		||||
            return self.request(**r)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Markdown is optional
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,6 +24,8 @@ from rest_framework.test import APIRequestFactory, APIClient
 | 
			
		|||
from rest_framework.views import APIView
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
import json
 | 
			
		||||
import django
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
factory = APIRequestFactory()
 | 
			
		||||
| 
						 | 
				
			
			@ -275,3 +277,16 @@ class TestAuthSetter(TestCase):
 | 
			
		|||
        request = Request(factory.get('/'))
 | 
			
		||||
        request.auth = 'DUMMY'
 | 
			
		||||
        self.assertEqual(request.auth, 'DUMMY')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(django.VERSION < (1, 7),
 | 
			
		||||
                    reason='secure argument is only available for django1.7+')
 | 
			
		||||
class TestSecure(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_default_secure_false(self):
 | 
			
		||||
        request = Request(factory.get('/', secure=False))
 | 
			
		||||
        self.assertEqual(request.scheme, 'http')
 | 
			
		||||
 | 
			
		||||
    def test_default_secure_true(self):
 | 
			
		||||
        request = Request(factory.get('/', secure=True))
 | 
			
		||||
        self.assertEqual(request.scheme, 'https')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user