diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index 3a10e8b36..bafef21f6 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -2,4 +2,4 @@ markdown==2.6.4 django-guardian==1.4.3 django-filter==0.13.0 -coreapi==1.32.2 +coreapi==1.32.3 diff --git a/rest_framework/test.py b/rest_framework/test.py index 492edac50..b8e486b21 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -60,7 +60,10 @@ if requests is not None: # Set request content, if any exists. if request.body is not None: - kwargs['data'] = request.body + if hasattr(request.body, 'read'): + kwargs['data'] = request.body.read() + else: + kwargs['data'] = request.body if 'content-type' in request.headers: kwargs['content_type'] = request.headers['content-type'] diff --git a/tests/test_api_client.py b/tests/test_api_client.py index 9f8b9075e..b00ce3af9 100644 --- a/tests/test_api_client.py +++ b/tests/test_api_client.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import os import tempfile import unittest @@ -7,6 +8,7 @@ from django.conf.urls import url from django.test import override_settings from rest_framework.compat import coreapi +from rest_framework.parsers import FileUploadParser from rest_framework.renderers import CoreJSONRenderer from rest_framework.response import Response from rest_framework.test import APITestCase, get_api_client @@ -40,6 +42,9 @@ def get_schema(): 'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ coreapi.Field(name='example') ]), + 'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[ + coreapi.Field(name='example', location='body') + ]), } } ) @@ -102,10 +107,23 @@ class DetailView(APIView): }) +class UploadView(APIView): + parser_classes = [FileUploadParser] + + def post(self, request): + upload = request.data['file'] + contents = upload.read() + return Response({ + 'method': request.method, + 'files': {'name': upload.name, 'contents': contents} + }) + + urlpatterns = [ url(r'^$', SchemaView.as_view()), url(r'^example/$', ListView.as_view()), - url(r'^example/(?P[0-9]+)/$', DetailView.as_view()) + url(r'^example/(?P[0-9]+)/$', DetailView.as_view()), + url(r'^upload/$', UploadView.as_view()), ] @@ -176,16 +194,21 @@ class APIClientTests(APITestCase): def test_multipart_encoding(self): client = get_api_client() schema = client.get('http://api.example.com/') - temp = tempfile.TemporaryFile() + + temp = tempfile.NamedTemporaryFile() temp.write('example file contents') - temp.seek(0) - data = client.action(schema, ['encoding', 'multipart'], params={'example': temp}) + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'multipart'], params={'example': upload}) + expected = { 'method': 'POST', 'content_type': 'multipart/form-data', 'query_params': {}, 'data': {}, - 'files': {'example': {'name': 'example', 'contents': 'example file contents'}} + 'files': {'example': {'name': name, 'contents': 'example file contents'}} } assert data == expected @@ -201,3 +224,21 @@ class APIClientTests(APITestCase): 'files': None } assert data == expected + + def test_raw_upload(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write('example file contents') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'files': {'name': name, 'contents': 'example file contents'} + } + assert data == expected