Merge pull request #3054 from thekorn/compat_requestfactory_secure

Compat requestfactory secure
This commit is contained in:
Tom Christie 2015-06-22 15:33:24 +01:00
commit 894aa9b47e
2 changed files with 39 additions and 20 deletions

View File

@ -199,8 +199,6 @@ if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch'] View.http_method_names = View.http_method_names + ['patch']
# RequestFactory only provides `generic` from 1.5 onwards
from django.test.client import RequestFactory as DjangoRequestFactory
from django.test.client import FakePayload from django.test.client import FakePayload
try: try:
@ -211,24 +209,30 @@ except ImportError:
from django.utils.encoding import smart_str as force_bytes_or_smart_bytes from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
class RequestFactory(DjangoRequestFactory): # RequestFactory only provides `generic` from 1.5 onwards
def generic(self, method, path, if django.VERSION >= (1, 5):
data='', content_type='application/octet-stream', **extra): from django.test.client import RequestFactory
parsed = _urlparse(path) else:
data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) from django.test.client import RequestFactory as DjangoRequestFactory
r = {
'PATH_INFO': self._get_path(parsed), class RequestFactory(DjangoRequestFactory):
'QUERY_STRING': force_text(parsed[4]), def generic(self, method, path,
'REQUEST_METHOD': six.text_type(method), data='', content_type='application/octet-stream', **extra):
} parsed = _urlparse(path)
if data: data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
r.update({ r = {
'CONTENT_LENGTH': len(data), 'PATH_INFO': self._get_path(parsed),
'CONTENT_TYPE': six.text_type(content_type), 'QUERY_STRING': force_text(parsed[4]),
'wsgi.input': FakePayload(data), 'REQUEST_METHOD': six.text_type(method),
}) }
r.update(extra) if data:
return self.request(**r) r.update({
'CONTENT_LENGTH': len(data),
'CONTENT_TYPE': six.text_type(content_type),
'wsgi.input': FakePayload(data),
})
r.update(extra)
return self.request(**r)
# Markdown is optional # Markdown is optional

View File

@ -24,6 +24,8 @@ from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView from rest_framework.views import APIView
from io import BytesIO from io import BytesIO
import json import json
import django
import pytest
factory = APIRequestFactory() factory = APIRequestFactory()
@ -275,3 +277,16 @@ class TestAuthSetter(TestCase):
request = Request(factory.get('/')) request = Request(factory.get('/'))
request.auth = 'DUMMY' request.auth = 'DUMMY'
self.assertEqual(request.auth, 'DUMMY') self.assertEqual(request.auth, 'DUMMY')
@pytest.mark.skipif(django.VERSION < (1, 7),
reason='secure argument is only available for django1.7+')
class TestSecure(TestCase):
def test_default_secure_false(self):
request = Request(factory.get('/', secure=False))
self.assertEqual(request.scheme, 'http')
def test_default_secure_true(self):
request = Request(factory.get('/', secure=True))
self.assertEqual(request.scheme, 'https')