diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 8555c21be..e7a73adda 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -199,8 +199,6 @@ if 'patch' not in View.http_method_names: View.http_method_names = View.http_method_names + ['patch'] -# RequestFactory only provides `generic` from 1.5 onwards -from django.test.client import RequestFactory as DjangoRequestFactory from django.test.client import FakePayload try: @@ -211,24 +209,30 @@ except ImportError: from django.utils.encoding import smart_str as force_bytes_or_smart_bytes -class RequestFactory(DjangoRequestFactory): - def generic(self, method, path, - data='', content_type='application/octet-stream', **extra): - parsed = _urlparse(path) - data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) - r = { - 'PATH_INFO': self._get_path(parsed), - 'QUERY_STRING': force_text(parsed[4]), - 'REQUEST_METHOD': six.text_type(method), - } - if data: - r.update({ - 'CONTENT_LENGTH': len(data), - 'CONTENT_TYPE': six.text_type(content_type), - 'wsgi.input': FakePayload(data), - }) - r.update(extra) - return self.request(**r) +# RequestFactory only provides `generic` from 1.5 onwards +if django.VERSION >= (1, 5): + from django.test.client import RequestFactory +else: + from django.test.client import RequestFactory as DjangoRequestFactory + + class RequestFactory(DjangoRequestFactory): + def generic(self, method, path, + data='', content_type='application/octet-stream', **extra): + parsed = _urlparse(path) + data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + r = { + 'PATH_INFO': self._get_path(parsed), + 'QUERY_STRING': force_text(parsed[4]), + 'REQUEST_METHOD': six.text_type(method), + } + if data: + r.update({ + 'CONTENT_LENGTH': len(data), + 'CONTENT_TYPE': six.text_type(content_type), + 'wsgi.input': FakePayload(data), + }) + r.update(extra) + return self.request(**r) # Markdown is optional diff --git a/tests/test_request.py b/tests/test_request.py index ebf94530b..03d9f8e49 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -24,6 +24,8 @@ from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView from io import BytesIO import json +import django +import pytest factory = APIRequestFactory() @@ -275,3 +277,16 @@ class TestAuthSetter(TestCase): request = Request(factory.get('/')) request.auth = 'DUMMY' self.assertEqual(request.auth, 'DUMMY') + + +@pytest.mark.skipif(django.VERSION < (1, 7), + reason='secure argument is only available for django1.7+') +class TestSecure(TestCase): + + def test_default_secure_false(self): + request = Request(factory.get('/', secure=False)) + self.assertEqual(request.scheme, 'http') + + def test_default_secure_true(self): + request = Request(factory.get('/', secure=True)) + self.assertEqual(request.scheme, 'https')