mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-10-31 16:07:38 +03:00 
			
		
		
		
	* Fixed regression that tests using format still work Error only occurred on tests which return no content and use a renderer without charset (e.g. JSONRenderer) * Fixed linting * Used early return as before * Move ret str check back to where it was
		
			
				
	
	
		
			404 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			404 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
 | |
| # to make it harder for the user to import the wrong thing without realizing.
 | |
| import io
 | |
| from importlib import import_module
 | |
| 
 | |
| from django.conf import settings
 | |
| from django.core.exceptions import ImproperlyConfigured
 | |
| from django.core.handlers.wsgi import WSGIHandler
 | |
| from django.test import override_settings, testcases
 | |
| from django.test.client import Client as DjangoClient
 | |
| from django.test.client import ClientHandler
 | |
| from django.test.client import RequestFactory as DjangoRequestFactory
 | |
| from django.utils.encoding import force_bytes
 | |
| from django.utils.http import urlencode
 | |
| 
 | |
| from rest_framework.compat import coreapi, requests
 | |
| from rest_framework.settings import api_settings
 | |
| 
 | |
| 
 | |
| def force_authenticate(request, user=None, token=None):
 | |
|     request._force_auth_user = user
 | |
|     request._force_auth_token = token
 | |
| 
 | |
| 
 | |
| if requests is not None:
 | |
|     class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
 | |
|         def get_all(self, key, default):
 | |
|             return self.getheaders(key)
 | |
| 
 | |
|     class MockOriginalResponse:
 | |
|         def __init__(self, headers):
 | |
|             self.msg = HeaderDict(headers)
 | |
|             self.closed = False
 | |
| 
 | |
|         def isclosed(self):
 | |
|             return self.closed
 | |
| 
 | |
|         def close(self):
 | |
|             self.closed = True
 | |
| 
 | |
|     class DjangoTestAdapter(requests.adapters.HTTPAdapter):
 | |
|         """
 | |
|         A transport adapter for `requests`, that makes requests via the
 | |
|         Django WSGI app, rather than making actual HTTP requests over the network.
 | |
|         """
 | |
|         def __init__(self):
 | |
|             self.app = WSGIHandler()
 | |
|             self.factory = DjangoRequestFactory()
 | |
| 
 | |
|         def get_environ(self, request):
 | |
|             """
 | |
|             Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
 | |
|             """
 | |
|             method = request.method
 | |
|             url = request.url
 | |
|             kwargs = {}
 | |
| 
 | |
|             # Set request content, if any exists.
 | |
|             if request.body is not None:
 | |
|                 if hasattr(request.body, 'read'):
 | |
|                     kwargs['data'] = request.body.read()
 | |
|                 else:
 | |
|                     kwargs['data'] = request.body
 | |
|             if 'content-type' in request.headers:
 | |
|                 kwargs['content_type'] = request.headers['content-type']
 | |
| 
 | |
|             # Set request headers.
 | |
|             for key, value in request.headers.items():
 | |
|                 key = key.upper()
 | |
|                 if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
 | |
|                     continue
 | |
|                 kwargs['HTTP_%s' % key.replace('-', '_')] = value
 | |
| 
 | |
|             return self.factory.generic(method, url, **kwargs).environ
 | |
| 
 | |
|         def send(self, request, *args, **kwargs):
 | |
|             """
 | |
|             Make an outgoing request to the Django WSGI application.
 | |
|             """
 | |
|             raw_kwargs = {}
 | |
| 
 | |
|             def start_response(wsgi_status, wsgi_headers, exc_info=None):
 | |
|                 status, _, reason = wsgi_status.partition(' ')
 | |
|                 raw_kwargs['status'] = int(status)
 | |
|                 raw_kwargs['reason'] = reason
 | |
|                 raw_kwargs['headers'] = wsgi_headers
 | |
|                 raw_kwargs['version'] = 11
 | |
|                 raw_kwargs['preload_content'] = False
 | |
|                 raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
 | |
| 
 | |
|             # Make the outgoing request via WSGI.
 | |
|             environ = self.get_environ(request)
 | |
|             wsgi_response = self.app(environ, start_response)
 | |
| 
 | |
