mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 01:47:59 +03:00 
			
		
		
		
	Merge pull request #3054 from thekorn/compat_requestfactory_secure
Compat requestfactory secure
This commit is contained in:
		
						commit
						894aa9b47e
					
				| 
						 | 
					@ -199,8 +199,6 @@ if 'patch' not in View.http_method_names:
 | 
				
			||||||
    View.http_method_names = View.http_method_names + ['patch']
 | 
					    View.http_method_names = View.http_method_names + ['patch']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# RequestFactory only provides `generic` from 1.5 onwards
 | 
					 | 
				
			||||||
from django.test.client import RequestFactory as DjangoRequestFactory
 | 
					 | 
				
			||||||
from django.test.client import FakePayload
 | 
					from django.test.client import FakePayload
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
| 
						 | 
					@ -211,24 +209,30 @@ except ImportError:
 | 
				
			||||||
    from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
 | 
					    from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RequestFactory(DjangoRequestFactory):
 | 
					# RequestFactory only provides `generic` from 1.5 onwards
 | 
				
			||||||
    def generic(self, method, path,
 | 
					if django.VERSION >= (1, 5):
 | 
				
			||||||
            data='', content_type='application/octet-stream', **extra):
 | 
					    from django.test.client import RequestFactory
 | 
				
			||||||
        parsed = _urlparse(path)
 | 
					else:
 | 
				
			||||||
        data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
 | 
					    from django.test.client import RequestFactory as DjangoRequestFactory
 | 
				
			||||||
        r = {
 | 
					
 | 
				
			||||||
            'PATH_INFO': self._get_path(parsed),
 | 
					    class RequestFactory(DjangoRequestFactory):
 | 
				
			||||||
            'QUERY_STRING': force_text(parsed[4]),
 | 
					        def generic(self, method, path,
 | 
				
			||||||
            'REQUEST_METHOD': six.text_type(method),
 | 
					                data='', content_type='application/octet-stream', **extra):
 | 
				
			||||||
        }
 | 
					            parsed = _urlparse(path)
 | 
				
			||||||
        if data:
 | 
					            data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
 | 
				
			||||||
            r.update({
 | 
					            r = {
 | 
				
			||||||
                'CONTENT_LENGTH': len(data),
 | 
					                'PATH_INFO': self._get_path(parsed),
 | 
				
			||||||
                'CONTENT_TYPE': six.text_type(content_type),
 | 
					                'QUERY_STRING': force_text(parsed[4]),
 | 
				
			||||||
                'wsgi.input': FakePayload(data),
 | 
					                'REQUEST_METHOD': six.text_type(method),
 | 
				
			||||||
            })
 | 
					            }
 | 
				
			||||||
        r.update(extra)
 | 
					            if data:
 | 
				
			||||||
        return self.request(**r)
 | 
					                r.update({
 | 
				
			||||||
 | 
					                    'CONTENT_LENGTH': len(data),
 | 
				
			||||||
 | 
					                    'CONTENT_TYPE': six.text_type(content_type),
 | 
				
			||||||
 | 
					                    'wsgi.input': FakePayload(data),
 | 
				
			||||||
 | 
					                })
 | 
				
			||||||
 | 
					            r.update(extra)
 | 
				
			||||||
 | 
					            return self.request(**r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Markdown is optional
 | 
					# Markdown is optional
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -24,6 +24,8 @@ from rest_framework.test import APIRequestFactory, APIClient
 | 
				
			||||||
from rest_framework.views import APIView
 | 
					from rest_framework.views import APIView
 | 
				
			||||||
from io import BytesIO
 | 
					from io import BytesIO
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					import django
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
factory = APIRequestFactory()
 | 
					factory = APIRequestFactory()
 | 
				
			||||||
| 
						 | 
					@ -275,3 +277,16 @@ class TestAuthSetter(TestCase):
 | 
				
			||||||
        request = Request(factory.get('/'))
 | 
					        request = Request(factory.get('/'))
 | 
				
			||||||
        request.auth = 'DUMMY'
 | 
					        request.auth = 'DUMMY'
 | 
				
			||||||
        self.assertEqual(request.auth, 'DUMMY')
 | 
					        self.assertEqual(request.auth, 'DUMMY')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.skipif(django.VERSION < (1, 7),
 | 
				
			||||||
 | 
					                    reason='secure argument is only available for django1.7+')
 | 
				
			||||||
 | 
					class TestSecure(TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_default_secure_false(self):
 | 
				
			||||||
 | 
					        request = Request(factory.get('/', secure=False))
 | 
				
			||||||
 | 
					        self.assertEqual(request.scheme, 'http')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_default_secure_true(self):
 | 
				
			||||||
 | 
					        request = Request(factory.get('/', secure=True))
 | 
				
			||||||
 | 
					        self.assertEqual(request.scheme, 'https')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user