mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-23 01:57:00 +03:00
397 lines
14 KiB
Python
397 lines
14 KiB
Python
# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
|
|
# to make it harder for the user to import the wrong thing without realizing.
|
|
import io
|
|
from importlib import import_module
|
|
|
|
from django.conf import settings
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
from django.core.handlers.wsgi import WSGIHandler
|
|
from django.test import override_settings, testcases
|
|
from django.test.client import Client as DjangoClient
|
|
from django.test.client import ClientHandler
|
|
from django.test.client import RequestFactory as DjangoRequestFactory
|
|
from django.utils.encoding import force_bytes
|
|
from django.utils.http import urlencode
|
|
|
|
from rest_framework.compat import coreapi, requests
|
|
from rest_framework.settings import api_settings
|
|
|
|
|
|
def force_authenticate(request, user=None, token=None):
|
|
request._force_auth_user = user
|
|
request._force_auth_token = token
|
|
|
|
|
|
if requests is not None:
|
|
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
|
def get_all(self, key, default):
|
|
return self.getheaders(key)
|
|
|
|
class MockOriginalResponse:
|
|
def __init__(self, headers):
|
|
self.msg = HeaderDict(headers)
|
|
self.closed = False
|
|
|
|
def isclosed(self):
|
|
return self.closed
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
|
|
"""
|
|
A transport adapter for `requests`, that makes requests via the
|
|
Django WSGI app, rather than making actual HTTP requests over the network.
|
|
"""
|
|
def __init__(self):
|
|
self.app = WSGIHandler()
|
|
self.factory = DjangoRequestFactory()
|
|
|
|
def get_environ(self, request):
|
|
"""
|
|
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
|
"""
|
|
method = request.method
|
|
url = request.url
|
|
kwargs = {}
|
|
|
|
# Set request content, if any exists.
|
|
if request.body is not None:
|
|
if hasattr(request.body, 'read'):
|
|
kwargs['data'] = request.body.read()
|
|
else:
|
|
kwargs['data'] = request.body
|
|
if 'content-type' in request.headers:
|
|
kwargs['content_type'] = request.headers['content-type']
|
|
|
|
# Set request headers.
|
|
for key, value in request.headers.items():
|
|
key = key.upper()
|
|
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
|
|
continue
|
|
kwargs['HTTP_%s' % key.replace('-', '_')] = value
|
|
|
|
return self.factory.generic(method, url, **kwargs).environ
|
|
|
|
def send(self, request, *args, **kwargs):
|
|
"""
|
|
Make an outgoing request to the Django WSGI application.
|
|
"""
|
|
raw_kwargs = {}
|
|
|
|
def start_response(wsgi_status, wsgi_headers):
|
|
status, _, reason = wsgi_status.partition(' ')
|
|
raw_kwargs['status'] = int(status)
|
|
raw_kwargs['reason'] = reason
|
|
raw_kwargs['headers'] = wsgi_headers
|
|
raw_kwargs['version'] = 11
|
|
raw_kwargs['preload_content'] = False
|
|
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
|
|
|
|
# Make the outgoing request via WSGI.
|
|
environ = self.get_environ(request)
|
|
wsgi_response = self.app(environ, start_response)
|
|
|
|
# Build the underlying urllib3.HTTPResponse
|
|
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
|
|
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
|
|
|
|
# Build the requests.Response
|
|
return self.build_response(request, raw)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
class RequestsClient(requests.Session):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
adapter = DjangoTestAdapter()
|
|
self.mount('http://', adapter)
|
|
self.mount('https://', adapter)
|
|
|
|
def request(self, method, url, *args, **kwargs):
|
|
if not url.startswith('http'):
|
|
raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
|
|
return super().request(method, url, *args, **kwargs)
|
|
|
|
else:
|
|
def RequestsClient(*args, **kwargs):
|
|
raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
|
|
|
|
|
|
if coreapi is not None:
|
|
class CoreAPIClient(coreapi.Client):
|
|
def __init__(self, *args, **kwargs):
|
|
self._session = RequestsClient()
|
|
kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
|
|
return super().__init__(*args, **kwargs)
|
|
|
|
@property
|
|
def session(self):
|
|
return self._session
|
|
|
|
else:
|
|
def CoreAPIClient(*args, **kwargs):
|
|
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
|
|
|
|
|
|
class APIRequestFactory(DjangoRequestFactory):
|
|
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
|
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
|
|
|
def __init__(self, enforce_csrf_checks=False, **defaults):
|
|
self.enforce_csrf_checks = enforce_csrf_checks
|
|
self.renderer_classes = {}
|
|
for cls in self.renderer_classes_list:
|
|
self.renderer_classes[cls.format] = cls
|
|
super().__init__(**defaults)
|
|
|
|
def _encode_data(self, data, format=None, content_type=None):
|
|
"""
|
|
Encode the data returning a two tuple of (bytes, content_type)
|
|
"""
|
|
|
|
if data is None:
|
|
return ('', content_type)
|
|
|
|
assert format is None or content_type is None, (
|
|
'You may not set both `format` and `content_type`.'
|
|
)
|
|
|
|
if content_type:
|
|
# Content type specified explicitly, treat data as a raw bytestring
|
|
ret = force_bytes(data, settings.DEFAULT_CHARSET)
|
|
|
|
else:
|
|
format = format or self.default_format
|
|
|
|
assert format in self.renderer_classes, (
|
|
"Invalid format '{}'. Available formats are {}. "
|
|
"Set TEST_REQUEST_RENDERER_CLASSES to enable "
|
|
"extra request formats.".format(
|
|
format,
|
|
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
|
|
)
|
|
)
|
|
|
|
# Use format and render the data into a bytestring
|
|
renderer = self.renderer_classes[format]()
|
|
ret = renderer.render(data)
|
|
|
|
# Determine the content-type header from the renderer
|
|
content_type = "{}; charset={}".format(
|
|
renderer.media_type, renderer.charset
|
|
)
|
|
|
|
# Coerce text to bytes if required.
|
|
if isinstance(ret, str):
|
|
ret = ret.encode(renderer.charset)
|
|
|
|
return ret, content_type
|
|
|
|
def get(self, path, data=None, **extra):
|
|
r = {
|
|
'QUERY_STRING': urlencode(data or {}, doseq=True),
|
|
}
|
|
if not data and '?' in path:
|
|
# Fix to support old behavior where you have the arguments in the
|
|
# url. See #1461.
|
|
query_string = force_bytes(path.split('?')[1])
|
|
query_string = query_string.decode('iso-8859-1')
|
|
r['QUERY_STRING'] = query_string
|
|
r.update(extra)
|
|
return self.generic('GET', path, **r)
|
|
|
|
def post(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('POST', path, data, content_type, **extra)
|
|
|
|
def put(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('PUT', path, data, content_type, **extra)
|
|
|
|
def patch(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('PATCH', path, data, content_type, **extra)
|
|
|
|
def delete(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('DELETE', path, data, content_type, **extra)
|
|
|
|
def options(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('OPTIONS', path, data, content_type, **extra)
|
|
|
|
def generic(self, method, path, data='',
|
|
content_type='application/octet-stream', secure=False, **extra):
|
|
# Include the CONTENT_TYPE, regardless of whether or not data is empty.
|
|
if content_type is not None:
|
|
extra['CONTENT_TYPE'] = str(content_type)
|
|
|
|
return super().generic(
|
|
method, path, data, content_type, secure, **extra)
|
|
|
|
def request(self, **kwargs):
|
|
request = super().request(**kwargs)
|
|
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
|
return request
|
|
|
|
|
|
class ForceAuthClientHandler(ClientHandler):
|
|
"""
|
|
A patched version of ClientHandler that can enforce authentication
|
|
on the outgoing requests.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._force_user = None
|
|
self._force_token = None
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def get_response(self, request):
|
|
# This is the simplest place we can hook into to patch the
|
|
# request object.
|
|
force_authenticate(request, self._force_user, self._force_token)
|
|
return super().get_response(request)
|
|
|
|
|
|
class APIClient(APIRequestFactory, DjangoClient):
|
|
def __init__(self, enforce_csrf_checks=False, **defaults):
|
|
super().__init__(**defaults)
|
|
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
|
|
self._credentials = {}
|
|
|
|
def credentials(self, **kwargs):
|
|
"""
|
|
Sets headers that will be used on every outgoing request.
|
|
"""
|
|
self._credentials = kwargs
|
|
|
|
def force_authenticate(self, user=None, token=None):
|
|
"""
|
|
Forcibly authenticates outgoing requests with the given
|
|
user and/or token.
|
|
"""
|
|
self.handler._force_user = user
|
|
self.handler._force_token = token
|
|
if user is None:
|
|
self.logout() # Also clear any possible session info if required
|
|
|
|
def request(self, **kwargs):
|
|
# Ensure that any credentials set get added to every request.
|
|
kwargs.update(self._credentials)
|
|
return super().request(**kwargs)
|
|
|
|
def get(self, path, data=None, follow=False, **extra):
|
|
response = super().get(path, data=data, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def post(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super().post(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def put(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super().put(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def patch(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super().patch(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def delete(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super().delete(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def options(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super().options(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def logout(self):
|
|
self._credentials = {}
|
|
|
|
# Also clear any `force_authenticate`
|
|
self.handler._force_user = None
|
|
self.handler._force_token = None
|
|
|
|
if self.session:
|
|
super().logout()
|
|
|
|
|
|
class APITransactionTestCase(testcases.TransactionTestCase):
|
|
client_class = APIClient
|
|
|
|
|
|
class APITestCase(testcases.TestCase):
|
|
client_class = APIClient
|
|
|
|
|
|
class APISimpleTestCase(testcases.SimpleTestCase):
|
|
client_class = APIClient
|
|
|
|
|
|
class APILiveServerTestCase(testcases.LiveServerTestCase):
|
|
client_class = APIClient
|
|
|
|
|
|
class URLPatternsTestCase(testcases.SimpleTestCase):
|
|
"""
|
|
Isolate URL patterns on a per-TestCase basis. For example,
|
|
|
|
class ATestCase(URLPatternsTestCase):
|
|
urlpatterns = [...]
|
|
|
|
def test_something(self):
|
|
...
|
|
|
|
class AnotherTestCase(URLPatternsTestCase):
|
|
urlpatterns = [...]
|
|
|
|
def test_something_else(self):
|
|
...
|
|
"""
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# Get the module of the TestCase subclass
|
|
cls._module = import_module(cls.__module__)
|
|
cls._override = override_settings(ROOT_URLCONF=cls.__module__)
|
|
|
|
if hasattr(cls._module, 'urlpatterns'):
|
|
cls._module_urlpatterns = cls._module.urlpatterns
|
|
|
|
cls._module.urlpatterns = cls.urlpatterns
|
|
|
|
cls._override.enable()
|
|
super().setUpClass()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
cls._override.disable()
|
|
|
|
if hasattr(cls, '_module_urlpatterns'):
|
|
cls._module.urlpatterns = cls._module_urlpatterns
|
|
else:
|
|
del cls._module.urlpatterns
|