|             # Build the underlying urllib3.HTTPResponse
 | |
|             raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
 | |
|             raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
 | |
| 
 | |
|             # Build the requests.Response
 | |
|             return self.build_response(request, raw)
 | |
| 
 | |
|         def close(self):
 | |
|             pass
 | |
| 
 | |
|     class RequestsClient(requests.Session):
 | |
|         def __init__(self, *args, **kwargs):
 | |
|             super().__init__(*args, **kwargs)
 | |
|             adapter = DjangoTestAdapter()
 | |
|             self.mount('http://', adapter)
 | |
|             self.mount('https://', adapter)
 | |
| 
 | |
|         def request(self, method, url, *args, **kwargs):
 | |
|             if not url.startswith('http'):
 | |
|                 raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
 | |
|             return super().request(method, url, *args, **kwargs)
 | |
| 
 | |
| else:
 | |
|     def RequestsClient(*args, **kwargs):
 | |
|         raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
 | |
| 
 | |
| 
 | |
| if coreapi is not None:
 | |
|     class CoreAPIClient(coreapi.Client):
 | |
|         def __init__(self, *args, **kwargs):
 | |
|             self._session = RequestsClient()
 | |
|             kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
 | |
|             super().__init__(*args, **kwargs)
 | |
| 
 | |
|         @property
 | |
|         def session(self):
 | |
|             return self._session
 | |
| 
 | |
| else:
 | |
|     def CoreAPIClient(*args, **kwargs):
 | |
|         raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
 | |
| 
 | |
| 
 | |
| class APIRequestFactory(DjangoRequestFactory):
 | |
|     renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
 | |
|     default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
 | |
| 
 | |
|     def __init__(self, enforce_csrf_checks=False, **defaults):
 | |
|         self.enforce_csrf_checks = enforce_csrf_checks
 | |
|         self.renderer_classes = {}
 | |
|         for cls in self.renderer_classes_list:
 | |
|             self.renderer_classes[cls.format] = cls
 | |
|         super().__init__(**defaults)
 | |
| 
 | |
|     def _encode_data(self, data, format=None, content_type=None):
 | |
|         """
 | |
|         Encode the data returning a two tuple of (bytes, content_type)
 | |
|         """
 | |
|         if data is None:
 | |
|             return (b'', content_type)
 | |
| 
 | |
|         assert format is None or content_type is None, (
 | |
|             'You may not set both `format` and `content_type`.'
 | |
|         )
 | |
| 
 | |
|         if content_type:
 | |
|             try:
 | |
|                 data = self._encode_json(data, content_type)
 | |
|             except AttributeError:
 | |
|                 pass
 | |
| 
 | |
|             # Content type specified explicitly, treat data as a raw bytestring
 | |
|             ret = force_bytes(data, settings.DEFAULT_CHARSET)
 | |
| 
 | |
|         else:
 | |
|             format = format or self.default_format
 | |
| 
 | |
