mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-07 13:54:47 +03:00
Added 'requests' test client
This commit is contained in:
parent
5abac93c01
commit
3d1fff3f26
|
@ -373,7 +373,7 @@ class Request(object):
|
||||||
if not _hasattr(self, '_data'):
|
if not _hasattr(self, '_data'):
|
||||||
self._load_data_and_files()
|
self._load_data_and_files()
|
||||||
if is_form_media_type(self.content_type):
|
if is_form_media_type(self.content_type):
|
||||||
return self.data
|
return self._data
|
||||||
return QueryDict('', encoding=self._request._encoding)
|
return QueryDict('', encoding=self._request._encoding)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -4,7 +4,10 @@
|
||||||
# to make it harder for the user to import the wrong thing without realizing.
|
# to make it harder for the user to import the wrong thing without realizing.
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.core.handlers.wsgi import WSGIHandler
|
||||||
from django.test import testcases
|
from django.test import testcases
|
||||||
from django.test.client import Client as DjangoClient
|
from django.test.client import Client as DjangoClient
|
||||||
from django.test.client import RequestFactory as DjangoRequestFactory
|
from django.test.client import RequestFactory as DjangoRequestFactory
|
||||||
|
@ -13,6 +16,10 @@ from django.utils import six
|
||||||
from django.utils.encoding import force_bytes
|
from django.utils.encoding import force_bytes
|
||||||
from django.utils.http import urlencode
|
from django.utils.http import urlencode
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
from requests.adapters import BaseAdapter
|
||||||
|
from requests.models import Response
|
||||||
|
from requests.structures import CaseInsensitiveDict
|
||||||
|
from requests.utils import get_encoding_from_headers
|
||||||
|
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
@ -22,6 +29,83 @@ def force_authenticate(request, user=None, token=None):
|
||||||
request._force_auth_token = token
|
request._force_auth_token = token
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoTestAdapter(BaseAdapter):
|
||||||
|
"""
|
||||||
|
A transport adaptor for `requests`, that makes requests via the
|
||||||
|
Django WSGI app, rather than making actual HTTP requests ovet the network.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.app = WSGIHandler()
|
||||||
|
self.factory = DjangoRequestFactory()
|
||||||
|
|
||||||
|
def get_environ(self, request):
|
||||||
|
"""
|
||||||
|
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
|
||||||
|
"""
|
||||||
|
method = request.method
|
||||||
|
url = request.url
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
# Set request content, if any exists.
|
||||||
|
if request.body is not None:
|
||||||
|
kwargs['data'] = request.body
|
||||||
|
if 'content-type' in request.headers:
|
||||||
|
kwargs['content_type'] = request.headers['content-type']
|
||||||
|
|
||||||
|
# Set request headers.
|
||||||
|
for key, value in request.headers.items():
|
||||||
|
key = key.upper()
|
||||||
|
if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'):
|
||||||
|
continue
|
||||||
|
kwargs['HTTP_%s' % key] = value
|
||||||
|
|
||||||
|
return self.factory.generic(method, url, **kwargs).environ
|
||||||
|
|
||||||
|
def send(self, request, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Make an outgoing request to the Django WSGI application.
|
||||||
|
"""
|
||||||
|
response = Response()
|
||||||
|
|
||||||
|
def start_response(status, headers):
|
||||||
|
status_code, _, reason_phrase = status.partition(' ')
|
||||||
|
response.status_code = int(status_code)
|
||||||
|
response.reason = reason_phrase
|
||||||
|
response.headers = CaseInsensitiveDict(headers)
|
||||||
|
response.encoding = get_encoding_from_headers(response.headers)
|
||||||
|
|
||||||
|
environ = self.get_environ(request)
|
||||||
|
raw_bytes = self.app(environ, start_response)
|
||||||
|
|
||||||
|
response.request = request
|
||||||
|
response.url = request.url
|
||||||
|
response.raw = io.BytesIO(b''.join(raw_bytes))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoTestSession(Session):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(DjangoTestSession, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
adapter = DjangoTestAdapter()
|
||||||
|
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']
|
||||||
|
|
||||||
|
for hostname in hostnames:
|
||||||
|
if hostname == '*':
|
||||||
|
hostname = ''
|
||||||
|
self.mount('http://%s' % hostname, adapter)
|
||||||
|
self.mount('https://%s' % hostname, adapter)
|
||||||
|
|
||||||
|
def request(self, method, url, *args, **kwargs):
|
||||||
|
if ':' not in url:
|
||||||
|
url = 'http://testserver/' + url.lstrip('/')
|
||||||
|
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class APIRequestFactory(DjangoRequestFactory):
|
class APIRequestFactory(DjangoRequestFactory):
|
||||||
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
||||||
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
||||||
|
@ -224,7 +308,7 @@ class APITestCase(testcases.TestCase):
|
||||||
|
|
||||||
def _pre_setup(self):
|
def _pre_setup(self):
|
||||||
super(APITestCase, self)._pre_setup()
|
super(APITestCase, self)._pre_setup()
|
||||||
self.requests = Session()
|
self.requests = DjangoTestSession()
|
||||||
|
|
||||||
|
|
||||||
class APISimpleTestCase(testcases.SimpleTestCase):
|
class APISimpleTestCase(testcases.SimpleTestCase):
|
||||||
|
|
|
@ -10,15 +10,128 @@ from rest_framework.views import APIView
|
||||||
|
|
||||||
class Root(APIView):
|
class Root(APIView):
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
return Response({'hello': 'world'})
|
return Response({
|
||||||
|
'method': request.method,
|
||||||
|
'query_params': request.query_params,
|
||||||
|
})
|
||||||
|
|
||||||
|
def post(self, request):
|
||||||
|
files = {
|
||||||
|
key: (value.name, value.read())
|
||||||
|
for key, value in request.FILES.items()
|
||||||
|
}
|
||||||
|
post = request.POST
|
||||||
|
json = None
|
||||||
|
if request.META.get('CONTENT_TYPE') == 'application/json':
|
||||||
|
json = request.data
|
||||||
|
|
||||||
|
return Response({
|
||||||
|
'method': request.method,
|
||||||
|
'query_params': request.query_params,
|
||||||
|
'POST': post,
|
||||||
|
'FILES': files,
|
||||||
|
'JSON': json
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class Headers(APIView):
|
||||||
|
def get(self, request):
|
||||||
|
headers = {
|
||||||
|
key[5:]: value
|
||||||
|
for key, value in request.META.items()
|
||||||
|
if key.startswith('HTTP_')
|
||||||
|
}
|
||||||
|
return Response({
|
||||||
|
'method': request.method,
|
||||||
|
'headers': headers
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
url(r'^$', Root.as_view()),
|
url(r'^$', Root.as_view()),
|
||||||
|
url(r'^headers/$', Headers.as_view()),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@override_settings(ROOT_URLCONF='tests.test_requests_client')
|
@override_settings(ROOT_URLCONF='tests.test_requests_client')
|
||||||
class RequestsClientTests(APITestCase):
|
class RequestsClientTests(APITestCase):
|
||||||
def test_get_root(self):
|
def test_get_request(self):
|
||||||
print self.requests.get('http://example.com')
|
response = self.requests.get('/')
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'GET',
|
||||||
|
'query_params': {}
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_get_request_query_params_in_url(self):
|
||||||
|
response = self.requests.get('/?key=value')
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'GET',
|
||||||
|
'query_params': {'key': 'value'}
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_get_request_query_params_by_kwarg(self):
|
||||||
|
response = self.requests.get('/', params={'key': 'value'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'GET',
|
||||||
|
'query_params': {'key': 'value'}
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_get_with_headers(self):
|
||||||
|
response = self.requests.get('/headers/', headers={'User-Agent': 'example'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
headers = response.json()['headers']
|
||||||
|
assert headers['USER-AGENT'] == 'example'
|
||||||
|
|
||||||
|
def test_post_form_request(self):
|
||||||
|
response = self.requests.post('/', data={'key': 'value'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'query_params': {},
|
||||||
|
'POST': {'key': 'value'},
|
||||||
|
'FILES': {},
|
||||||
|
'JSON': None
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_post_json_request(self):
|
||||||
|
response = self.requests.post('/', json={'key': 'value'})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'query_params': {},
|
||||||
|
'POST': {},
|
||||||
|
'FILES': {},
|
||||||
|
'JSON': {'key': 'value'}
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
def test_post_multipart_request(self):
|
||||||
|
files = {
|
||||||
|
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
|
||||||
|
}
|
||||||
|
response = self.requests.post('/', files=files)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers['Content-Type'] == 'application/json'
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'query_params': {},
|
||||||
|
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
|
||||||
|
'POST': {},
|
||||||
|
'JSON': None
|
||||||
|
}
|
||||||
|
assert response.json() == expected
|
||||||
|
|
||||||
|
# cookies/session auth
|
||||||
|
|
Loading…
Reference in New Issue
Block a user