# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order # to make it harder for the user to import the wrong thing without realizing. import io from importlib import import_module import django from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.handlers.wsgi import WSGIHandler from django.test import override_settings, testcases from django.test.client import Client as DjangoClient from django.test.client import ClientHandler from django.test.client import RequestFactory as DjangoRequestFactory from django.utils.encoding import force_bytes from django.utils.http import urlencode from rest_framework.compat import coreapi, requests from rest_framework.settings import api_settings def force_authenticate(request, user=None, token=None): request._force_auth_user = user request._force_auth_token = token if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): def get_all(self, key, default): return self.getheaders(key) class MockOriginalResponse: def __init__(self, headers): self.msg = HeaderDict(headers) self.closed = False def isclosed(self): return self.closed def close(self): self.closed = True class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the Django WSGI app, rather than making actual HTTP requests over the network. """ def __init__(self): self.app = WSGIHandler() self.factory = DjangoRequestFactory() def get_environ(self, request): """ Given a `requests.PreparedRequest` instance, return a WSGI environ dict. """ method = request.method url = request.url kwargs = {} # Set request content, if any exists. if request.body is not None: if hasattr(request.body, 'read'): kwargs['data'] = request.body.read() else: kwargs['data'] = request.body if 'content-type' in request.headers: kwargs['content_type'] = request.headers['content-type'] # Set request headers. for key, value in request.headers.items(): key = key.upper() if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): continue kwargs['HTTP_%s' % key.replace('-', '_')] = value return self.factory.generic(method, url, **kwargs).environ def send(self, request, *args, **kwargs): """ Make an outgoing request to the Django WSGI application. """ raw_kwargs = {} def start_response(wsgi_status, wsgi_headers, exc_info=None): status, _, reason = wsgi_status.partition(' ') raw_kwargs['status'] = int(status) raw_kwargs['reason'] = reason raw_kwargs['headers'] = wsgi_headers raw_kwargs['version'] = 11 raw_kwargs['preload_content'] = False raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) # Make the outgoing request via WSGI. environ = self.get_environ(request) wsgi_response = self.app(environ, start_response) # Build the underlying urllib3.HTTPResponse raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) # Build the requests.Response return self.build_response(request, raw) def close(self): pass class RequestsClient(requests.Session): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) adapter = DjangoTestAdapter() self.mount('http://', adapter) self.mount('https://', adapter) def request(self, method, url, *args, **kwargs): if not url.startswith('http'): raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url) return super().request(method, url, *args, **kwargs) else: def RequestsClient(*args, **kwargs): raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.') if coreapi is not None: class CoreAPIClient(coreapi.Client): def __init__(self, *args, **kwargs): self._session = RequestsClient() kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)] super().__init__(*args, **kwargs) @property def session(self): return self._session else: def CoreAPIClient(*args, **kwargs): raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.') class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT def __init__(self, enforce_csrf_checks=False, **defaults): self.enforce_csrf_checks = enforce_csrf_checks self.renderer_classes = {} for cls in self.renderer_classes_list: self.renderer_classes[cls.format] = cls super().__init__(**defaults) def _encode_data(self, data, format=None, content_type=None): """ Encode the data returning a two tuple of (bytes, content_type) """ if data is None: return ('', content_type) assert format is None or content_type is None, ( 'You may not set both `format` and `content_type`.' ) if content_type: # Content type specified explicitly, treat data as a raw bytestring ret = force_bytes(data, settings.DEFAULT_CHARSET) else: format = format or self.default_format assert format in self.renderer_classes, ( "Invalid format '{}'. Available formats are {}. " "Set TEST_REQUEST_RENDERER_CLASSES to enable " "extra request formats.".format( format, ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes]) ) ) # Use format and render the data into a bytestring renderer = self.renderer_classes[format]() ret = renderer.render(data) # Determine the content-type header from the renderer content_type = renderer.media_type if renderer.charset: content_type = "{}; charset={}".format( content_type, renderer.charset ) # Coerce text to bytes if required. if isinstance(ret, str): ret = ret.encode(renderer.charset) return ret, content_type def get(self, path, data=None, **extra): r = { 'QUERY_STRING': urlencode(data or {}, doseq=True), } if not data and '?' in path: # Fix to support old behavior where you have the arguments in the # url. See #1461. query_string = force_bytes(path.split('?')[1]) query_string = query_string.decode('iso-8859-1') r['QUERY_STRING'] = query_string r.update(extra) return self.generic('GET', path, **r) def post(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('POST', path, data, content_type, **extra) def put(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('PUT', path, data, content_type, **extra) def patch(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('PATCH', path, data, content_type, **extra) def delete(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('DELETE', path, data, content_type, **extra) def options(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('OPTIONS', path, data, content_type, **extra) def generic(self, method, path, data='', content_type='application/octet-stream', secure=False, **extra): # Include the CONTENT_TYPE, regardless of whether or not data is empty. if content_type is not None: extra['CONTENT_TYPE'] = str(content_type) return super().generic( method, path, data, content_type, secure, **extra) def request(self, **kwargs): request = super().request(**kwargs) request._dont_enforce_csrf_checks = not self.enforce_csrf_checks return request class ForceAuthClientHandler(ClientHandler): """ A patched version of ClientHandler that can enforce authentication on the outgoing requests. """ def __init__(self, *args, **kwargs): self._force_user = None self._force_token = None super().__init__(*args, **kwargs) def get_response(self, request): # This is the simplest place we can hook into to patch the # request object. force_authenticate(request, self._force_user, self._force_token) return super().get_response(request) class APIClient(APIRequestFactory, DjangoClient): def __init__(self, enforce_csrf_checks=False, **defaults): super().__init__(**defaults) self.handler = ForceAuthClientHandler(enforce_csrf_checks) self._credentials = {} def credentials(self, **kwargs): """ Sets headers that will be used on every outgoing request. """ self._credentials = kwargs def force_authenticate(self, user=None, token=None): """ Forcibly authenticates outgoing requests with the given user and/or token. """ self.handler._force_user = user self.handler._force_token = token if user is None: self.logout() # Also clear any possible session info if required def request(self, **kwargs): # Ensure that any credentials set get added to every request. kwargs.update(self._credentials) return super().request(**kwargs) def get(self, path, data=None, follow=False, **extra): response = super().get(path, data=data, **extra) if follow: response = self._handle_redirects(response, data=data, **extra) return response def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super().post( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra) return response def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super().put( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra) return response def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super().patch( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra) return response def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super().delete( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra) return response def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super().options( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra) return response def logout(self): self._credentials = {} # Also clear any `force_authenticate` self.handler._force_user = None self.handler._force_token = None if self.session: super().logout() class APITransactionTestCase(testcases.TransactionTestCase): client_class = APIClient class APITestCase(testcases.TestCase): client_class = APIClient class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient class APILiveServerTestCase(testcases.LiveServerTestCase): client_class = APIClient def cleanup_url_patterns(cls): if hasattr(cls, '_module_urlpatterns'): cls._module.urlpatterns = cls._module_urlpatterns else: del cls._module.urlpatterns class URLPatternsTestCase(testcases.SimpleTestCase): """ Isolate URL patterns on a per-TestCase basis. For example, class ATestCase(URLPatternsTestCase): urlpatterns = [...] def test_something(self): ... class AnotherTestCase(URLPatternsTestCase): urlpatterns = [...] def test_something_else(self): ... """ @classmethod def setUpClass(cls): # Get the module of the TestCase subclass cls._module = import_module(cls.__module__) cls._override = override_settings(ROOT_URLCONF=cls.__module__) if hasattr(cls._module, 'urlpatterns'): cls._module_urlpatterns = cls._module.urlpatterns cls._module.urlpatterns = cls.urlpatterns cls._override.enable() if django.VERSION > (4, 0): cls.addClassCleanup(cls._override.disable) cls.addClassCleanup(cleanup_url_patterns, cls) super().setUpClass() if django.VERSION < (4, 0): @classmethod def tearDownClass(cls): super().tearDownClass() cls._override.disable() if hasattr(cls, '_module_urlpatterns'): cls._module.urlpatterns = cls._module_urlpatterns else: del cls._module.urlpatterns