This commit is contained in:
Tom Christie 2016-09-28 09:55:41 +00:00 committed by GitHub
commit 490729fdee
3 changed files with 359 additions and 0 deletions

View File

@ -178,6 +178,13 @@ except (ImportError, SyntaxError):
uritemplate = None uritemplate = None
# requests is optional
try:
import requests
except ImportError:
requests = None
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
# Fixes (#1712). We keep the try/except for the test suite. # Fixes (#1712). We keep the try/except for the test suite.
guardian = None guardian = None

View File

@ -4,7 +4,10 @@
# to make it harder for the user to import the wrong thing without realizing. # to make it harder for the user to import the wrong thing without realizing.
from __future__ import unicode_literals from __future__ import unicode_literals
import io
from django.conf import settings from django.conf import settings
from django.core.handlers.wsgi import WSGIHandler
from django.test import testcases from django.test import testcases
from django.test.client import Client as DjangoClient from django.test.client import Client as DjangoClient
from django.test.client import RequestFactory as DjangoRequestFactory from django.test.client import RequestFactory as DjangoRequestFactory
@ -13,6 +16,7 @@ from django.utils import six
from django.utils.encoding import force_bytes from django.utils.encoding import force_bytes
from django.utils.http import urlencode from django.utils.http import urlencode
from rest_framework.compat import requests
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -21,6 +25,107 @@ def force_authenticate(request, user=None, token=None):
request._force_auth_token = token request._force_auth_token = token
if requests is not None:
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default):
return self.getheaders(key)
class MockOriginalResponse(object):
def __init__(self, headers):
self.msg = HeaderDict(headers)
self.closed = False
def isclosed(self):
return self.closed
def close(self):
self.closed = True
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
"""
A transport adapter for `requests`, that makes requests via the
Django WSGI app, rather than making actual HTTP requests over the network.
"""
def __init__(self):
self.app = WSGIHandler()
self.factory = DjangoRequestFactory()
def get_environ(self, request):
"""
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
"""
method = request.method
url = request.url
kwargs = {}
# Set request content, if any exists.
if request.body is not None:
kwargs['data'] = request.body
if 'content-type' in request.headers:
kwargs['content_type'] = request.headers['content-type']
# Set request headers.
for key, value in request.headers.items():
key = key.upper()
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
continue
kwargs['HTTP_%s' % key.replace('-', '_')] = value
return self.factory.generic(method, url, **kwargs).environ
def send(self, request, *args, **kwargs):
"""
Make an outgoing request to the Django WSGI application.
"""
raw_kwargs = {}
def start_response(wsgi_status, wsgi_headers):
status, _, reason = wsgi_status.partition(' ')
raw_kwargs['status'] = int(status)
raw_kwargs['reason'] = reason
raw_kwargs['headers'] = wsgi_headers
raw_kwargs['version'] = 11
raw_kwargs['preload_content'] = False
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
# Make the outgoing request via WSGI.
environ = self.get_environ(request)
wsgi_response = self.app(environ, start_response)
# Build the underlying urllib3.HTTPResponse
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
# Build the requests.Response
return self.build_response(request, raw)
def close(self):
pass
class DjangoTestSession(requests.Session):
def __init__(self, *args, **kwargs):
super(DjangoTestSession, self).__init__(*args, **kwargs)
adapter = DjangoTestAdapter()
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
for hostname in hostnames:
if hostname == '*':
hostname = ''
self.mount('http://%s' % hostname, adapter)
self.mount('https://%s' % hostname, adapter)
def request(self, method, url, *args, **kwargs):
if ':' not in url:
url = 'http://testserver/' + url.lstrip('/')
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
def get_requests_client():
assert requests is not None, 'requests must be installed'
return DjangoTestSession()
class APIRequestFactory(DjangoRequestFactory): class APIRequestFactory(DjangoRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT

View File

@ -0,0 +1,247 @@
from __future__ import unicode_literals
import unittest
from django.conf.urls import url
from django.contrib.auth import authenticate, login
from django.contrib.auth.models import User
from django.shortcuts import redirect
from django.test import override_settings
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
from rest_framework.compat import is_authenticated, requests
from rest_framework.response import Response
from rest_framework.test import APITestCase, get_requests_client
from rest_framework.views import APIView
class Root(APIView):
def get(self, request):
return Response({
'method': request.method,
'query_params': request.query_params,
})
def post(self, request):
files = {
key: (value.name, value.read())
for key, value in request.FILES.items()
}
post = request.POST
json = None
if request.META.get('CONTENT_TYPE') == 'application/json':
json = request.data
return Response({
'method': request.method,
'query_params': request.query_params,
'POST': post,
'FILES': files,
'JSON': json
})
class HeadersView(APIView):
def get(self, request):
headers = {
key[5:].replace('_', '-'): value
for key, value in request.META.items()
if key.startswith('HTTP_')
}
return Response({
'method': request.method,
'headers': headers
})
class SessionView(APIView):
def get(self, request):
return Response({
key: value for key, value in request.session.items()
})
def post(self, request):
for key, value in request.data.items():
request.session[key] = value
return Response({
key: value for key, value in request.session.items()
})
class AuthView(APIView):
@method_decorator(ensure_csrf_cookie)
def get(self, request):
if is_authenticated(request.user):
username = request.user.username
else:
username = None
return Response({
'username': username
})
@method_decorator(csrf_protect)
def post(self, request):
username = request.data['username']
password = request.data['password']
user = authenticate(username=username, password=password)
if user is None:
return Response({'error': 'incorrect credentials'})
login(request, user)
return redirect('/auth/')
urlpatterns = [
url(r'^$', Root.as_view()),
url(r'^headers/$', HeadersView.as_view()),
url(r'^session/$', SessionView.as_view()),
url(r'^auth/$', AuthView.as_view()),
]
@unittest.skipUnless(requests, 'requests not installed')
@override_settings(ROOT_URLCONF='tests.test_requests_client')
class RequestsClientTests(APITestCase):
def test_get_request(self):
client = get_requests_client()
response = client.get('/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {}
}
assert response.json() == expected
def test_get_request_query_params_in_url(self):
client = get_requests_client()
response = client.get('/?key=value')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {'key': 'value'}
}
assert response.json() == expected
def test_get_request_query_params_by_kwarg(self):
client = get_requests_client()
response = client.get('/', params={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {'key': 'value'}
}
assert response.json() == expected
def test_get_with_headers(self):
client = get_requests_client()
response = client.get('/headers/', headers={'User-Agent': 'example'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
headers = response.json()['headers']
assert headers['USER-AGENT'] == 'example'
def test_post_form_request(self):
client = get_requests_client()
response = client.post('/', data={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'POST': {'key': 'value'},
'FILES': {},
'JSON': None
}
assert response.json() == expected
def test_post_json_request(self):
client = get_requests_client()
response = client.post('/', json={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'POST': {},
'FILES': {},
'JSON': {'key': 'value'}
}
assert response.json() == expected
def test_post_multipart_request(self):
client = get_requests_client()
files = {
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
}
response = client.post('/', files=files)
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
'POST': {},
'JSON': None
}
assert response.json() == expected
def test_session(self):
client = get_requests_client()
response = client.get('/session/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {}
assert response.json() == expected
response = client.post('/session/', json={'example': 'abc'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'}
assert response.json() == expected
response = client.get('/session/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'}
assert response.json() == expected
def test_auth(self):
# Confirm session is not authenticated
client = get_requests_client()
response = client.get('/auth/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'username': None
}
assert response.json() == expected
assert 'csrftoken' in response.cookies
csrftoken = response.cookies['csrftoken']
user = User.objects.create(username='tom')
user.set_password('password')
user.save()
# Perform a login
response = client.post('/auth/', json={
'username': 'tom',
'password': 'password'
}, headers={'X-CSRFToken': csrftoken})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'username': 'tom'
}
assert response.json() == expected
# Confirm session is authenticated
response = client.get('/auth/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'username': 'tom'
}
assert response.json() == expected