support django 2.1 test client json data automatically serialized

This commit is contained in:
Terence D. Honles 2019-03-15 14:10:42 -07:00
parent 4296189283
commit 803e0896c2
2 changed files with 34 additions and 7 deletions

View File

@ -152,14 +152,19 @@ class APIRequestFactory(DjangoRequestFactory):
Encode the data returning a two tuple of (bytes, content_type) Encode the data returning a two tuple of (bytes, content_type)
""" """
if data is None:
return ('', content_type)
assert format is None or content_type is None, ( assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.' 'You may not set both `format` and `content_type`.'
) )
if content_type: if content_type:
try:
data = self._encode_json(data, content_type)
except AttributeError:
pass
if data is None:
data = ''
# Content type specified explicitly, treat data as a raw bytestring # Content type specified explicitly, treat data as a raw bytestring
ret = force_bytes(data, settings.DEFAULT_CHARSET) ret = force_bytes(data, settings.DEFAULT_CHARSET)
@ -177,7 +182,6 @@ class APIRequestFactory(DjangoRequestFactory):
# Use format and render the data into a bytestring # Use format and render the data into a bytestring
renderer = self.renderer_classes[format]() renderer = self.renderer_classes[format]()
ret = renderer.render(data)
# Determine the content-type header from the renderer # Determine the content-type header from the renderer
content_type = renderer.media_type content_type = renderer.media_type
@ -186,6 +190,11 @@ class APIRequestFactory(DjangoRequestFactory):
content_type, renderer.charset content_type, renderer.charset
) )
if data is None:
ret = ''
else:
ret = renderer.render(data)
# Coerce text to bytes if required. # Coerce text to bytes if required.
if isinstance(ret, str): if isinstance(ret, str):
ret = ret.encode(renderer.charset) ret = ret.encode(renderer.charset)

View File

@ -9,9 +9,9 @@ from django.shortcuts import redirect
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import path from django.urls import path
from rest_framework import fields, serializers from rest_framework import fields, parsers, serializers
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.decorators import api_view from rest_framework.decorators import api_view, parser_classes
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import ( from rest_framework.test import (
APIClient, APIRequestFactory, URLPatternsTestCase, force_authenticate APIClient, APIRequestFactory, URLPatternsTestCase, force_authenticate
@ -51,6 +51,12 @@ class BasicSerializer(serializers.Serializer):
flag = fields.BooleanField(default=lambda: True) flag = fields.BooleanField(default=lambda: True)
@api_view(['POST'])
@parser_classes((parsers.JSONParser,))
def post_json_view(request):
return Response(request.data)
@api_view(['POST']) @api_view(['POST'])
def post_view(request): def post_view(request):
serializer = BasicSerializer(data=request.data) serializer = BasicSerializer(data=request.data)
@ -63,7 +69,8 @@ urlpatterns = [
path('session-view/', session_view), path('session-view/', session_view),
path('redirect-view/', redirect_view), path('redirect-view/', redirect_view),
path('redirect-view/<int:code>/', redirect_307_308_view), path('redirect-view/<int:code>/', redirect_307_308_view),
path('post-view/', post_view) path('post-json-view/', post_json_view),
path('post-view/', post_view),
] ]
@ -237,6 +244,17 @@ class TestAPITestClient(TestCase):
assert response.status_code == 200 assert response.status_code == 200
assert response.data == {"flag": True} assert response.data == {"flag": True}
def test_post_encodes_data_based_on_json_content_type(self):
data = {'data': True}
response = self.client.post(
'/post-json-view/',
data=data,
content_type='application/json'
)
assert response.status_code == 200
assert response.data == data
class TestAPIRequestFactory(TestCase): class TestAPIRequestFactory(TestCase):
def test_csrf_exempt_by_default(self): def test_csrf_exempt_by_default(self):