mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			400 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			400 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
 | 
						|
# to make it harder for the user to import the wrong thing without realizing.
 | 
						|
import io
 | 
						|
from importlib import import_module
 | 
						|
 | 
						|
from django.conf import settings
 | 
						|
from django.core.exceptions import ImproperlyConfigured
 | 
						|
from django.core.handlers.wsgi import WSGIHandler
 | 
						|
from django.test import override_settings, testcases
 | 
						|
from django.test.client import Client as DjangoClient
 | 
						|
from django.test.client import ClientHandler
 | 
						|
from django.test.client import RequestFactory as DjangoRequestFactory
 | 
						|
from django.utils.encoding import force_bytes
 | 
						|
from django.utils.http import urlencode
 | 
						|
 | 
						|
from rest_framework.compat import coreapi, requests
 | 
						|
from rest_framework.settings import api_settings
 | 
						|
 | 
						|
 | 
						|
def force_authenticate(request, user=None, token=None):
 | 
						|
    request._force_auth_user = user
 | 
						|
    request._force_auth_token = token
 | 
						|
 | 
						|
 | 
						|
if requests is not None:
 | 
						|
    class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
 | 
						|
        def get_all(self, key, default):
 | 
						|
            return self.getheaders(key)
 | 
						|
 | 
						|
    class MockOriginalResponse:
 | 
						|
        def __init__(self, headers):
 | 
						|
            self.msg = HeaderDict(headers)
 | 
						|
            self.closed = False
 | 
						|
 | 
						|
        def isclosed(self):
 | 
						|
            return self.closed
 | 
						|
 | 
						|
        def close(self):
 | 
						|
            self.closed = True
 | 
						|
 | 
						|
    class DjangoTestAdapter(requests.adapters.HTTPAdapter):
 | 
						|
        """
 | 
						|
        A transport adapter for `requests`, that makes requests via the
 | 
						|
        Django WSGI app, rather than making actual HTTP requests over the network.
 | 
						|
        """
 | 
						|
        def __init__(self):
 | 
						|
            self.app = WSGIHandler()
 | 
						|
            self.factory = DjangoRequestFactory()
 | 
						|
 | 
						|
        def get_environ(self, request):
 | 
						|
            """
 | 
						|
            Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
 | 
						|
            """
 | 
						|
            method = request.method
 | 
						|
            url = request.url
 | 
						|
            kwargs = {}
 | 
						|
 | 
						|
            # Set request content, if any exists.
 | 
						|
            if request.body is not None:
 | 
						|
                if hasattr(request.body, 'read'):
 | 
						|
                    kwargs['data'] = request.body.read()
 | 
						|
                else:
 | 
						|
                    kwargs['data'] = request.body
 | 
						|
            if 'content-type' in request.headers:
 | 
						|
                kwargs['content_type'] = request.headers['content-type']
 | 
						|
 | 
						|
            # Set request headers.
 | 
						|
            for key, value in request.headers.items():
 | 
						|
                key = key.upper()
 | 
						|
                if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
 | 
						|
                    continue
 | 
						|
                kwargs['HTTP_%s' % key.replace('-', '_')] = value
 | 
						|
 | 
						|
            return self.factory.generic(method, url, **kwargs).environ
 | 
						|
 | 
						|
        def send(self, request, *args, **kwargs):
 | 
						|
            """
 | 
						|
            Make an outgoing request to the Django WSGI application.
 | 
						|
            """
 | 
						|
            raw_kwargs = {}
 | 
						|
 | 
						|
            def start_response(wsgi_status, wsgi_headers, exc_info=None):
 | 
						|
                status, _, reason = wsgi_status.partition(' ')
 | 
						|
                raw_kwargs['status'] = int(status)
 | 
						|
                raw_kwargs['reason'] = reason
 | 
						|
                raw_kwargs['headers'] = wsgi_headers
 | 
						|
                raw_kwargs['version'] = 11
 | 
						|
                raw_kwargs['preload_content'] = False
 | 
						|
                raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
 | 
						|
 | 
						|
            # Make the outgoing request via WSGI.
 | 
						|
            environ = self.get_environ(request)
 | 
						|
            wsgi_response = self.app(environ, start_response)
 | 
						|
 | 
						|
            # Build the underlying urllib3.HTTPResponse
 | 
						|
            raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
 | 
						|
            raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
 | 
						|
 | 
						|
            # Build the requests.Response
 | 
						|
            return self.build_response(request, raw)
 | 
						|
 | 
						|
        def close(self):
 | 
						|
            pass
 | 
						|
 | 
						|
    class RequestsClient(requests.Session):
 | 
						|
        def __init__(self, *args, **kwargs):
 | 
						|
            super().__init__(*args, **kwargs)
 | 
						|
            adapter = DjangoTestAdapter()
 | 
						|
            self.mount('http://', adapter)
 | 
						|
            self.mount('https://', adapter)
 | 
						|
 | 
						|
        def request(self, method, url, *args, **kwargs):
 | 
						|
            if not url.startswith('http'):
 | 
						|
                raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
 | 
						|
            return super().request(method, url, *args, **kwargs)
 | 
						|
 | 
						|
