mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-29 04:54:00 +03:00
Merge pull request #3054 from thekorn/compat_requestfactory_secure
Compat requestfactory secure
This commit is contained in:
commit
894aa9b47e
|
@ -199,8 +199,6 @@ if 'patch' not in View.http_method_names:
|
||||||
View.http_method_names = View.http_method_names + ['patch']
|
View.http_method_names = View.http_method_names + ['patch']
|
||||||
|
|
||||||
|
|
||||||
# RequestFactory only provides `generic` from 1.5 onwards
|
|
||||||
from django.test.client import RequestFactory as DjangoRequestFactory
|
|
||||||
from django.test.client import FakePayload
|
from django.test.client import FakePayload
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -211,24 +209,30 @@ except ImportError:
|
||||||
from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
|
from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
|
||||||
|
|
||||||
|
|
||||||
class RequestFactory(DjangoRequestFactory):
|
# RequestFactory only provides `generic` from 1.5 onwards
|
||||||
def generic(self, method, path,
|
if django.VERSION >= (1, 5):
|
||||||
data='', content_type='application/octet-stream', **extra):
|
from django.test.client import RequestFactory
|
||||||
parsed = _urlparse(path)
|
else:
|
||||||
data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
|
from django.test.client import RequestFactory as DjangoRequestFactory
|
||||||
r = {
|
|
||||||
'PATH_INFO': self._get_path(parsed),
|
class RequestFactory(DjangoRequestFactory):
|
||||||
'QUERY_STRING': force_text(parsed[4]),
|
def generic(self, method, path,
|
||||||
'REQUEST_METHOD': six.text_type(method),
|
data='', content_type='application/octet-stream', **extra):
|
||||||
}
|
parsed = _urlparse(path)
|
||||||
if data:
|
data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
|
||||||
r.update({
|
r = {
|
||||||
'CONTENT_LENGTH': len(data),
|
'PATH_INFO': self._get_path(parsed),
|
||||||
'CONTENT_TYPE': six.text_type(content_type),
|
'QUERY_STRING': force_text(parsed[4]),
|
||||||
'wsgi.input': FakePayload(data),
|
'REQUEST_METHOD': six.text_type(method),
|
||||||
})
|
}
|
||||||
r.update(extra)
|
if data:
|
||||||
return self.request(**r)
|
r.update({
|
||||||
|
'CONTENT_LENGTH': len(data),
|
||||||
|
'CONTENT_TYPE': six.text_type(content_type),
|
||||||
|
'wsgi.input': FakePayload(data),
|
||||||
|
})
|
||||||
|
r.update(extra)
|
||||||
|
return self.request(**r)
|
||||||
|
|
||||||
|
|
||||||
# Markdown is optional
|
# Markdown is optional
|
||||||
|
|
|
@ -24,6 +24,8 @@ from rest_framework.test import APIRequestFactory, APIClient
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import json
|
import json
|
||||||
|
import django
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
factory = APIRequestFactory()
|
factory = APIRequestFactory()
|
||||||
|
@ -275,3 +277,16 @@ class TestAuthSetter(TestCase):
|
||||||
request = Request(factory.get('/'))
|
request = Request(factory.get('/'))
|
||||||
request.auth = 'DUMMY'
|
request.auth = 'DUMMY'
|
||||||
self.assertEqual(request.auth, 'DUMMY')
|
self.assertEqual(request.auth, 'DUMMY')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(django.VERSION < (1, 7),
|
||||||
|
reason='secure argument is only available for django1.7+')
|
||||||
|
class TestSecure(TestCase):
|
||||||
|
|
||||||
|
def test_default_secure_false(self):
|
||||||
|
request = Request(factory.get('/', secure=False))
|
||||||
|
self.assertEqual(request.scheme, 'http')
|
||||||
|
|
||||||
|
def test_default_secure_true(self):
|
||||||
|
request = Request(factory.get('/', secure=True))
|
||||||
|
self.assertEqual(request.scheme, 'https')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user