|             assert format in self.renderer_classes, (
 | |
|                 "Invalid format '{}'. Available formats are {}. "
 | |
|                 "Set TEST_REQUEST_RENDERER_CLASSES to enable "
 | |
|                 "extra request formats.".format(
 | |
|                     format,
 | |
|                     ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|             # Use format and render the data into a bytestring
 | |
|             renderer = self.renderer_classes[format]()
 | |
|             ret = renderer.render(data)
 | |
| 
 | |
|             # Determine the content-type header from the renderer
 | |
|             content_type = renderer.media_type
 | |
|             if renderer.charset:
 | |
|                 content_type = "{}; charset={}".format(
 | |
|                     content_type, renderer.charset
 | |
|                 )
 | |
| 
 | |
|             # Coerce text to bytes if required.
 | |
|             if isinstance(ret, str):
 | |
|                 ret = ret.encode(renderer.charset)
 | |
| 
 | |
|         return ret, content_type
 | |
| 
 | |
|     def get(self, path, data=None, **extra):
 | |
|         r = {
 | |
|             'QUERY_STRING': urlencode(data or {}, doseq=True),
 | |
|         }
 | |
|         if not data and '?' in path:
 | |
|             # Fix to support old behavior where you have the arguments in the
 | |
|             # url. See #1461.
 | |
|             query_string = force_bytes(path.split('?')[1])
 | |
|             query_string = query_string.decode('iso-8859-1')
 | |
|             r['QUERY_STRING'] = query_string
 | |
|         r.update(extra)
 | |
|         return self.generic('GET', path, **r)
 | |
| 
 | |
|     def post(self, path, data=None, format=None, content_type=None, **extra):
 | |
|         data, content_type = self._encode_data(data, format, content_type)
 | |
|         return self.generic('POST', path, data, content_type, **extra)
 | |
| 
 | |
|     def put(self, path, data=None, format=None, content_type=None, **extra):
 | |
|         data, content_type = self._encode_data(data, format, content_type)
 | |
|         return self.generic('PUT', path, data, content_type, **extra)
 | |
| 
 | |
|     def patch(self, path, data=None, format=None, content_type=None, **extra):
 | |
|         data, content_type = self._encode_data(data, format, content_type)
 | |
|         return self.generic('PATCH', path, data, content_type, **extra)
 | |
| 
 | |
|     def delete(self, path, data=None, format=None, content_type=None, **extra):
 | |
|         data, content_type = self._encode_data(data, format, content_type)
 | |
|         return self.generic('DELETE', path, data, content_type, **extra)
 | |
| 
 | |
|     def options(self, path, data=None, format=None, content_type=None, **extra):
 | |
|         data, content_type = self._encode_data(data, format, content_type)
 | |
|         return self.generic('OPTIONS', path, data, content_type, **extra)
 | |
| 
 | |
|     def generic(self, method, path, data='',
 | |
|                 content_type='application/octet-stream', secure=False, **extra):
 | |
|         # Include the CONTENT_TYPE, regardless of whether or not data is empty.
 | |
|         if content_type is not None:
 | |
|             extra['CONTENT_TYPE'] = str(content_type)
 | |
| 
 | |
|         return super().generic(
 | |
|             method, path, data, content_type, secure, **extra)
 | |
| 
 | |
|     def request(self, **kwargs):
 | |
|         request = super().request(**kwargs)
 | |
|         request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
 | |
|         return request
 | |
| 
 | |
| 
 | |
| class ForceAuthClientHandler(ClientHandler):
 | |
|     """
 | |
|     A patched version of ClientHandler that can enforce authentication
 | |
|     on the outgoing requests.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         self._force_user = None
 | |
|         self._force_token = None
 | |
|         super().__init__(*args, **kwargs)
 | |
| 
 | |
|     def get_response(self, request):
 | |
|         # This is the simplest place we can hook into to patch the
 | |
|         # request object.
 | |
|         force_authenticate(request, self._force_user, self._force_token)
 | |
|         return super().get_response(request)
 | |
| 
 | |
| 
 | |
| class APIClient(APIRequestFactory, DjangoClient):
 | |
|     def __init__(self, enforce_csrf_checks=False, **defaults):
 | |
|         super().__init__(**defaults)
 | |
|         self.handler = ForceAuthClientHandler(enforce_csrf_checks)
 | |
|         self._credentials = {}
 | |
| 
 | |
|     def credentials(self, **kwargs):
 | |
|         """
 | |
|         Sets headers that will be used on every outgoing request.
 | |
|         """
 | |
|         self._credentials = kwargs
 | |
| 
 | |
|     def force_authenticate(self, user=None, token=None):
 | |
|         """
 | |
|         Forcibly authenticates outgoing requests with the given
 | |
|         user and/or token.
 | |
|         """
 | |
|         self.handler._force_user = user
 | |
|         self.handler._force_token = token
 | |
|         if user is None and token is None:
 | |
|             self.logout()  # Also clear any possible session info if required
 | |
| 
 | |
|     def request(self, **kwargs):
 | |
|         # Ensure that any credentials set get added to every request.
 | |
|         kwargs.update(self._credentials)
 | |
|         return super().request(**kwargs)
 | |
| 
 | |
|     def get(self, path, data=None, follow=False, **extra):
 | |
|         response = super().get(path, data=data, **extra)
 | |
|         if follow:
 | |
|             response = self._handle_redirects(response, data=data, **extra)
 | |
|         return response
 | |
| 
 | |
|     def post(self, path, data=None, format=None, content_type=None,
 | |
|              follow=False, **extra):
 | |
|         response = super().post(
 | |
|             path, data=data, format=format, content_type=content_type, **extra)
 | |
|         if follow:
 | |
|             response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | |
|         return response
 | |
| 
 | |
|     def put(self, path, data=None, format=None, content_type=None,
 | |
|             follow=False, **extra):
 | |
|         response = super().put(
 | |
|             path, data=data, format=format, content_type=content_type, **extra)
 | |
|         if follow:
 | |
|             response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | |
|         return response
 | |
| 
 | |
|     def patch(self, path, data=None, format=None, content_type=None,
 | |
|               follow=False, **extra):
 | |
|         response = super().patch(
 | |
|             path, data=data, format=format, content_type=content_type, **extra)
 | |
|         if follow:
 | |
|             response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | |
|         return response
 | |
| 
 | |
|     def delete(self, path, data=None, format=None, content_type=None,
 | |
|                follow=False, **extra):
 | |
|         response = super().delete(
 | |
|             path, data=data, format=format, content_type=content_type, **extra)
 | |
|         if follow:
 | |
|             response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | |
|         return response
 | |
| 
 | |
|     def options(self, path, data=None, format=None, content_type=None,
 | |
|                 follow=False, **extra):
 | |
|         response = super().options(
 | |
|             path, data=data, format=format, content_type=content_type, **extra)
 | |
|         if follow:
 | |
|             response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
 | |
|         return response
 | |
| 
 | |
|     def logout(self):
 | |
|         self._credentials = {}
 | |
| 
 | |
|         # Also clear any `force_authenticate`
 | |
|         self.handler._force_user = None
 | |
|         self.handler._force_token = None
 | |
| 
 | |
|         if self.session:
 | |
|             super().logout()
 | |
| 
 | |
| 
 | |
| class APITransactionTestCase(testcases.TransactionTestCase):
 | |
|     client_class = APIClient
 | |
| 
 | |
| 
 | |
| class APITestCase(testcases.TestCase):
 | |
|     client_class = APIClient
 | |
| 
 | |
| 
 | |
| class APISimpleTestCase(testcases.SimpleTestCase):
 | |
|     client_class = APIClient
 | |
| 
 | |
| 
 | |
| class APILiveServerTestCase(testcases.LiveServerTestCase):
 | |
|     client_class = APIClient
 | |
| 
 | |
| 
 | |
| def cleanup_url_patterns(cls):
 | |
|     if hasattr(cls, '_module_urlpatterns'):
 | |
|         cls._module.urlpatterns = cls._module_urlpatterns
 | |
|     else:
 | |
|         del cls._module.urlpatterns
 | |
| 
 | |
| 
 | |
| class URLPatternsTestCase(testcases.SimpleTestCase):
 | |
|     """
 | |
|     Isolate URL patterns on a per-TestCase basis. For example,
 | |
| 
 | |
|     class ATestCase(URLPatternsTestCase):
 | |
|         urlpatterns = [...]
 | |
| 
 | |
|         def test_something(self):
 | |
|             ...
 | |
| 
 | |
|     class AnotherTestCase(URLPatternsTestCase):
 | |
|         urlpatterns = [...]
 | |
| 
 | |
|         def test_something_else(self):
 | |
|             ...
 | |
|     """
 | |
|     @classmethod
 | |
|     def setUpClass(cls):
 | |
|         # Get the module of the TestCase subclass
 | |
|         cls._module = import_module(cls.__module__)
 | |
|         cls._override = override_settings(ROOT_URLCONF=cls.__module__)
 | |
| 
 | |
|         if hasattr(cls._module, 'urlpatterns'):
 | |
|             cls._module_urlpatterns = cls._module.urlpatterns
 | |
| 
 | |
|         cls._module.urlpatterns = cls.urlpatterns
 | |
| 
 | |
|         cls._override.enable()
 | |
| 
 | |
|         cls.addClassCleanup(cls._override.disable)
 | |
|         cls.addClassCleanup(cleanup_url_patterns, cls)
 | |
| 
 | |
|         super().setUpClass()
 |