mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 13:54:47 +03:00
Merge 0cc3f5008f
into 12576275c4
This commit is contained in:
commit
490729fdee
|
@ -178,6 +178,13 @@ except (ImportError, SyntaxError):
|
|||
uritemplate = None
|
||||
|
||||
|
||||
# requests is optional
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
|
||||
# Fixes (#1712). We keep the try/except for the test suite.
|
||||
guardian = None
|
||||
|
|
|
@ -4,7 +4,10 @@
|
|||
# to make it harder for the user to import the wrong thing without realizing.
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import io
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.handlers.wsgi import WSGIHandler
|
||||
from django.test import testcases
|
||||
from django.test.client import Client as DjangoClient
|
||||
from django.test.client import RequestFactory as DjangoRequestFactory
|
||||
|
@ -13,6 +16,7 @@ from django.utils import six
|
|||
from django.utils.encoding import force_bytes
|
||||
from django.utils.http import urlencode
|
||||
|
||||
from rest_framework.compat import requests
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
|
@ -21,6 +25,107 @@ def force_authenticate(request, user=None, token=None):
|
|||
request._force_auth_token = token
|
||||
|
||||
|
||||
if requests is not None:
|
||||
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
||||
def get_all(self, key, default):
|
||||
return self.getheaders(key)
|
||||
|
||||
class MockOriginalResponse(object):
|
||||
def __init__(self, headers):
|
||||
self.msg = HeaderDict(headers)
|
||||
self.closed = False
|
||||
|
||||
def isclosed(self):
|
||||
return self.closed
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
|
||||
"""
|
||||
A transport adapter for `requests`, that makes requests via the
|
||||
Django WSGI app, rather than making actual HTTP requests over the network.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.app = WSGIHandler()
|
||||
self.factory = DjangoRequestFactory()
|
||||
|
||||
def get_environ(self, request):
|
||||
"""
|
||||
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
||||
"""
|
||||
method = request.method
|
||||
url = request.url
|
||||
kwargs = {}
|
||||
|
||||
# Set request content, if any exists.
|
||||
if request.body is not None:
|
||||
kwargs['data'] = request.body
|
||||
if 'content-type' in request.headers:
|
||||
kwargs['content_type'] = request.headers['content-type']
|
||||
|
||||
# Set request headers.
|
||||
for key, value in request.headers.items():
|
||||
key = key.upper()
|
||||
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
|
||||
continue
|
||||
kwargs['HTTP_%s' % key.replace('-', '_')] = value
|
||||
|
||||
return self.factory.generic(method, url, **kwargs).environ
|
||||
|
||||
def send(self, request, *args, **kwargs):
|
||||
"""
|
||||
Make an outgoing request to the Django WSGI application.
|
||||
"""
|
||||
raw_kwargs = {}
|
||||
|
||||
def start_response(wsgi_status, wsgi_headers):
|
||||
status, _, reason = wsgi_status.partition(' ')
|
||||
raw_kwargs['status'] = int(status)
|
||||
raw_kwargs['reason'] = reason
|
||||
raw_kwargs['headers'] = wsgi_headers
|
||||
raw_kwargs['version'] = 11
|
||||
raw_kwargs['preload_content'] = False
|
||||
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
|
||||
|
||||
# Make the outgoing request via WSGI.
|
||||
environ = self.get_environ(request)
|
||||
wsgi_response = self.app(environ, start_response)
|
||||
|
||||
# Build the underlying urllib3.HTTPResponse
|
||||
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
|
||||
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
|
||||
|
||||
# Build the requests.Response
|
||||
return self.build_response(request, raw)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
class DjangoTestSession(requests.Session):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DjangoTestSession, self).__init__(*args, **kwargs)
|
||||
|
||||
adapter = DjangoTestAdapter()
|
||||
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
|
||||
|
||||
for hostname in hostnames:
|
||||
if hostname == '*':
|
||||
hostname = ''
|
||||
self.mount('http://%s' % hostname, adapter)
|
||||
self.mount('https://%s' % hostname, adapter)
|
||||
|
||||
def request(self, method, url, *args, **kwargs):
|
||||
if ':' not in url:
|
||||
url = 'http://testserver/' + url.lstrip('/')
|
||||
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
|
||||
|
||||
|
||||
def get_requests_client():
|
||||
assert requests is not None, 'requests must be installed'
|
||||
return DjangoTestSession()
|
||||
|
||||
|
||||
class APIRequestFactory(DjangoRequestFactory):
|
||||
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
||||
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
||||
|
|
247
tests/test_requests_client.py
Normal file
247
tests/test_requests_client.py
Normal file
|
@ -0,0 +1,247 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
|
||||
from django.conf.urls import url
|
||||
from django.contrib.auth import authenticate, login
|
||||
from django.contrib.auth.models import User
|
||||
from django.shortcuts import redirect
|
||||
from django.test import override_settings
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
|
||||
|
||||
from rest_framework.compat import is_authenticated, requests
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APITestCase, get_requests_client
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
class Root(APIView):
|
||||
def get(self, request):
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': request.query_params,
|
||||
})
|
||||
|
||||
def post(self, request):
|
||||
files = {
|
||||
key: (value.name, value.read())
|
||||
for key, value in request.FILES.items()
|
||||
}
|
||||
post = request.POST
|
||||
json = None
|
||||
if request.META.get('CONTENT_TYPE') == 'application/json':
|
||||
json = request.data
|
||||
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'query_params': request.query_params,
|
||||
'POST': post,
|
||||
'FILES': files,
|
||||
'JSON': json
|
||||
})
|
||||
|
||||
|
||||
class HeadersView(APIView):
|
||||
def get(self, request):
|
||||
headers = {
|
||||
key[5:].replace('_', '-'): value
|
||||
for key, value in request.META.items()
|
||||
if key.startswith('HTTP_')
|
||||
}
|
||||
return Response({
|
||||
'method': request.method,
|
||||
'headers': headers
|
||||
})
|
||||
|
||||
|
||||
class SessionView(APIView):
|
||||
def get(self, request):
|
||||
return Response({
|
||||
key: value for key, value in request.session.items()
|
||||
})
|
||||
|
||||
def post(self, request):
|
||||
for key, value in request.data.items():
|
||||
request.session[key] = value
|
||||
return Response({
|
||||
key: value for key, value in request.session.items()
|
||||
})
|
||||
|
||||
|
||||
class AuthView(APIView):
|
||||
@method_decorator(ensure_csrf_cookie)
|
||||
def get(self, request):
|
||||
if is_authenticated(request.user):
|
||||
username = request.user.username
|
||||
else:
|
||||
username = None
|
||||
return Response({
|
||||
'username': username
|
||||
})
|
||||
|
||||
@method_decorator(csrf_protect)
|
||||
def post(self, request):
|
||||
username = request.data['username']
|
||||
password = request.data['password']
|
||||
user = authenticate(username=username, password=password)
|
||||
if user is None:
|
||||
return Response({'error': 'incorrect credentials'})
|
||||
login(request, user)
|
||||
return redirect('/auth/')
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^$', Root.as_view()),
|
||||
url(r'^headers/$', HeadersView.as_view()),
|
||||
url(r'^session/$', SessionView.as_view()),
|
||||
url(r'^auth/$', AuthView.as_view()),
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipUnless(requests, 'requests not installed')
|
||||
@override_settings(ROOT_URLCONF='tests.test_requests_client')
|
||||
class RequestsClientTests(APITestCase):
|
||||
def test_get_request(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_request_query_params_in_url(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/?key=value')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_request_query_params_by_kwarg(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/', params={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'GET',
|
||||
'query_params': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_get_with_headers(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/headers/', headers={'User-Agent': 'example'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
headers = response.json()['headers']
|
||||
assert headers['USER-AGENT'] == 'example'
|
||||
|
||||
def test_post_form_request(self):
|
||||
client = get_requests_client()
|
||||
response = client.post('/', data={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'POST': {'key': 'value'},
|
||||
'FILES': {},
|
||||
'JSON': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_post_json_request(self):
|
||||
client = get_requests_client()
|
||||
response = client.post('/', json={'key': 'value'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'POST': {},
|
||||
'FILES': {},
|
||||
'JSON': {'key': 'value'}
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_post_multipart_request(self):
|
||||
client = get_requests_client()
|
||||
files = {
|
||||
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
|
||||
}
|
||||
response = client.post('/', files=files)
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'method': 'POST',
|
||||
'query_params': {},
|
||||
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
|
||||
'POST': {},
|
||||
'JSON': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_session(self):
|
||||
client = get_requests_client()
|
||||
response = client.get('/session/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {}
|
||||
assert response.json() == expected
|
||||
|
||||
response = client.post('/session/', json={'example': 'abc'})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {'example': 'abc'}
|
||||
assert response.json() == expected
|
||||
|
||||
response = client.get('/session/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {'example': 'abc'}
|
||||
assert response.json() == expected
|
||||
|
||||
def test_auth(self):
|
||||
# Confirm session is not authenticated
|
||||
client = get_requests_client()
|
||||
response = client.get('/auth/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': None
|
||||
}
|
||||
assert response.json() == expected
|
||||
assert 'csrftoken' in response.cookies
|
||||
csrftoken = response.cookies['csrftoken']
|
||||
|
||||
user = User.objects.create(username='tom')
|
||||
user.set_password('password')
|
||||
user.save()
|
||||
|
||||
# Perform a login
|
||||
response = client.post('/auth/', json={
|
||||
'username': 'tom',
|
||||
'password': 'password'
|
||||
}, headers={'X-CSRFToken': csrftoken})
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': 'tom'
|
||||
}
|
||||
assert response.json() == expected
|
||||
|
||||
# Confirm session is authenticated
|
||||
response = client.get('/auth/')
|
||||
assert response.status_code == 200
|
||||
assert response.headers['Content-Type'] == 'application/json'
|
||||
expected = {
|
||||
'username': 'tom'
|
||||
}
|
||||
assert response.json() == expected
|
Loading…
Reference in New Issue
Block a user