diff --git a/rest_framework/request.py b/rest_framework/request.py index aafafcb32..f5738bfd5 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -373,7 +373,7 @@ class Request(object): if not _hasattr(self, '_data'): self._load_data_and_files() if is_form_media_type(self.content_type): - return self.data + return self._data return QueryDict('', encoding=self._request._encoding) @property diff --git a/rest_framework/test.py b/rest_framework/test.py index fb08c4a73..1fd530a0c 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -4,7 +4,10 @@ # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals +import io + from django.conf import settings +from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient from django.test.client import RequestFactory as DjangoRequestFactory @@ -13,6 +16,10 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode from requests import Session +from requests.adapters import BaseAdapter +from requests.models import Response +from requests.structures import CaseInsensitiveDict +from requests.utils import get_encoding_from_headers from rest_framework.settings import api_settings @@ -22,6 +29,83 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token +class DjangoTestAdapter(BaseAdapter): + """ + A transport adaptor for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests ovet the network. + """ + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() + + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} + + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] + + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key] = value + + return self.factory.generic(method, url, **kwargs).environ + + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = Response() + + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = CaseInsensitiveDict(headers) + response.encoding = get_encoding_from_headers(response.headers) + + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) + + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) + + return response + + def close(self): + pass + + +class DjangoTestSession(Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) + + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) + + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -224,7 +308,7 @@ class APITestCase(testcases.TestCase): def _pre_setup(self): super(APITestCase, self)._pre_setup() - self.requests = Session() + self.requests = DjangoTestSession() class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index a36349a3f..0687dd92e 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -10,15 +10,128 @@ from rest_framework.views import APIView class Root(APIView): def get(self, request): - return Response({'hello': 'world'}) + return Response({ + 'method': request.method, + 'query_params': request.query_params, + }) + + def post(self, request): + files = { + key: (value.name, value.read()) + for key, value in request.FILES.items() + } + post = request.POST + json = None + if request.META.get('CONTENT_TYPE') == 'application/json': + json = request.data + + return Response({ + 'method': request.method, + 'query_params': request.query_params, + 'POST': post, + 'FILES': files, + 'JSON': json + }) + + +class Headers(APIView): + def get(self, request): + headers = { + key[5:]: value + for key, value in request.META.items() + if key.startswith('HTTP_') + } + return Response({ + 'method': request.method, + 'headers': headers + }) urlpatterns = [ url(r'^$', Root.as_view()), + url(r'^headers/$', Headers.as_view()), ] @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): - def test_get_root(self): - print self.requests.get('http://example.com') + def test_get_request(self): + response = self.requests.get('/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {} + } + assert response.json() == expected + + def test_get_request_query_params_in_url(self): + response = self.requests.get('/?key=value') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_request_query_params_by_kwarg(self): + response = self.requests.get('/', params={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_with_headers(self): + response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + headers = response.json()['headers'] + assert headers['USER-AGENT'] == 'example' + + def test_post_form_request(self): + response = self.requests.post('/', data={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {'key': 'value'}, + 'FILES': {}, + 'JSON': None + } + assert response.json() == expected + + def test_post_json_request(self): + response = self.requests.post('/', json={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {}, + 'FILES': {}, + 'JSON': {'key': 'value'} + } + assert response.json() == expected + + def test_post_multipart_request(self): + files = { + 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') + } + response = self.requests.post('/', files=files) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']}, + 'POST': {}, + 'JSON': None + } + assert response.json() == expected + + # cookies/session auth