else:
 | 
						|
    def RequestsClient(*args, **kwargs):
 | 
						|
        raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
 | 
						|
 | 
						|
 | 
						|
if coreapi is not None:
 | 
						|
    class CoreAPIClient(coreapi.Client):
 | 
						|
        def __init__(self, *args, **kwargs):
 | 
						|
            self._session = RequestsClient()
 | 
						|
            kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
 | 
						|
            super().__init__(*args, **kwargs)
 | 
						|
 | 
						|
        @property
 | 
						|
        def session(self):
 | 
						|
            return self._session
 | 
						|
 | 
						|
else:
 | 
						|
    def CoreAPIClient(*args, **kwargs):
 | 
						|
        raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
 | 
						|
 | 
						|
 | 
						|
class APIRequestFactory(DjangoRequestFactory):
 | 
						|
    renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
 | 
						|
    default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
 | 
						|
 | 
						|
    def __init__(self, enforce_csrf_checks=False, **defaults):
 | 
						|
        self.enforce_csrf_checks = enforce_csrf_checks
 | 
						|
        self.renderer_classes = {}
 | 
						|
        for cls in self.renderer_classes_list:
 | 
						|
            self.renderer_classes[cls.format] = cls
 | 
						|
        super().__init__(**defaults)
 | 
						|
 | 
						|
    def _encode_data(self, data, format=None, content_type=None):
 | 
						|
        """
 | 
						|
        Encode the data returning a two tuple of (bytes, content_type)
 | 
						|
        """
 | 
						|
 | 
						|
        if data is None:
 | 
						|
            return ('', content_type)
 | 
						|
 | 
						|
        assert format is None or content_type is None, (
 | 
						|
            'You may not set both `format` and `content_type`.'
 | 
						|
        )
 | 
						|
 | 
						|
        if content_type:
 | 
						|
            # Content type specified explicitly, treat data as a raw bytestring
 | 
						|
            ret = force_bytes(data, settings.DEFAULT_CHARSET)
 | 
						|
 | 
						|
        else:
 | 
						|
            format = format or self.default_format
 | 
						|
 | 
						|
            assert format in self.renderer_classes, (
 | 
						|
                "Invalid format '{}'. Available formats are {}. "
 | 
						|
                "Set TEST_REQUEST_RENDERER_CLASSES to enable "
 | 
						|
                "extra request formats.".format(
 | 
						|
                    format,
 | 
						|
                    ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
            # Use format and render the data into a bytestring
 | 
						|
            renderer = self.renderer_classes[format]()
 | 
						|
            ret = renderer.render(data)
 | 
						|
 | 
						|
            # Determine the content-type header from the renderer
 | 
						|
            content_type = renderer.media_type
 | 
						|
            if renderer.charset:
 | 
						|
                content_type = "{}; charset={}".format(
 | 
						|
                    content_type, renderer.charset
 | 
						|
                )
 | 
						|
 | 
						|
            # Coerce text to bytes if required.
 | 
						|
            if isinstance(ret, str):
 | 
						|
                ret = ret.encode(renderer.charset)
 | 
						|
 | 
						|
        return ret, content_type
 | 
						|
 | 
						|
    def get(self, path, data=None, **extra):
 | 
						|
        r = {
 | 
						|
            'QUERY_STRING': urlencode(data or {}, doseq=True),
 | 
						|
        }
 | 
						|
        if not data and '?' in path:
 | 
						|
            # Fix to support old behavior where you have the arguments in the
 | 
						|
            # url. See #1461.
 | 
						|
            query_string = force_bytes(path.split('?')[1])
 | 
						|
            query_string = query_string.decode('iso-8859-1')
 | 
						|
            r['QUERY_STRING'] = query_string
 | 
						|
        r.update(extra)
 | 
						|
        return self.generic('GET', path, **r)
 | 
						|
 | 
						|
    def post(self, path, data=None, format=None, content_type=None, **extra):
 | 
						|
        data, content_type = self._encode_data(data, format, content_type)
 | 
						|
        return self.generic('POST', path, data, content_type, **extra)
 | 
						|
 | 
						|
    def put(self, path, data=None, format=None, content_type=None, **extra):
 | 
						|
        data, content_type = self._encode_data(data, format, content_type)
 | 
						|
        return self.generic('PUT', path, data, content_type, **extra)
 | 
						|
 | 
						|
    def patch(self, path, data=None, format=None, content_type=None, **extra):
 | 
						|
        data, content_type = self._encode_data(data, format, content_type)
 | 
						|
        return self.generic('PATCH', path, data, content_type, **extra)
 | 
						|
 | 
						|
    def delete(self, path, data=None, format=None, content_type=None, **extra):
 | 
						|
        data, content_type = self._encode_data(data, format, content_type)
 | 
						|
        return self.generic('DELETE', path, data, content_type, **extra)
 | 
						|
 | 
						|
    def options(self, path, data=None, format=None, content_type=None, **extra):
 | 
						|
        data, content_type = self._encode_data(data, format, content_type)
 | 
						|
        return self.generic('OPTIONS', path, data, content_type, **extra)
 | 
						|
 | 
						|
    def generic(self, method, path, data='',
 | 
						|
                content_type='application/octet-stream', secure=False, **extra):
 | 
						|
        # Include the CONTENT_TYPE, regardless of whether or not data is empty.
 | 
						|
        if content_type is not None:
 | 
						|
            extra['CONTENT_TYPE'] = str(content_type)
 | 
						|
 | 
						|
        return super().generic(
 | 
						|
            method, path, data, content_type, secure, **extra)
 | 
						|
 | 
						|
    def request(self, **kwargs):
 | 
						|
        request = super().request(**kwargs)
 | 
						|
        request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
 | 
						|
        return request
 | 
						|
 | 
						|
 | 
						|
class ForceAuthClientHandler(ClientHandler):
 | 
						|
    """
 | 
						|
    A patched version of ClientHandler that can enforce authentication
 | 
						|
    on the outgoing requests.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, *args, **kwargs):
 | 
						|
        self._force_user = None
 | 
						|
        self._force_token = None
 | 
						|
        super().__init__(*args, **kwargs)
 | 
						|
 | 
						|
    def get_response(self, request):
 | 
						|
        # This is the simplest place we can hook into to patch the
 | 
						|
        # request object.
 | 
						|
        force_authenticate(request, self._force_user, self._force_token)
 | 
						|
        return super().get_response(request)
 | 
						|
 | 
						|
 | 
						|
class APIClient(APIRequestFactory, DjangoClient):
 | 
						|
    def __init__(self, enforce_csrf_checks=False, **defaults):
 | 
						|
        super().__init__(**defaults)
 | 
						|
        self.handler = ForceAuthClientHandler(enforce_csrf_checks)
 | 
						|
        self._credentials = {}
 | 
						|
 | 
						|
    def credentials(self, **kwargs):
 | 
						|
        """
 | 
						|
        Sets headers that will be used on every outgoing request.
 | 
						|
        """
 | 
						|
        self._credentials = kwargs
 | 
						|
 | 
						|
    def force_authenticate(self, user=None, token=None):
 | 
						|
        """
 | 
						|
        Forcibly authenticates outgoing requests with the given
 | 
						|
        user and/or token.
 | 
						|
        """
 | 
						|
        self.handler._force_user = user
 | 
						|
        self.handler._force_token = token
 | 
						|
        if user is None and token is None:
 | 
						|
            self.logout()  # Also clear any possible session info if required
 | 
						|
 | 
						|
    def request(self, **kwargs):
 | 
						|
        # Ensure that any credentials set get added to every request.
 | 
						|
        kwargs.update(self._credentials)
 | 
						|
        return super().request(**kwargs)
 | 
						|
 | 
						|
    def get(self, path, data=None, follow=False, **extra):
 | 
						|
        response = super().get(path, data=data, **extra)
 | 
						|
        if follow:
 | 
						|
            response = self._handle_redirects(response, data=data, **extra)
 | 
						|
        return response
 | 
						|
 | 
						|
    def post(self, path, data=None, format=None, content_type=None,
 | 
						|
             follow=False, **extra):
 | 
						|
        response = super().post(
 | 
						|
            path, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        if follow:
 | 
						|
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        return response
 | 
						|
 | 
						|
    def put(self, path, data=None, format=None, content_type=None,
 | 
						|
            follow=False, **extra):
 | 
						|
        response = super().put(
 | 
						|
            path, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        if follow:
 | 
						|
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        return response
 | 
						|
 | 
						|
    def patch(self, path, data=None, format=None, content_type=None,
 | 
						|
              follow=False, **extra):
 | 
						|
        response = super().patch(
 | 
						|
            path, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        if follow:
 | 
						|
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        return response
 | 
						|
 | 
						|
    def delete(self, path, data=None, format=None, content_type=None,
 | 
						|
               follow=False, **extra):
 | 
						|
        response = super().delete(
 | 
						|
            path, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        if follow:
 | 
						|
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        return response
 | 
						|
 | 
						|
    def options(self, path, data=None, format=None, content_type=None,
 | 
						|
                follow=False, **extra):
 | 
						|
        response = super().options(
 | 
						|
            path, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        if follow:
 | 
						|
            response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | 
						|
        return response
 | 
						|
 | 
						|
    def logout(self):
 | 
						|
        self._credentials = {}
 | 
						|
 | 
						|
        # Also clear any `force_authenticate`
 | 
						|
        self.handler._force_user = None
 | 
						|
        self.handler._force_token = None
 | 
						|
 | 
						|
        if self.session:
 | 
						|
            super().logout()
 | 
						|
 | 
						|
 | 
						|
class APITransactionTestCase(testcases.TransactionTestCase):
 | 
						|
    client_class = APIClient
 | 
						|
 | 
						|
 | 
						|
class APITestCase(testcases.TestCase):
 | 
						|
    client_class = APIClient
 | 
						|
 | 
						|
 | 
						|
class APISimpleTestCase(testcases.SimpleTestCase):
 | 
						|
    client_class = APIClient
 | 
						|
 | 
						|
 | 
						|
class APILiveServerTestCase(testcases.LiveServerTestCase):
 | 
						|
    client_class = APIClient
 | 
						|
 | 
						|
 | 
						|
def cleanup_url_patterns(cls):
 | 
						|
    if hasattr(cls, '_module_urlpatterns'):
 | 
						|
        cls._module.urlpatterns = cls._module_urlpatterns
 | 
						|
    else:
 | 
						|
        del cls._module.urlpatterns
 | 
						|
 | 
						|
 | 
						|
class URLPatternsTestCase(testcases.SimpleTestCase):
 | 
						|
    """
 | 
						|
    Isolate URL patterns on a per-TestCase basis. For example,
 | 
						|
 | 
						|
    class ATestCase(URLPatternsTestCase):
 | 
						|
        urlpatterns = [...]
 | 
						|
 | 
						|
        def test_something(self):
 | 
						|
            ...
 | 
						|
 | 
						|
    class AnotherTestCase(URLPatternsTestCase):
 | 
						|
        urlpatterns = [...]
 | 
						|
 | 
						|
        def test_something_else(self):
 | 
						|
            ...
 | 
						|
    """
 | 
						|
    @classmethod
 | 
						|
    def setUpClass(cls):
 | 
						|
        # Get the module of the TestCase subclass
 | 
						|
        cls._module = import_module(cls.__module__)
 | 
						|
        cls._override = override_settings(ROOT_URLCONF=cls.__module__)
 | 
						|
 | 
						|
        if hasattr(cls._module, 'urlpatterns'):
 | 
						|
            cls._module_urlpatterns = cls._module.urlpatterns
 | 
						|
 | 
						|
        cls._module.urlpatterns = cls.urlpatterns
 | 
						|
 | 
						|
        cls._override.enable()
 | 
						|
 | 
						|
        cls.addClassCleanup(cls._override.disable)
 | 
						|
        cls.addClassCleanup(cleanup_url_patterns, cls)
 | 
						|
 | 
						|
        super().setUpClass()
 |