# -- coding: utf-8 -- # Note that we import as `DjangoRequestFactory` and `DjangoClient` in order # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals import io from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient from django.test.client import RequestFactory as DjangoRequestFactory from django.test.client import ClientHandler from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode from rest_framework.compat import coreapi, requests from rest_framework.settings import api_settings def force_authenticate(request, user=None, token=None): request._force_auth_user = user request._force_auth_token = token if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): def get_all(self, key, default): return self.getheaders(key) class MockOriginalResponse(object): def __init__(self, headers): self.msg = HeaderDict(headers) self.closed = False def isclosed(self): return self.closed def close(self): self.closed = True class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the Django WSGI app, rather than making actual HTTP requests over the network. """ def __init__(self): self.app = WSGIHandler() self.factory = DjangoRequestFactory() def get_environ(self, request): """ Given a `requests.PreparedRequest` instance, return a WSGI environ dict. """ method = request.method url = request.url kwargs = {} # Set request content, if any exists. if request.body is not None: if hasattr(request.body, 'read'): kwargs['data'] = request.body.read() else: kwargs['data'] = request.body if 'content-type' in request.headers: kwargs['content_type'] = request.headers['content-type'] # Set request headers. for key, value in request.headers.items(): key = key.upper() if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): continue kwargs['HTTP_%s' % key.replace('-', '_')] = value return self.factory.generic(method, url, **kwargs).environ def send(self, request, *args, **kwargs): """ Make an outgoing request to the Django WSGI application. """ raw_kwargs = {} def start_response(wsgi_status, wsgi_headers): status, _, reason = wsgi_status.partition(' ') raw_kwargs['status'] = int(status) raw_kwargs['reason'] = reason raw_kwargs['headers'] = wsgi_headers raw_kwargs['version'] = 11 raw_kwargs['preload_content'] = False raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) # Make the outgoing request via WSGI. environ = self.get_environ(request) wsgi_response = self.app(environ, start_response) # Build the underlying urllib3.HTTPResponse raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) # Build the requests.Response return self.build_response(request, raw) def close(self): pass class RequestsClient(requests.Session): def __init__(self, *args, **kwargs): super(RequestsClient, self).__init__(*args, **kwargs) adapter = DjangoTestAdapter() self.mount('http://', adapter) self.mount('https://', adapter) def request(self, method, url, *args, **kwargs): if not url.startswith('http'): raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url) return super(RequestsClient, self).request(method, url, *args, **kwargs) else: def RequestsClient(*args, **kwargs): raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.') if coreapi is not None: class CoreAPIClient(coreapi.Client): def __init__(self, *args, **kwargs): self._session = RequestsClient() kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)] return super(CoreAPIClient, self).__init__(*args, **kwargs) @property def session(self): return self._session else: def CoreAPIClient(*args, **kwargs): raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.') class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT def __init__(self, enforce_csrf_checks=False, **defaults): self.enforce_csrf_checks = enforce_csrf_checks self.renderer_classes = {} for cls in self.renderer_classes_list: self.renderer_classes[cls.format] = cls super(APIRequestFactory, self).__init__(**defaults) def _encode_data(self, data, format=None, content_type=None): """ Encode the data returning a two tuple of (bytes, content_type) """ if data is None: return ('', content_type) assert format is None or content_type is None, ( 'You may not set both `format` and `content_type`.' ) if content_type: # Content type specified explicitly, treat data as a raw bytestring ret = force_bytes(data, settings.DEFAULT_CHARSET) else: format = format or self.default_format assert format in self.renderer_classes, ( "Invalid format '{0}'. Available formats are {1}. " "Set TEST_REQUEST_RENDERER_CLASSES to enable " "extra request formats.".format( format, ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) ) ) # Use format and render the data into a bytestring renderer = self.renderer_classes[format]() ret = renderer.render(data) # Determine the content-type header from the renderer content_type = "{0}; charset={1}".format( renderer.media_type, renderer.charset ) # Coerce text to bytes if required. if isinstance(ret, six.text_type): ret = bytes(ret.encode(renderer.charset)) return ret, content_type def get(self, path, data=None, **extra): r = { 'QUERY_STRING': urlencode(data or {}, doseq=True), } if not data and '?' in path: # Fix to support old behavior where you have the arguments in the # url. See #1461. query_string = force_bytes(path.split('?')[1]) if six.PY3: query_string = query_string.decode('iso-8859-1') r['QUERY_STRING'] = query_string r.update(extra) return self.generic('GET', path, **r) def post(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('POST', path, data, content_type, **extra) def put(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('PUT', path, data, content_type, **extra) def patch(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('PATCH', path, data, content_type, **extra) def delete(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('DELETE', path, data, content_type, **extra) def options(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('OPTIONS', path, data, content_type, **extra) def generic(self, method, path, data='', content_type='application/octet-stream', secure=False, **extra): # Include the CONTENT_TYPE, regardless of whether or not data is empty. if content_type is not None: extra['CONTENT_TYPE'] = str(content_type) return super(APIRequestFactory, self).generic( method, path, data, content_type, secure, **extra) def request(self, **kwargs): request = super(APIRequestFactory, self).request(**kwargs) request._dont_enforce_csrf_checks = not self.enforce_csrf_checks return request class ForceAuthClientHandler(ClientHandler): """ A patched version of ClientHandler that can enforce authentication on the outgoing requests. """ def __init__(self, *args, **kwargs): self._force_user = None self._force_token = None super(ForceAuthClientHandler, self).__init__(*args, **kwargs) def get_response(self, request): # This is the simplest place we can hook into to patch the # request object. force_authenticate(request, self._force_user, self._force_token) return super(ForceAuthClientHandler, self).get_response(request) class APIClient(APIRequestFactory, DjangoClient): def __init__(self, enforce_csrf_checks=False, **defaults): super(APIClient, self).__init__(**defaults) self.handler = ForceAuthClientHandler(enforce_csrf_checks) self._credentials = {} def credentials(self, **kwargs): """ Sets headers that will be used on every outgoing request. """ self._credentials = kwargs def force_authenticate(self, user=None, token=None): """ Forcibly authenticates outgoing requests with the given user and/or token. """ self.handler._force_user = user self.handler._force_token = token if user is None: self.logout() # Also clear any possible session info if required def request(self, **kwargs): # Ensure that any credentials set get added to every request. kwargs.update(self._credentials) return super(APIClient, self).request(**kwargs) def get(self, path, data=None, follow=False, **extra): response = super(APIClient, self).get(path, data=data, **extra) if follow: response = self._handle_redirects(response, **extra) return response def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super(APIClient, self).post( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super(APIClient, self).put( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super(APIClient, self).patch( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super(APIClient, self).delete( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): response = super(APIClient, self).options( path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def logout(self): self._credentials = {} # Also clear any `force_authenticate` self.handler._force_user = None self.handler._force_token = None if self.session: super(APIClient, self).logout() class APITransactionTestCase(testcases.TransactionTestCase): client_class = APIClient class APITestCase(testcases.TestCase): client_class = APIClient class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient class APILiveServerTestCase(testcases.LiveServerTestCase): client_class = APIClient