mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-10 07:14:48 +03:00
Merge 0cc3f5008f
into 12576275c4
This commit is contained in:
commit
490729fdee
|
@ -178,6 +178,13 @@ except (ImportError, SyntaxError):
|
||||||
uritemplate = None
|
uritemplate = None
|
||||||
|
|
||||||
|
|
||||||
|
# requests is optional
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except ImportError:
|
||||||
|
requests = None
|
||||||
|
|
||||||
|
|
||||||
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
|
# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS
|
||||||
# Fixes (#1712). We keep the try/except for the test suite.
|
# Fixes (#1712). We keep the try/except for the test suite.
|
||||||
guardian = None
|
guardian = None
|
||||||
|
|
|
@ -4,7 +4,10 @@
|
||||||
# to make it harder for the user to import the wrong thing without realizing.
|
# to make it harder for the user to import the wrong thing without realizing.
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.core.handlers.wsgi import WSGIHandler
|
||||||
from django.test import testcases
|
from django.test import testcases
|
||||||
from django.test.client import Client as DjangoClient
|
from django.test.client import Client as DjangoClient
|
||||||
from django.test.client import RequestFactory as DjangoRequestFactory
|
from django.test.client import RequestFactory as DjangoRequestFactory
|
||||||
|
@ -13,6 +16,7 @@ from django.utils import six
|
||||||
from django.utils.encoding import force_bytes
|
from django.utils.encoding import force_bytes
|
||||||
from django.utils.http import urlencode
|
from django.utils.http import urlencode
|
||||||
|
|
||||||
|
from rest_framework.compat import requests
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,6 +25,107 @@ def force_authenticate(request, user=None, token=None):
|
||||||
request._force_auth_token = token
|
request._force_auth_token = token
|
||||||
|
|
||||||
|
|
||||||
|
if requests is not None:
|
||||||
|
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
||||||
|
def get_all(self, key, default):
|
||||||
|
return self.getheaders(key)
|
||||||
|
|
||||||
|
class MockOriginalResponse(object):
|
||||||
|
def __init__(self, headers):
|
||||||
|
self.msg = HeaderDict(headers)
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
def isclosed(self):
|
||||||
|
return self.closed
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
|
||||||
|
"""
|
||||||
|
A transport adapter for `requests`, that makes requests via the
|
||||||
|
Django WSGI app, rather than making actual HTTP requests over the network.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.app = WSGIHandler()
|
||||||
|
self.factory = DjangoRequestFactory()
|
||||||
|
|
||||||
|
def get_environ(self, request):
|
||||||
|
"""
|
||||||
|
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
||||||
|
"""
|
||||||
|
method = request.method
|
||||||
|
url = request.url
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
# Set request content, if any exists.
|
||||||
|
if request.body is not None:
|
||||||
|
kwargs['data'] = request.body
|
||||||
|
if 'content-type' in request.headers:
|
||||||
|
kwargs['content_type'] = request.headers['content-type']
|
||||||
|
|
||||||
|
# Set request headers.
|
||||||
|
for key, value in request.headers.items():
|
||||||
|
key = key.upper()
|
||||||
|
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
|
||||||
|
continue
|
||||||
|
kwargs['HTTP_%s' % key.replace('-', '_')] = value
|
||||||
|
|
||||||
|
return self.factory.generic(method, url, **kwargs).environ
|
||||||
|
|
||||||
|
def send(self, request, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Make an outgoing request to the Django WSGI application.
|
||||||
|
"""
|
||||||
|
raw_kwargs = {}
|
||||||
|
|
||||||
|
def start_response(wsgi_status, wsgi_headers):
|
||||||
|
status, _, reason = wsgi_status.partition(' ')
|
||||||
|
raw_kwargs['status'] = int(status)
|
||||||
|
raw_kwargs['reason'] = reason
|
||||||
|
raw_kwargs['headers'] = wsgi_headers
|
||||||
|
raw_kwargs['version'] = 11
|
||||||
|
raw_kwargs['preload_content'] = False
|
||||||
|
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
|
||||||
|
|
||||||
|
# Make the outgoing request via WSGI.
|
||||||
|
environ = self.get_environ(request)
|
||||||
|
wsgi_response = self.app(environ, start_response)
|
||||||
|
|
||||||
|
# Build the underlying urllib3.HTTPResponse
|
||||||
|
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
|
||||||
|
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
|
||||||
|
|
||||||
|
# Build the requests.Response
|
||||||
|
return self.build_response(request, raw)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DjangoTestSession(requests.Session):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(DjangoTestSession, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
adapter = DjangoTestAdapter()
|
||||||
|
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
|
||||||
|
|
||||||
|
for hostname in hostnames:
|
||||||
|
if hostname == '*':
|
||||||
|
hostname = ''
|
||||||
|
self.mount('http://%s' % hostname, adapter)
|
||||||
|
self.mount('https://%s' % hostname, adapter)
|
||||||
|
|
||||||
|
def request(self, method, url, *args, **kwargs):
|
||||||
|
if ':' not in url:
|
||||||
|
url = 'http://testserver/' + url.lstrip('/')
|
||||||
|
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_requests_client():
|
||||||
|
assert requests is not None, 'requests must be installed'
|
||||||
|
return DjangoTestSession()
|
||||||
|
|
||||||
|
|
||||||
class APIRequestFactory(DjangoRequestFactory):
|
class APIRequestFactory(DjangoRequestFactory):
|
||||||
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
||||||
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
||||||
|
|
247
tests/test_requests_client.py
Normal file
247
tests/test_requests_client.py
Normal file
|
@ -0,0 +1,247 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from django.conf.urls import url
|
||||||
|
from django.contrib.auth import authenticate, login
|
||||||
|
from django.contrib.auth.models import User
|
||||||
|
from django.shortcuts import redirect
|
||||||
|
from django.test import override_settings
|
||||||
|
from django.utils.decorators import method_decorator
|
||||||
|
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
|
||||||
|
|
||||||
|
from rest_framework.compat import is_authenticated, requests
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.test import APITestCase, get_requests_client
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
|
||||||
|
class Root(APIView):
|
||||||
|
def get(self, request):
|
||||||
|
return Response({
|
||||||
|
'method': request.method,
|
||||||
|
'query_params': request.query_params,
|
||||||
|
})
|
||||||
|
|
||||||
|
def post(self, request):
|
||||||
|
files = {
|
||||||
|
key: (value.name, value.read())
|
||||||
|
for key, value in request.FILES.items()
|
||||||
|
}
|
||||||
|
post = request.POST
|
||||||
|
json = None
|
||||||
|
if request.META.get('CONTENT_TYPE') == 'application/json':
|
||||||
|
json = request.data
|
||||||
|
|
||||||
|
return Response({
|
||||||
|
'method': request.method,
|
||||||
|
'query_params': request.query_params,
|
||||||
|
'POST': post,
|
||||||
|
'FILES': files,
|
||||||
|
'JSON': json
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class HeadersView(APIView):
|
||||||
|
def get(self, request):
|
||||||
|
headers = {
|
||||||
|
key[5:].replace('_', '-'): value
|
||||||
|
for key, value in request.META.items()
|
||||||
|
if key.startswith('HTTP_')
|
||||||
|
}
|
||||||
|
return Response({
|
||||||
|
'method': request.method,
|
||||||
|
'headers': headers
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class SessionView(APIView):
|
||||||
|
def get(self, request):
|
||||||
|
return Response({
|
||||||
|
key: value for key, value in request.session.items()
|
||||||
|
})
|
||||||
|
|
||||||
|
def post(self, request):
|
||||||
|
for key, value in request.data.items():
|
||||||
|
request.session[key] = value
|
||||||
|
return Response({
|
||||||
|
key: value for key, value in request.session.items()
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class AuthView(APIView):
|
||||||
|
@method_decorator(ensure_csrf_cookie)
|
||||||
|
def get(self, request):
|
||||||
|
if is_authenticated(request.user):
|
||||||
|
username = request.user.username
|
||||||
|
else:
|
||||||
|
username = None
|
||||||
|
return Response({
|
||||||
|
'username': username
|
||||||
|
})
|
||||||
|
|
||||||
|
@method_decorator(csrf_protect)
|
||||||
|
def post(self, request):
|
||||||
|
username = request.data['username']
|
||||||
|
password = request.data['password']
|
||||||
|
user = authenticate(username=username, password=password)
|
||||||
|
if user is None:
|
||||||
|
return Response({'error': 'incorrect credentials'})
|
||||||
|
login(request, user)
|
||||||
|
return redirect('/auth/')
|
||||||
|
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
url(r'^$', Root.as_view()),
|
||||||
|
url(r'^headers/$', HeadersView.as_view()),
|
||||||
|
url(r'^session/$', SessionView.as_view()),
|
||||||
|
url(r'^auth/$', AuthView.as_view()),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(requests, 'requests not installed')
|
||||||
|
@override_settings(ROOT_URLCONF='tests.test_requests_client')
|
||||||
|
class RequestsClientTests(APITestCase):
|
||||||
|
def test_get_request(self):
|
||||||
|
client = get_requests_client()
|
||||||
|
response = client.get('/')
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'GET',
|
||||||
|
'query_params': {}
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_get_request_query_params_in_url(self):
|
||||||
|
client = get_requests_client()
|
||||||
|
response = client.get('/?key=value')
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'GET',
|
||||||
|
'query_params': {'key': 'value'}
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_get_request_query_params_by_kwarg(self):
|
||||||
|
client = get_requests_client()
|
||||||
|
response = client.get('/', params={'key': 'value'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'GET',
|
||||||
|
'query_params': {'key': 'value'}
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_get_with_headers(self):
|
||||||
|
client = get_requests_client()
|
||||||
|
response = client.get('/headers/', headers={'User-Agent': 'example'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
headers = response.json()['headers']
|
||||||
|
assert headers['USER-AGENT'] == 'example'
|
||||||
|
|
||||||
|
def test_post_form_request(self):
|
||||||
|
client = get_requests_client()
|
||||||
|
response = client.post('/', data={'key': 'value'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'query_params': {},
|
||||||
|
'POST': {'key': 'value'},
|
||||||
|
'FILES': {},
|
||||||
|
'JSON': None
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_post_json_request(self):
|
||||||
|
client = get_requests_client()
|
||||||
|
response = client.post('/', json={'key': 'value'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'query_params': {},
|
||||||
|
'POST': {},
|
||||||
|
'FILES': {},
|
||||||
|
'JSON': {'key': 'value'}
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_post_multipart_request(self):
|
||||||
|
client = get_requests_client()
|
||||||
|
files = {
|
||||||
|
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
|
||||||
|
}
|
||||||
|
response = client.post('/', files=files)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'query_params': {},
|
||||||
|
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
|
||||||
|
'POST': {},
|
||||||
|
'JSON': None
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_session(self):
|
||||||
|
client = get_requests_client()
|
||||||
|
response = client.get('/session/')
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
response = client.post('/session/', json={'example': 'abc'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {'example': 'abc'}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
response = client.get('/session/')
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {'example': 'abc'}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_auth(self):
|
||||||
|
# Confirm session is not authenticated
|
||||||
|
client = get_requests_client()
|
||||||
|
response = client.get('/auth/')
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'username': None
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
assert 'csrftoken' in response.cookies
|
||||||
|
csrftoken = response.cookies['csrftoken']
|
||||||
|
|
||||||
|
user = User.objects.create(username='tom')
|
||||||
|
user.set_password('password')
|
||||||
|
user.save()
|
||||||
|
|
||||||
|
# Perform a login
|
||||||
|
response = client.post('/auth/', json={
|
||||||
|
'username': 'tom',
|
||||||
|
'password': 'password'
|
||||||
|
}, headers={'X-CSRFToken': csrftoken})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'username': 'tom'
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
# Confirm session is authenticated
|
||||||
|
response = client.get('/auth/')
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'username': 'tom'
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
Loading…
Reference in New Issue
Block a user