mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-26 03:23:59 +03:00
0dec36eb41
* Start test case * Added 'requests' test client * Address typos * Graceful fallback if requests is not installed. * Add cookie support * Tests for auth and CSRF * Py3 compat * py3 compat * py3 compat * Add get_requests_client * Added SchemaGenerator.should_include_link * add settings for html cutoff on related fields * Router doesn't work if prefix is blank, though project urls.py handles prefix * Fix Django 1.10 to-many deprecation * Add django.core.urlresolvers compatibility * Update django-filter & django-guardian * Check for empty router prefix; adjust URL accordingly It's easiest to fix this issue after we have made the regex. To try to fix it before would require doing something different for List vs Detail, which means we'd have to know which type of url we're constructing before acting accordingly. * Fix misc django deprecations * Use TOC extension instead of header * Fix deprecations for py3k * Add py3k compatibility to is_simple_callable * Add is_simple_callable tests * Drop python 3.2 support (EOL, Dropped by Django) * schema_renderers= should *set* the renderers, not append to them. * API client (#4424) * Fix release notes * Add note about 'User account is disabled.' vs 'Unable to log in' * Clean up schema generation (#4527) * Handle multiple methods on custom action (#4529) * RequestsClient, CoreAPIClient * exclude_from_schema * Added 'get_schema_view()' shortcut * Added schema descriptions * Better descriptions for schemas * Add type annotation to schema generation * Coerce schema 'pk' in path to actual field name * Deprecations move into assertion errors * Use get_schema_view in tests * Updte CoreJSON media type * Handle schema structure correctly when path prefixs exist. Closes #4401 * Add PendingDeprecation to Router schema generation. * Added SCHEMA_COERCE_PATH_PK and SCHEMA_COERCE_METHOD_NAMES * Renamed and documented 'get_schema_fields' interface.
361 lines
13 KiB
Python
361 lines
13 KiB
Python
# -- coding: utf-8 --
|
|
|
|
# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
|
|
# to make it harder for the user to import the wrong thing without realizing.
|
|
from __future__ import unicode_literals
|
|
|
|
import io
|
|
|
|
from django.conf import settings
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
from django.core.handlers.wsgi import WSGIHandler
|
|
from django.test import testcases
|
|
from django.test.client import Client as DjangoClient
|
|
from django.test.client import RequestFactory as DjangoRequestFactory
|
|
from django.test.client import ClientHandler
|
|
from django.utils import six
|
|
from django.utils.encoding import force_bytes
|
|
from django.utils.http import urlencode
|
|
|
|
from rest_framework.compat import coreapi, requests
|
|
from rest_framework.settings import api_settings
|
|
|
|
|
|
def force_authenticate(request, user=None, token=None):
|
|
request._force_auth_user = user
|
|
request._force_auth_token = token
|
|
|
|
|
|
if requests is not None:
|
|
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
|
def get_all(self, key, default):
|
|
return self.getheaders(key)
|
|
|
|
class MockOriginalResponse(object):
|
|
def __init__(self, headers):
|
|
self.msg = HeaderDict(headers)
|
|
self.closed = False
|
|
|
|
def isclosed(self):
|
|
return self.closed
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
|
|
"""
|
|
A transport adapter for `requests`, that makes requests via the
|
|
Django WSGI app, rather than making actual HTTP requests over the network.
|
|
"""
|
|
def __init__(self):
|
|
self.app = WSGIHandler()
|
|
self.factory = DjangoRequestFactory()
|
|
|
|
def get_environ(self, request):
|
|
"""
|
|
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
|
"""
|
|
method = request.method
|
|
url = request.url
|
|
kwargs = {}
|
|
|
|
# Set request content, if any exists.
|
|
if request.body is not None:
|
|
if hasattr(request.body, 'read'):
|
|
kwargs['data'] = request.body.read()
|
|
else:
|
|
kwargs['data'] = request.body
|
|
if 'content-type' in request.headers:
|
|
kwargs['content_type'] = request.headers['content-type']
|
|
|
|
# Set request headers.
|
|
for key, value in request.headers.items():
|
|
key = key.upper()
|
|
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
|
|
continue
|
|
kwargs['HTTP_%s' % key.replace('-', '_')] = value
|
|
|
|
return self.factory.generic(method, url, **kwargs).environ
|
|
|
|
def send(self, request, *args, **kwargs):
|
|
"""
|
|
Make an outgoing request to the Django WSGI application.
|
|
"""
|
|
raw_kwargs = {}
|
|
|
|
def start_response(wsgi_status, wsgi_headers):
|
|
status, _, reason = wsgi_status.partition(' ')
|
|
raw_kwargs['status'] = int(status)
|
|
raw_kwargs['reason'] = reason
|
|
raw_kwargs['headers'] = wsgi_headers
|
|
raw_kwargs['version'] = 11
|
|
raw_kwargs['preload_content'] = False
|
|
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
|
|
|
|
# Make the outgoing request via WSGI.
|
|
environ = self.get_environ(request)
|
|
wsgi_response = self.app(environ, start_response)
|
|
|
|
# Build the underlying urllib3.HTTPResponse
|
|
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
|
|
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
|
|
|
|
# Build the requests.Response
|
|
return self.build_response(request, raw)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
class NoExternalRequestsAdapter(requests.adapters.HTTPAdapter):
|
|
def send(self, request, *args, **kwargs):
|
|
msg = (
|
|
'RequestsClient refusing to make an outgoing network request '
|
|
'to "%s". Only "testserver" or hostnames in your ALLOWED_HOSTS '
|
|
'setting are valid.' % request.url
|
|
)
|
|
raise RuntimeError(msg)
|
|
|
|
class RequestsClient(requests.Session):
|
|
def __init__(self, *args, **kwargs):
|
|
super(RequestsClient, self).__init__(*args, **kwargs)
|
|
adapter = DjangoTestAdapter()
|
|
self.mount('http://', adapter)
|
|
self.mount('https://', adapter)
|
|
|
|
def request(self, method, url, *args, **kwargs):
|
|
if ':' not in url:
|
|
raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
|
|
return super(RequestsClient, self).request(method, url, *args, **kwargs)
|
|
|
|
else:
|
|
def RequestsClient(*args, **kwargs):
|
|
raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
|
|
|
|
|
|
if coreapi is not None:
|
|
class CoreAPIClient(coreapi.Client):
|
|
def __init__(self, *args, **kwargs):
|
|
self._session = RequestsClient()
|
|
kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
|
|
return super(CoreAPIClient, self).__init__(*args, **kwargs)
|
|
|
|
@property
|
|
def session(self):
|
|
return self._session
|
|
|
|
else:
|
|
def CoreAPIClient(*args, **kwargs):
|
|
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
|
|
|
|
|
|
class APIRequestFactory(DjangoRequestFactory):
|
|
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
|
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
|
|
|
def __init__(self, enforce_csrf_checks=False, **defaults):
|
|
self.enforce_csrf_checks = enforce_csrf_checks
|
|
self.renderer_classes = {}
|
|
for cls in self.renderer_classes_list:
|
|
self.renderer_classes[cls.format] = cls
|
|
super(APIRequestFactory, self).__init__(**defaults)
|
|
|
|
def _encode_data(self, data, format=None, content_type=None):
|
|
"""
|
|
Encode the data returning a two tuple of (bytes, content_type)
|
|
"""
|
|
|
|
if data is None:
|
|
return ('', content_type)
|
|
|
|
assert format is None or content_type is None, (
|
|
'You may not set both `format` and `content_type`.'
|
|
)
|
|
|
|
if content_type:
|
|
# Content type specified explicitly, treat data as a raw bytestring
|
|
ret = force_bytes(data, settings.DEFAULT_CHARSET)
|
|
|
|
else:
|
|
format = format or self.default_format
|
|
|
|
assert format in self.renderer_classes, (
|
|
"Invalid format '{0}'. Available formats are {1}. "
|
|
"Set TEST_REQUEST_RENDERER_CLASSES to enable "
|
|
"extra request formats.".format(
|
|
format,
|
|
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
|
|
)
|
|
)
|
|
|
|
# Use format and render the data into a bytestring
|
|
renderer = self.renderer_classes[format]()
|
|
ret = renderer.render(data)
|
|
|
|
# Determine the content-type header from the renderer
|
|
content_type = "{0}; charset={1}".format(
|
|
renderer.media_type, renderer.charset
|
|
)
|
|
|
|
# Coerce text to bytes if required.
|
|
if isinstance(ret, six.text_type):
|
|
ret = bytes(ret.encode(renderer.charset))
|
|
|
|
return ret, content_type
|
|
|
|
def get(self, path, data=None, **extra):
|
|
r = {
|
|
'QUERY_STRING': urlencode(data or {}, doseq=True),
|
|
}
|
|
if not data and '?' in path:
|
|
# Fix to support old behavior where you have the arguments in the
|
|
# url. See #1461.
|
|
query_string = force_bytes(path.split('?')[1])
|
|
if six.PY3:
|
|
query_string = query_string.decode('iso-8859-1')
|
|
r['QUERY_STRING'] = query_string
|
|
r.update(extra)
|
|
return self.generic('GET', path, **r)
|
|
|
|
def post(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('POST', path, data, content_type, **extra)
|
|
|
|
def put(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('PUT', path, data, content_type, **extra)
|
|
|
|
def patch(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('PATCH', path, data, content_type, **extra)
|
|
|
|
def delete(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('DELETE', path, data, content_type, **extra)
|
|
|
|
def options(self, path, data=None, format=None, content_type=None, **extra):
|
|
data, content_type = self._encode_data(data, format, content_type)
|
|
return self.generic('OPTIONS', path, data, content_type, **extra)
|
|
|
|
def request(self, **kwargs):
|
|
request = super(APIRequestFactory, self).request(**kwargs)
|
|
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
|
return request
|
|
|
|
|
|
class ForceAuthClientHandler(ClientHandler):
|
|
"""
|
|
A patched version of ClientHandler that can enforce authentication
|
|
on the outgoing requests.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self._force_user = None
|
|
self._force_token = None
|
|
super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
|
|
|
|
def get_response(self, request):
|
|
# This is the simplest place we can hook into to patch the
|
|
# request object.
|
|
force_authenticate(request, self._force_user, self._force_token)
|
|
return super(ForceAuthClientHandler, self).get_response(request)
|
|
|
|
|
|
class APIClient(APIRequestFactory, DjangoClient):
|
|
def __init__(self, enforce_csrf_checks=False, **defaults):
|
|
super(APIClient, self).__init__(**defaults)
|
|
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
|
|
self._credentials = {}
|
|
|
|
def credentials(self, **kwargs):
|
|
"""
|
|
Sets headers that will be used on every outgoing request.
|
|
"""
|
|
self._credentials = kwargs
|
|
|
|
def force_authenticate(self, user=None, token=None):
|
|
"""
|
|
Forcibly authenticates outgoing requests with the given
|
|
user and/or token.
|
|
"""
|
|
self.handler._force_user = user
|
|
self.handler._force_token = token
|
|
if user is None:
|
|
self.logout() # Also clear any possible session info if required
|
|
|
|
def request(self, **kwargs):
|
|
# Ensure that any credentials set get added to every request.
|
|
kwargs.update(self._credentials)
|
|
return super(APIClient, self).request(**kwargs)
|
|
|
|
def get(self, path, data=None, follow=False, **extra):
|
|
response = super(APIClient, self).get(path, data=data, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def post(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super(APIClient, self).post(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def put(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super(APIClient, self).put(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def patch(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super(APIClient, self).patch(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def delete(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super(APIClient, self).delete(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def options(self, path, data=None, format=None, content_type=None,
|
|
follow=False, **extra):
|
|
response = super(APIClient, self).options(
|
|
path, data=data, format=format, content_type=content_type, **extra)
|
|
if follow:
|
|
response = self._handle_redirects(response, **extra)
|
|
return response
|
|
|
|
def logout(self):
|
|
self._credentials = {}
|
|
|
|
# Also clear any `force_authenticate`
|
|
self.handler._force_user = None
|
|
self.handler._force_token = None
|
|
|
|
if self.session:
|
|
super(APIClient, self).logout()
|
|
|
|
|
|
class APITransactionTestCase(testcases.TransactionTestCase):
|
|
client_class = APIClient
|
|
|
|
|
|
class APITestCase(testcases.TestCase):
|
|
client_class = APIClient
|
|
|
|
|
|
class APISimpleTestCase(testcases.SimpleTestCase):
|
|
client_class = APIClient
|
|
|
|
|
|
class APILiveServerTestCase(testcases.LiveServerTestCase):
|
|
client_class = APIClient
|