unittest to pytest assertion convertions

This commit is contained in:
Asif Saif Uddin 2019-05-02 13:36:18 +06:00
parent 375f1b59c6
commit 3d28279d05
31 changed files with 3503 additions and 5311 deletions

2
.gitignore vendored
View File

@ -2,7 +2,7 @@
*.db *.db
*~ *~
.* .*
.py.bak *.py.bak
/site/ /site/

View File

@ -11,6 +11,7 @@ from rest_framework.response import Response
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from tests.models import BasicModel from tests.models import BasicModel
import pytest
factory = APIRequestFactory() factory = APIRequestFactory()
@ -52,8 +53,7 @@ urlpatterns = (
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionTests(TestCase): def setUp(self):
def setUp(self):
self.view = BasicView.as_view() self.view = BasicView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases['default']['ATOMIC_REQUESTS'] = True
@ -62,20 +62,18 @@ class DBTransactionTests(TestCase):
def test_no_exception_commit_transaction(self): def test_no_exception_commit_transaction(self):
request = factory.post('/') request = factory.post('/')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request) response = self.view(request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
assert BasicModel.objects.count() == 1 assertBasicModel.objects.count()==1
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionErrorTests(TestCase): def setUp(self):
def setUp(self):
self.view = ErrorView.as_view() self.view = ErrorView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases['default']['ATOMIC_REQUESTS'] = True
@ -95,7 +93,8 @@ class DBTransactionErrorTests(TestCase):
# 2 - insert # 2 - insert
# 3 - release savepoint # 3 - release savepoint
with transaction.atomic(): with transaction.atomic():
self.assertRaises(Exception, self.view, request) with pytest.raises(Exception):
self.view(request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
@ -104,8 +103,7 @@ class DBTransactionErrorTests(TestCase):
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionAPIExceptionTests(TestCase): def setUp(self):
def setUp(self):
self.view = APIExceptionView.as_view() self.view = APIExceptionView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases['default']['ATOMIC_REQUESTS'] = True
@ -127,7 +125,7 @@ class DBTransactionAPIExceptionTests(TestCase):
response = self.view(request) response = self.view(request)
assert transaction.get_rollback() assert transaction.get_rollback()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.count() == 0 assertBasicModel.objects.count()==0
@unittest.skipUnless( @unittest.skipUnless(

View File

@ -13,10 +13,7 @@ from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
def setUp(self):
class AuthTokenTests(TestCase):
def setUp(self):
self.site = site self.site = site
self.user = User.objects.create_user(username='test_user') self.user = User.objects.create_user(username='test_user')
self.token = Token.objects.create(key='test token', user=self.user) self.token = Token.objects.create(key='test token', user=self.user)
@ -40,9 +37,8 @@ class AuthTokenTests(TestCase):
assert AuthTokenSerializer(data=data).is_valid() assert AuthTokenSerializer(data=data).is_valid()
class AuthTokenCommandTests(TestCase):
def setUp(self): def setUp(self):
self.site = site self.site = site
self.user = User.objects.create_user(username='test_user') self.user = User.objects.create_user(username='test_user')
@ -61,7 +57,6 @@ class AuthTokenCommandTests(TestCase):
first_token_key = Token.objects.first().key first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, True) AuthTokenCommand().create_user_token(self.user.username, True)
second_token_key = Token.objects.first().key second_token_key = Token.objects.first().key
assert first_token_key != second_token_key assert first_token_key != second_token_key
def test_command_do_not_reset_user_token(self): def test_command_do_not_reset_user_token(self):
@ -69,7 +64,6 @@ class AuthTokenCommandTests(TestCase):
first_token_key = Token.objects.first().key first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, False) AuthTokenCommand().create_user_token(self.user.username, False)
second_token_key = Token.objects.first().key second_token_key = Token.objects.first().key
assert first_token_key == second_token_key assert first_token_key == second_token_key
def test_command_raising_error_for_invalid_user(self): def test_command_raising_error_for_invalid_user(self):
@ -81,6 +75,6 @@ class AuthTokenCommandTests(TestCase):
out = StringIO() out = StringIO()
call_command('drf_create_token', self.user.username, stdout=out) call_command('drf_create_token', self.user.username, stdout=out)
token_saved = Token.objects.first() token_saved = Token.objects.first()
self.assertIn('Generated token', out.getvalue()) assert 'Generated token' in out.getvalue()
self.assertIn(self.user.username, out.getvalue()) assert self.user.username in out.getvalue()
self.assertIn(token_saved.key, out.getvalue()) assert token_saved.key in out.getvalue()

View File

@ -17,10 +17,7 @@ from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView from rest_framework.views import APIView
def setUp(self):
class DecoratorTestCase(TestCase):
def setUp(self):
self.factory = APIRequestFactory() self.factory = APIRequestFactory()
def _finalize_response(self, request, response, *args, **kwargs): def _finalize_response(self, request, response, *args, **kwargs):
@ -31,20 +28,19 @@ class DecoratorTestCase(TestCase):
""" """
If @api_view is not applied correct, we should raise an assertion. If @api_view is not applied correct, we should raise an assertion.
""" """
@api_view @api_view
def view(request): def view(request):
return Response() return Response()
request = self.factory.get('/') request = self.factory.get('/')
self.assertRaises(AssertionError, view, request) withpytest.raises(AssertionError):
view(request)
def test_api_view_incorrect_arguments(self): def test_api_view_incorrect_arguments(self):
""" """
If @api_view is missing arguments, we should raise an assertion. If @api_view is missing arguments, we should raise an assertion.
""" """
with pytest.raises(AssertionError):
with self.assertRaises(AssertionError):
@api_view('GET') @api_view('GET')
def view(request): def view(request):
return Response() return Response()
@ -56,12 +52,11 @@ class DecoratorTestCase(TestCase):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
request = self.factory.post('/') response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_put_method(self): def test_calling_put_method(self):
@ -70,12 +65,11 @@ class DecoratorTestCase(TestCase):
return Response({}) return Response({})
request = self.factory.put('/') request = self.factory.put('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
request = self.factory.post('/') response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_patch_method(self): def test_calling_patch_method(self):
@ -84,12 +78,11 @@ class DecoratorTestCase(TestCase):
return Response({}) return Response({})
request = self.factory.patch('/') request = self.factory.patch('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
request = self.factory.post('/') response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_renderer_classes(self): def test_renderer_classes(self):
@ -99,8 +92,8 @@ class DecoratorTestCase(TestCase):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
response = view(request) response=view(request)
assert isinstance(response.accepted_renderer, JSONRenderer) assertisinstance(response.accepted_renderer,JSONRenderer)
def test_parser_classes(self): def test_parser_classes(self):
@ -112,7 +105,7 @@ class DecoratorTestCase(TestCase):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
view(request) view(request)
def test_authentication_classes(self): def test_authentication_classes(self):
@ -124,7 +117,7 @@ class DecoratorTestCase(TestCase):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
view(request) view(request)
def test_permission_classes(self): def test_permission_classes(self):
@ -134,24 +127,23 @@ class DecoratorTestCase(TestCase):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN assertresponse.status_code==status.HTTP_403_FORBIDDEN
def test_throttle_classes(self): def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle): class OncePerDayUserThrottle(UserRateThrottle):
rate = '1/day' rate = '1/day'
@api_view(['GET']) @api_view(['GET'])
@throttle_classes([OncePerDayUserThrottle]) @throttle_classes([OncePerDayUserThrottle])
def view(request): defview(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_429_TOO_MANY_REQUESTS
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
def test_schema(self): def test_schema(self):
""" """
@ -161,28 +153,24 @@ class DecoratorTestCase(TestCase):
pass pass
@api_view(['GET']) @api_view(['GET'])
@schema(CustomSchema()) @schema(CustomSchema())
def view(request): defview(request):
return Response({}) return Response({})
assert isinstance(view.cls.schema, CustomSchema) assert isinstance(view.cls.schema, CustomSchema)
class ActionDecoratorTestCase(TestCase):
def test_defaults(self): def test_defaults(self):
@action(detail=True) @action(detail=True)
def test_action(request): def test_action(request):
"""Description""" """Description"""
assert test_action.mapping == {'get': 'test_action'} assert test_action.mapping == {'get': 'test_action'}
assert test_action.detail is True asserttest_action.detailisTrue
assert test_action.url_path == 'test_action' asserttest_action.url_path=='test_action'
assert test_action.url_name == 'test-action' asserttest_action.url_name=='test-action'
assert test_action.kwargs == { asserttest_action.kwargs=={'name':'Test action','description':'Description',}
'name': 'Test action',
'description': 'Description',
}
def test_detail_required(self): def test_detail_required(self):
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
@ -203,7 +191,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
method.__name__ = name method.__name__ = name
getattr(test_action.mapping, name)(method) getattr(test_action.mapping,name)(method)
# ensure the mapping returns the correct method name # ensure the mapping returns the correct method name
for name in APIView.http_method_names: for name in APIView.http_method_names:
@ -214,38 +202,22 @@ class ActionDecoratorTestCase(TestCase):
'name' and 'suffix' are mutually exclusive kwargs used for generating 'name' and 'suffix' are mutually exclusive kwargs used for generating
a view's display name. a view's display name.
""" """
# by default, generate name from method
@action(detail=True) @action(detail=True)
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {'description':None,'name':'Test action',}
'description': None, @action(detail=True,name='test name')
'name': 'Test action', deftest_action(request):
}
# name kwarg supersedes name generation
@action(detail=True, name='test name')
def test_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {'description':None,'name':'test name',}
'description': None, @action(detail=True,suffix='Suffix')
'name': 'test name', deftest_action(request):
}
# suffix kwarg supersedes name generation
@action(detail=True, suffix='Suffix')
def test_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {'description':None,'suffix':'Suffix',}
'description': None, withpytest.raises(TypeError)asexcinfo:
'suffix': 'Suffix',
}
# name + suffix is a conflict.
with pytest.raises(TypeError) as excinfo:
action(detail=True, name='test name', suffix='Suffix') action(detail=True, name='test name', suffix='Suffix')
assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments." assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments."
@ -256,7 +228,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
@test_action.mapping.post @test_action.mapping.post
def test_action_post(request): deftest_action_post(request):
raise NotImplementedError raise NotImplementedError
# The secondary handler methods should not have the action attributes # The secondary handler methods should not have the action attributes
@ -269,7 +241,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
msg = "Method 'get' has already been mapped to '.test_action'." msg = "Method 'get' has already been mapped to '.test_action'."
with self.assertRaisesMessage(AssertionError, msg): withself.assertRaisesMessage(AssertionError,msg):
@test_action.mapping.get @test_action.mapping.get
def test_action_get(request): def test_action_get(request):
raise NotImplementedError raise NotImplementedError
@ -279,9 +251,8 @@ class ActionDecoratorTestCase(TestCase):
def test_action(): def test_action():
raise NotImplementedError raise NotImplementedError
msg = ("Method mapping does not behave like the property decorator. You " msg = ("Method mapping does not behave like the property decorator. You ""cannot use the same method name for each mapping declaration.")
"cannot use the same method name for each mapping declaration.") withself.assertRaisesMessage(AssertionError,msg):
with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.post @test_action.mapping.post
def test_action(): def test_action():
raise NotImplementedError raise NotImplementedError
@ -293,11 +264,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
assert len(record) == 1 assert len(record) == 1
assert str(record[0].message) == ( assertstr(record[0].message)==("`detail_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=True)` instead.")
"`detail_route` is deprecated and will be removed in "
"3.10 in favor of `action`, which accepts a `detail` bool. Use "
"`@action(detail=True)` instead."
)
def test_list_route_deprecation(self): def test_list_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record: with pytest.warns(RemovedInDRF310Warning) as record:
@ -306,11 +273,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
assert len(record) == 1 assert len(record) == 1
assert str(record[0].message) == ( assertstr(record[0].message)==("`list_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=False)` instead.")
"`list_route` is deprecated and will be removed in "
"3.10 in favor of `action`, which accepts a `detail` bool. Use "
"`@action(detail=False)` instead."
)
def test_route_url_name_from_path(self): def test_route_url_name_from_path(self):
# pre-3.8 behavior was to base the `url_name` off of the `url_path` # pre-3.8 behavior was to base the `url_name` off of the `url_path`
@ -320,4 +283,4 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
assert view.url_path == 'foo_bar' assert view.url_path == 'foo_bar'
assert view.url_name == 'foo-bar' assertview.url_name=='foo-bar'

View File

@ -69,10 +69,7 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
</code></pre> </code></pre>
<p>indented</p> <p>indented</p>
<h2 id="hash-style-header">hash style header</h2>%s""" <h2 id="hash-style-header">hash style header</h2>%s"""
def test_view_name_uses_class_name(self):
class TestViewNamesAndDescriptions(TestCase):
def test_view_name_uses_class_name(self):
""" """
Ensure view names are based on the class name. Ensure view names are based on the class name.
""" """
@ -150,9 +147,6 @@ class TestViewNamesAndDescriptions(TestCase):
See: https://github.com/encode/django-rest-framework/issues/1708 See: https://github.com/encode/django-rest-framework/issues/1708
""" """
# use a mock object instead of gettext_lazy to ensure that we can't end
# up with a test case string in our l10n catalog
class MockLazyStr: class MockLazyStr:
def __init__(self, string): def __init__(self, string):
self.s = string self.s = string
@ -171,16 +165,8 @@ class TestViewNamesAndDescriptions(TestCase):
""" """
if apply_markdown: if apply_markdown:
md_applied = apply_markdown(DESCRIPTION) md_applied = apply_markdown(DESCRIPTION)
gte_21_match = ( gte_21_match = ( md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
md_applied == ( lt_21_match = ( md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
lt_21_match = (
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
assert gte_21_match or lt_21_match assert gte_21_match or lt_21_match

View File

@ -15,12 +15,10 @@ class MockList:
return [1, 2, 3] return [1, 2, 3]
class JSONEncoderTests(TestCase): """
"""
Tests the JSONEncoder method Tests the JSONEncoder method
""" """
defsetUp(self):
def setUp(self):
self.encoder = JSONEncoder() self.encoder = JSONEncoder()
def test_encode_decimal(self): def test_encode_decimal(self):
@ -77,7 +75,7 @@ class JSONEncoderTests(TestCase):
assert self.encoder.default(unique_id) == str(unique_id) assert self.encoder.default(unique_id) == str(unique_id)
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_encode_coreapi_raises_error(self): deftest_encode_coreapi_raises_error(self):
""" """
Tests encoding a coreapi objects raises proper error Tests encoding a coreapi objects raises proper error
""" """

View File

@ -7,72 +7,42 @@ from rest_framework.exceptions import (
server_error server_error
) )
def test_get_error_details(self):
class ExceptionTestCase(TestCase):
def test_get_error_details(self):
example = "string" example = "string"
lazy_example = _(example) lazy_example = _(example)
assert _get_error_details(lazy_example) == example assert _get_error_details(lazy_example) == example
assert isinstance( _get_error_details(lazy_example), ErrorDetail )
assert isinstance(
_get_error_details(lazy_example),
ErrorDetail
)
assert _get_error_details({'nested': lazy_example})['nested'] == example assert _get_error_details({'nested': lazy_example})['nested'] == example
assert isinstance( _get_error_details({'nested': lazy_example})['nested'], ErrorDetail )
assert isinstance(
_get_error_details({'nested': lazy_example})['nested'],
ErrorDetail
)
assert _get_error_details([[lazy_example]])[0][0] == example assert _get_error_details([[lazy_example]])[0][0] == example
assert isinstance( _get_error_details([[lazy_example]])[0][0], ErrorDetail )
assert isinstance(
_get_error_details([[lazy_example]])[0][0],
ErrorDetail
)
def test_get_full_details_with_throttling(self): def test_get_full_details_with_throttling(self):
exception = Throttled() exception = Throttled()
assert exception.get_full_details() == { assert exception.get_full_details() == { 'message': 'Request was throttled.', 'code': 'throttled'}
'message': 'Request was throttled.', 'code': 'throttled'}
exception = Throttled(wait=2) exception = Throttled(wait=2)
assert exception.get_full_details() == { assert exception.get_full_details() == { 'message': 'Request was throttled. Expected available in {} seconds.'.format(2), 'code': 'throttled'}
'message': 'Request was throttled. Expected available in {} seconds.'.format(2),
'code': 'throttled'}
exception = Throttled(wait=2, detail='Slow down!') exception = Throttled(wait=2, detail='Slow down!')
assert exception.get_full_details() == { assert exception.get_full_details() == { 'message': 'Slow down! Expected available in {} seconds.'.format(2), 'code': 'throttled'}
'message': 'Slow down! Expected available in {} seconds.'.format(2),
'code': 'throttled'}
class ErrorDetailTests(TestCase):
def test_eq(self): def test_eq(self):
assert ErrorDetail('msg') == ErrorDetail('msg') assert ErrorDetail('msg') == ErrorDetail('msg')
assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code') assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
assert ErrorDetail('msg') == 'msg' assert ErrorDetail('msg') == 'msg'
assert ErrorDetail('msg', 'code') == 'msg' assert ErrorDetail('msg', 'code') == 'msg'
def test_ne(self): def test_ne(self):
assert ErrorDetail('msg1') != ErrorDetail('msg2') assert ErrorDetail('msg1') != ErrorDetail('msg2')
assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid') assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
assert ErrorDetail('msg1') != 'msg2' assert ErrorDetail('msg1') != 'msg2'
assert ErrorDetail('msg1', 'code') != 'msg2' assert ErrorDetail('msg1', 'code') != 'msg2'
def test_repr(self): def test_repr(self):
assert repr(ErrorDetail('msg1')) == \ assert repr(ErrorDetail('msg1')) == 'ErrorDetail(string={!r}, code=None)'.format('msg1')
'ErrorDetail(string={!r}, code=None)'.format('msg1') assert repr(ErrorDetail('msg1', 'code')) == 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
assert repr(ErrorDetail('msg1', 'code')) == \
'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
def test_str(self): def test_str(self):
assert str(ErrorDetail('msg1')) == 'msg1' assert str(ErrorDetail('msg1')) == 'msg1'
@ -83,13 +53,12 @@ class ErrorDetailTests(TestCase):
assert hash(ErrorDetail('msg', 'code')) == hash('msg') assert hash(ErrorDetail('msg', 'code')) == hash('msg')
class TranslationTests(TestCase):
@translation.override('fr') @translation.override('fr')
def test_message(self): deftest_message(self):
# this test largely acts as a sanity test to ensure the translation files are present. # this test largely acts as a sanity test to ensure the translation files are present.
self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.') assert _('A server error occurred.') == 'Une erreur du serveur est survenue.'
self.assertEqual(str(APIException()), 'Une erreur du serveur est survenue.') assert str(APIException()) == 'Une erreur du serveur est survenue.'
def test_server_error(): def test_server_error():

View File

@ -1134,14 +1134,13 @@ class TestNoStringCoercionDecimalField(FieldValues):
) )
class TestLocalizedDecimalField(TestCase): @override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl') deftest_to_internal_value(self):
def test_to_internal_value(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
assert field.to_internal_value('1,1') == Decimal('1.1') assert field.to_internal_value('1,1') == Decimal('1.1')
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl') @override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
def test_to_representation(self): deftest_to_representation(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
assert field.to_representation(Decimal('1.1')) == '1,1' assert field.to_representation(Decimal('1.1')) == '1,1'
@ -1150,8 +1149,7 @@ class TestLocalizedDecimalField(TestCase):
assert isinstance(field.to_representation(Decimal('1.1')), str) assert isinstance(field.to_representation(Decimal('1.1')), str)
class TestQuantizedValueForDecimal(TestCase): def test_int_quantized_value_for_decimal(self):
def test_int_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2) field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value(12).as_tuple() value = field.to_internal_value(12).as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2) expected_digit_tuple = (0, (1, 2, 0, 0), -2)
@ -1185,11 +1183,9 @@ class TestNoDecimalPlaces(FieldValues):
field = serializers.DecimalField(max_digits=6, decimal_places=None) field = serializers.DecimalField(max_digits=6, decimal_places=None)
class TestRoundingDecimalField(TestCase): def test_valid_rounding(self):
def test_valid_rounding(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP) field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
assert field.to_representation(Decimal('1.234')) == '1.24' assert field.to_representation(Decimal('1.234')) == '1.24'
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN) field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
assert field.to_representation(Decimal('1.234')) == '1.23' assert field.to_representation(Decimal('1.234')) == '1.23'
@ -1369,13 +1365,11 @@ class TestTZWithDateTimeField(FieldValues):
@override_settings(TIME_ZONE='UTC', USE_TZ=True) @override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestDefaultTZDateTimeField(TestCase): """
"""
Test the current/default timezone handling in `DateTimeField`. Test the current/default timezone handling in `DateTimeField`.
""" """
@classmethod
@classmethod defsetup_class(cls):
def setup_class(cls):
cls.field = serializers.DateTimeField() cls.field = serializers.DateTimeField()
cls.kolkata = pytz.timezone('Asia/Kolkata') cls.kolkata = pytz.timezone('Asia/Kolkata')
@ -1392,23 +1386,20 @@ class TestDefaultTZDateTimeField(TestCase):
@pytest.mark.skipif(pytz is None, reason='pytz not installed') @pytest.mark.skipif(pytz is None, reason='pytz not installed')
@override_settings(TIME_ZONE='UTC', USE_TZ=True) @override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestCustomTimezoneForDateTimeField(TestCase):
@classmethod @classmethod
def setup_class(cls): defsetup_class(cls):
cls.kolkata = pytz.timezone('Asia/Kolkata') cls.kolkata = pytz.timezone('Asia/Kolkata')
cls.date_format = '%d/%m/%Y %H:%M' cls.date_format = '%d/%m/%Y %H:%M'
def test_should_render_date_time_in_default_timezone(self): def test_should_render_date_time_in_default_timezone(self):
field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format) field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc) dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
with override(self.kolkata): with override(self.kolkata):
rendered_date = field.to_representation(dt) rendered_date = field.to_representation(dt)
rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format) rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format)
assertrendered_date==rendered_date_in_timezone
assert rendered_date == rendered_date_in_timezone
class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues): class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):

View File

@ -13,12 +13,9 @@ from rest_framework.compat import coreschema
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
def setUp(self):
class BaseFilterTests(TestCase):
def setUp(self):
self.original_coreapi = filters.coreapi self.original_coreapi = filters.coreapi
filters.coreapi = True # mock it, because not None value needed filters.coreapi = True
self.filter_backend = filters.BaseFilterBackend() self.filter_backend = filters.BaseFilterBackend()
def tearDown(self): def tearDown(self):
@ -29,12 +26,12 @@ class BaseFilterTests(TestCase):
self.filter_backend.filter_queryset(None, None, None) self.filter_backend.filter_queryset(None, None, None)
@pytest.mark.skipif(not coreschema, reason='coreschema is not installed') @pytest.mark.skipif(not coreschema, reason='coreschema is not installed')
def test_get_schema_fields_checks_for_coreapi(self): deftest_get_schema_fields_checks_for_coreapi(self):
filters.coreapi = None filters.coreapi = None
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
self.filter_backend.get_schema_fields({}) self.filter_backend.get_schema_fields({})
filters.coreapi = True filters.coreapi = True
assert self.filter_backend.get_schema_fields({}) == [] assertself.filter_backend.get_schema_fields({})==[]
class SearchFilterModel(models.Model): class SearchFilterModel(models.Model):
@ -48,8 +45,7 @@ class SearchFilterSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterTests(TestCase): def setUp(self):
def setUp(self):
# Sequence of title/text is: # Sequence of title/text is:
# #
# z abc # z abc
@ -58,11 +54,7 @@ class SearchFilterTests(TestCase):
# ... # ...
for idx in range(10): for idx in range(10):
title = 'z' * (idx + 1) title = 'z' * (idx + 1)
text = ( text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModel(title=title, text=text).save() SearchFilterModel(title=title, text=text).save()
def test_search(self): def test_search(self):
@ -73,12 +65,9 @@ class SearchFilterTests(TestCase):
search_fields = ('title', 'text') search_fields = ('title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'}) request=factory.get('/',{'search':'b'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self): def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
@ -87,11 +76,10 @@ class SearchFilterTests(TestCase):
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
expected = SearchFilterSerializer(SearchFilterModel.objects.all(), expected=SearchFilterSerializer(SearchFilterModel.objects.all(),many=True).data
many=True).data assertresponse.data==expected
assert response.data == expected
def test_exact_search(self): def test_exact_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
@ -101,11 +89,9 @@ class SearchFilterTests(TestCase):
search_fields = ('=title', 'text') search_fields = ('=title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'zzz'}) request=factory.get('/',{'search':'zzz'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
def test_startswith_search(self): def test_startswith_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
@ -115,11 +101,9 @@ class SearchFilterTests(TestCase):
search_fields = ('title', '^text') search_fields = ('title', '^text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'}) request=factory.get('/',{'search':'b'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_regexp_search(self): def test_regexp_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
@ -129,16 +113,13 @@ class SearchFilterTests(TestCase):
search_fields = ('$title', '$text') search_fields = ('$title', '$text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'z{2} ^b'}) request=factory.get('/',{'search':'z{2} ^b'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_with_nonstandard_search_param(self): def test_search_with_nonstandard_search_param(self):
with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
reload_module(filters) reload_module(filters)
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
@ -146,12 +127,9 @@ class SearchFilterTests(TestCase):
search_fields = ('title', 'text') search_fields = ('title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'query': 'b'}) request=factory.get('/',{'query':'b'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
reload_module(filters) reload_module(filters)
@ -170,15 +148,12 @@ class SearchFilterTests(TestCase):
search_fields = ('$title', '$text') search_fields = ('$title', '$text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': r'^\w{3}$'}) request=factory.get('/',{'search':r'^\w{3}$'})
response = view(request) response=view(request)
assert len(response.data) == 10 assertlen(response.data)==10
request=factory.get('/',{'search':r'^\w{3}$','title_only':'true'})
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'}) response=view(request)
response = view(request) assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
class AttributeModel(models.Model): class AttributeModel(models.Model):
@ -196,20 +171,13 @@ class SearchFilterFkSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterFkTests(TestCase):
def test_must_call_distinct(self): def test_must_call_distinct(self):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix] )
SearchFilterModelFk._meta, assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix, "%sattribute__label" % prefix] )
["%stitle" % prefix]
)
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%stitle" % prefix, "%sattribute__label" % prefix]
)
def test_must_call_distinct_restores_meta_for_each_field(self): def test_must_call_distinct_restores_meta_for_each_field(self):
# In this test case the attribute of the fk model comes first in the # In this test case the attribute of the fk model comes first in the
@ -217,10 +185,7 @@ class SearchFilterFkTests(TestCase):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%sattribute__label" % prefix, "%stitle" % prefix] )
SearchFilterModelFk._meta,
["%sattribute__label" % prefix, "%stitle" % prefix]
)
class SearchFilterModelM2M(models.Model): class SearchFilterModelM2M(models.Model):
@ -235,8 +200,7 @@ class SearchFilterM2MSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterM2MTests(TestCase): def setUp(self):
def setUp(self):
# Sequence of title/text/attributes is: # Sequence of title/text/attributes is:
# #
# z abc [1, 2, 3] # z abc [1, 2, 3]
@ -249,11 +213,7 @@ class SearchFilterM2MTests(TestCase):
for idx in range(10): for idx in range(10):
title = 'z' * (idx + 1) title = 'z' * (idx + 1)
text = ( text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModelM2M(title=title, text=text).save() SearchFilterModelM2M(title=title, text=text).save()
SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3) SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3)
@ -265,23 +225,16 @@ class SearchFilterM2MTests(TestCase):
search_fields = ('=title', 'text', 'attributes__label') search_fields = ('=title', 'text', 'attributes__label')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'zz'}) request=factory.get('/',{'search':'zz'})
response = view(request) response=view(request)
assert len(response.data) == 1 assertlen(response.data)==1
def test_must_call_distinct(self): def test_must_call_distinct(self):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix] )
SearchFilterModelM2M._meta, assert filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix] )
["%stitle" % prefix]
)
assert filter_.must_call_distinct(
SearchFilterModelM2M._meta,
["%stitle" % prefix, "%sattributes__label" % prefix]
)
class Blog(models.Model): class Blog(models.Model):
@ -300,18 +253,13 @@ class BlogSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterToManyTests(TestCase):
@classmethod @classmethod
def setUpTestData(cls): defsetUpTestData(cls):
b1 = Blog.objects.create(name='Blog 1') b1 = Blog.objects.create(name='Blog 1')
b2 = Blog.objects.create(name='Blog 2') b2 = Blog.objects.create(name='Blog 2')
# Multiple entries on Lennon published in 1979 - distinct should deduplicate
Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1)) Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1)) Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
# Entry on Lennon *and* a separate entry in 1979 - should not match
Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1)) Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1)) Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
@ -323,9 +271,9 @@ class SearchFilterToManyTests(TestCase):
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year') search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'Lennon,1979'}) request=factory.get('/',{'search':'Lennon,1979'})
response = view(request) response=view(request)
assert len(response.data) == 1 assertlen(response.data)==1
class SearchFilterAnnotatedSerializer(serializers.ModelSerializer): class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
@ -336,28 +284,23 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
fields = ('title', 'text', 'title_text') fields = ('title', 'text', 'title_text')
class SearchFilterAnnotatedFieldTests(TestCase): @classmethod
@classmethod defsetUpTestData(cls):
def setUpTestData(cls):
SearchFilterModel.objects.create(title='abc', text='def') SearchFilterModel.objects.create(title='abc', text='def')
SearchFilterModel.objects.create(title='ghi', text='jkl') SearchFilterModel.objects.create(title='ghi', text='jkl')
def test_search_in_annotated_field(self): def test_search_in_annotated_field(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.annotate( queryset = SearchFilterModel.objects.annotate( title_text=Upper( Concat(models.F('title'), models.F('text')) ) ).all()
title_text=Upper(
Concat(models.F('title'), models.F('text'))
)
).all()
serializer_class = SearchFilterAnnotatedSerializer serializer_class = SearchFilterAnnotatedSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title_text',) search_fields = ('title_text',)
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'ABCDEF'}) request=factory.get('/',{'search':'ABCDEF'})
response = view(request) response=view(request)
assert len(response.data) == 1 assertlen(response.data)==1
assert response.data[0]['title_text'] == 'ABCDEF' assertresponse.data[0]['title_text']=='ABCDEF'
class OrderingFilterModel(models.Model): class OrderingFilterModel(models.Model):
@ -403,24 +346,15 @@ class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class OrderingFilterTests(TestCase): def setUp(self):
def setUp(self):
# Sequence of title/text is: # Sequence of title/text is:
# #
# zyx abc # zyx abc
# yxw bcd # yxw bcd
# xwv cde # xwv cde
for idx in range(3): for idx in range(3):
title = ( title = ( chr(ord('z') - idx) + chr(ord('y') - idx) + chr(ord('x') - idx) )
chr(ord('z') - idx) + text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(ord('y') - idx) +
chr(ord('x') - idx)
)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
OrderingFilterModel(title=title, text=text).save() OrderingFilterModel(title=title, text=text).save()
def test_ordering(self): def test_ordering(self):
@ -432,13 +366,9 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request=factory.get('/',{'ordering':'text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
def test_reverse_ordering(self): def test_reverse_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -449,13 +379,9 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-text'}) request=factory.get('/',{'ordering':'-text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_incorrecturl_extrahyphens_ordering(self): def test_incorrecturl_extrahyphens_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -466,13 +392,9 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '--text'}) request=factory.get('/',{'ordering':'--text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_incorrectfield_ordering(self): def test_incorrectfield_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -483,13 +405,9 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'foobar'}) request=factory.get('/',{'ordering':'foobar'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_default_ordering(self): def test_default_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -500,13 +418,9 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('') request=factory.get('')
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_default_ordering_using_string(self): def test_default_ordering_using_string(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -517,23 +431,16 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('') request=factory.get('')
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_ordering_by_aggregate_field(self): def test_ordering_by_aggregate_field(self):
# create some related models to aggregate order by # create some related models to aggregate order by
num_objs = [2, 5, 3] num_objs = [2, 5, 3]
for obj, num_relateds in zip(OrderingFilterModel.objects.all(), for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs):
num_objs):
for _ in range(num_relateds): for _ in range(num_relateds):
new_related = OrderingFilterRelatedModel( new_related = OrderingFilterRelatedModel( related_object=obj )
related_object=obj
)
new_related.save() new_related.save()
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -541,25 +448,17 @@ class OrderingFilterTests(TestCase):
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = 'title' ordering = 'title'
ordering_fields = '__all__' ordering_fields = '__all__'
queryset = OrderingFilterModel.objects.all().annotate( queryset = OrderingFilterModel.objects.all().annotate( models.Count("relateds"))
models.Count("relateds"))
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'relateds__count'}) request=factory.get('/',{'ordering':'relateds__count'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
]
def test_ordering_by_dotted_source(self): def test_ordering_by_dotted_source(self):
for index, obj in enumerate(OrderingFilterModel.objects.all()): for index, obj in enumerate(OrderingFilterModel.objects.all()):
OrderingFilterRelatedModel.objects.create( OrderingFilterRelatedModel.objects.create( related_object=obj, index=index )
related_object=obj,
index=index
)
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
serializer_class = OrderingDottedRelatedSerializer serializer_class = OrderingDottedRelatedSerializer
@ -567,26 +466,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterRelatedModel.objects.all() queryset = OrderingFilterRelatedModel.objects.all()
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'related_object__text'}) request=factory.get('/',{'ordering':'related_object__text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'related_title':'zyx','related_text':'abc','index':0},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'xwv','related_text':'cde','index':2},]
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0}, request=factory.get('/',{'ordering':'-index'})
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1}, response=view(request)
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2}, assertresponse.data==[{'related_title':'xwv','related_text':'cde','index':2},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'zyx','related_text':'abc','index':0},]
]
request = factory.get('/', {'ordering': '-index'})
response = view(request)
assert response.data == [
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
]
def test_ordering_with_nonstandard_ordering_param(self): def test_ordering_with_nonstandard_ordering_param(self):
with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
reload_module(filters) reload_module(filters)
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
@ -595,13 +484,9 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'order': 'text'}) request=factory.get('/',{'order':'text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
reload_module(filters) reload_module(filters)
@ -613,30 +498,22 @@ class OrderingFilterTests(TestCase):
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html') request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html')
view = OrderingListView.as_view() view=OrderingListView.as_view()
response = view(request) response=view(request)
self.assertContains(response,'verbose title')
self.assertContains(response, 'verbose title')
def test_ordering_with_overridden_get_serializer_class(self): def test_ordering_with_overridden_get_serializer_class(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
# note: no ordering_fields and serializer_class specified
def get_serializer_class(self): def get_serializer_class(self):
return OrderingFilterSerializer return OrderingFilterSerializer
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request=factory.get('/',{'ordering':'text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
def test_ordering_with_improper_configuration(self): def test_ordering_with_improper_configuration(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -647,8 +524,8 @@ class OrderingFilterTests(TestCase):
# or get_serializer_class specified # or get_serializer_class specified
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request=factory.get('/',{'ordering':'text'})
with self.assertRaises(ImproperlyConfigured): withpytest.raises(ImproperlyConfigured):
view(request) view(request)
@ -684,63 +561,44 @@ class SensitiveDataSerializer3(serializers.ModelSerializer):
fields = ('id', 'user') fields = ('id', 'user')
class SensitiveOrderingFilterTests(TestCase): def setUp(self):
def setUp(self):
for idx in range(3): for idx in range(3):
username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
SensitiveOrderingFilterModel(username=username, password=password).save() SensitiveOrderingFilterModel(username=username, password=password).save()
def test_order_by_serializer_fields(self): def test_order_by_serializer_fields(self):
for serializer_cls in [ for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
SensitiveDataSerializer1,
SensitiveDataSerializer2,
SensitiveDataSerializer3
]:
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-username'}) request=factory.get('/',{'ordering':'-username'})
response = view(request) response=view(request)
ifserializer_cls==SensitiveDataSerializer3:
if serializer_cls == SensitiveDataSerializer3:
username_field = 'user' username_field = 'user'
else: else:
username_field = 'username' username_field = 'username'
# Note: Inverse username ordering correctly applied. # Note: Inverse username ordering correctly applied.
assert response.data == [ assert response.data == [{'id':3,username_field:'userC'},{'id':2,username_field:'userB'},{'id':1,username_field:'userA'},]
{'id': 3, username_field: 'userC'},
{'id': 2, username_field: 'userB'},
{'id': 1, username_field: 'userA'},
]
def test_cannot_order_by_non_serializer_fields(self): def test_cannot_order_by_non_serializer_fields(self):
for serializer_cls in [ for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
SensitiveDataSerializer1,
SensitiveDataSerializer2,
SensitiveDataSerializer3
]:
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'password'}) request=factory.get('/',{'ordering':'password'})
response = view(request) response=view(request)
ifserializer_cls==SensitiveDataSerializer3:
if serializer_cls == SensitiveDataSerializer3:
username_field = 'user' username_field = 'user'
else: else:
username_field = 'username' username_field = 'username'
# Note: The passwords are not in order. Default ordering is used. # Note: The passwords are not in order. Default ordering is used.
assert response.data == [ assert response.data == [{'id':1,username_field:'userA'},{'id':2,username_field:'userB'},{'id':3,username_field:'userC'},]
{'id': 1, username_field: 'userA'}, # PassB
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
]

View File

@ -23,10 +23,8 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_generateschema') @override_settings(ROOT_URLCONF='tests.test_generateschema')
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class GenerateSchemaTests(TestCase): """Tests for management command generateschema."""
"""Tests for management command generateschema.""" defsetUp(self):
def setUp(self):
self.out = io.StringIO() self.out = io.StringIO()
def test_renders_default_schema_with_custom_title_url_and_description(self): def test_renders_default_schema_with_custom_title_url_and_description(self):
@ -42,45 +40,16 @@ class GenerateSchemaTests(TestCase):
servers: servers:
- url: http://api.sample.com/ - url: http://api.sample.com/
""" """
call_command('generateschema', call_command('generateschema', '--title=SampleAPI', '--url=http://api.sample.com', '--description=Sample description', stdout=self.out)
'--title=SampleAPI', assert formatting.dedent(expected_out) in self.out.getvalue()
'--url=http://api.sample.com',
'--description=Sample description',
stdout=self.out)
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
def test_renders_openapi_json_schema(self): def test_renders_openapi_json_schema(self):
expected_out = { expected_out = { "openapi": "3.0.0", "info": { "version": "", "title": "", "description": "" }, "servers": [ { "url": "" } ], "paths": { "/": { "get": { "operationId": "list" } } } }
"openapi": "3.0.0", call_command('generateschema', '--format=openapi-json', stdout=self.out)
"info": {
"version": "",
"title": "",
"description": ""
},
"servers": [
{
"url": ""
}
],
"paths": {
"/": {
"get": {
"operationId": "list"
}
}
}
}
call_command('generateschema',
'--format=openapi-json',
stdout=self.out)
out_json = json.loads(self.out.getvalue()) out_json = json.loads(self.out.getvalue())
assert out_json == expected_out
self.assertDictEqual(out_json, expected_out)
def test_renders_corejson_schema(self): def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema', call_command('generateschema', '--format=corejson', stdout=self.out)
'--format=corejson', assert expected_out in self.out.getvalue()
stdout=self.out)
self.assertIn(expected_out, self.out.getvalue())

View File

@ -75,8 +75,7 @@ class SlugBasedInstanceView(InstanceView):
# Tests # Tests
class TestRootView(TestCase): def setUp(self):
def setUp(self):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
@ -84,11 +83,8 @@ class TestRootView(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text} self.view=RootView.as_view()
for obj in self.objects.all()
]
self.view = RootView.as_view()
def test_get_root_view(self): def test_get_root_view(self):
""" """
@ -98,7 +94,7 @@ class TestRootView(TestCase):
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data assertresponse.data==self.data
def test_head_root_view(self): def test_head_root_view(self):
""" """
@ -118,9 +114,9 @@ class TestRootView(TestCase):
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'} assertresponse.data=={'id':4,'text':'foobar'}
created = self.objects.get(id=4) created=self.objects.get(id=4)
assert created.text == 'foobar' assertcreated.text=='foobar'
def test_put_root_view(self): def test_put_root_view(self):
""" """
@ -131,7 +127,7 @@ class TestRootView(TestCase):
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "PUT" not allowed.'} assertresponse.data=={"detail":'Method "PUT" not allowed.'}
def test_delete_root_view(self): def test_delete_root_view(self):
""" """
@ -141,7 +137,7 @@ class TestRootView(TestCase):
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "DELETE" not allowed.'} assertresponse.data=={"detail":'Method "DELETE" not allowed.'}
def test_post_cannot_set_id(self): def test_post_cannot_set_id(self):
""" """
@ -152,9 +148,9 @@ class TestRootView(TestCase):
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'} assertresponse.data=={'id':4,'text':'foobar'}
created = self.objects.get(id=4) created=self.objects.get(id=4)
assert created.text == 'foobar' assertcreated.text=='foobar'
def test_post_error_root_view(self): def test_post_error_root_view(self):
""" """
@ -168,10 +164,7 @@ class TestRootView(TestCase):
EXPECTED_QUERIES_FOR_PUT = 2 EXPECTED_QUERIES_FOR_PUT = 2
def setUp(self):
class TestInstanceView(TestCase):
def setUp(self):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
@ -179,12 +172,9 @@ class TestInstanceView(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out') self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text} self.view=InstanceView.as_view()
for obj in self.objects.all() self.slug_based_view=SlugBasedInstanceView.as_view()
]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
def test_get_instance_view(self): def test_get_instance_view(self):
""" """
@ -194,7 +184,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0] assertresponse.data==self.data[0]
def test_post_instance_view(self): def test_post_instance_view(self):
""" """
@ -205,7 +195,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "POST" not allowed.'} assertresponse.data=={"detail":'Method "POST" not allowed.'}
def test_put_instance_view(self): def test_put_instance_view(self):
""" """
@ -216,9 +206,9 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render() response = self.view(request, pk='1').render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'} assertdict(response.data)=={'id':1,'text':'foobar'}
updated = self.objects.get(id=1) updated=self.objects.get(id=1)
assert updated.text == 'foobar' assertupdated.text=='foobar'
def test_patch_instance_view(self): def test_patch_instance_view(self):
""" """
@ -226,13 +216,12 @@ class TestInstanceView(TestCase):
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.patch('/1', data, format='json') request = factory.patch('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'} assertresponse.data=={'id':1,'text':'foobar'}
updated = self.objects.get(id=1) updated=self.objects.get(id=1)
assert updated.text == 'foobar' assertupdated.text=='foobar'
def test_delete_instance_view(self): def test_delete_instance_view(self):
""" """
@ -242,9 +231,9 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == b'' assertresponse.content==b''
ids = [obj.id for obj in self.objects.all()] ids=[obj.idforobjinself.objects.all()]
assert ids == [2, 3] assertids==[2,3]
def test_get_instance_view_incorrect_arg(self): def test_get_instance_view_incorrect_arg(self):
""" """
@ -265,9 +254,9 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'} assertresponse.data=={'id':1,'text':'foobar'}
updated = self.objects.get(id=1) updated=self.objects.get(id=1)
assert updated.text == 'foobar' assertupdated.text=='foobar'
def test_put_to_deleted_instance(self): def test_put_to_deleted_instance(self):
""" """
@ -301,7 +290,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=999).render() response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert not self.objects.filter(id=999).exists() assertnotself.objects.filter(id=999).exists()
def test_put_error_instance_view(self): def test_put_error_instance_view(self):
""" """
@ -314,8 +303,7 @@ class TestInstanceView(TestCase):
assert expected_error in response.rendered_content.decode() assert expected_error in response.rendered_content.decode()
class TestFKInstanceView(TestCase): def setUp(self):
def setUp(self):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
@ -326,20 +314,15 @@ class TestFKInstanceView(TestCase):
ForeignKeySource(name='source_' + item, target=t).save() ForeignKeySource(name='source_' + item, target=t).save()
self.objects = ForeignKeySource.objects self.objects = ForeignKeySource.objects
self.data = [ self.data=[{'id':obj.id,'name':obj.name}forobjinself.objects.all()]
{'id': obj.id, 'name': obj.name} self.view=FKInstanceView.as_view()
for obj in self.objects.all()
]
self.view = FKInstanceView.as_view()
class TestOverriddenGetObject(TestCase): """
"""
Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
queryset/model mechanism but instead overrides get_object() queryset/model mechanism but instead overrides get_object()
""" """
defsetUp(self):
def setUp(self):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
@ -347,17 +330,12 @@ class TestOverriddenGetObject(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text} classOverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
for obj in self.objects.all()
]
class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
""" """
Example detail view for override of get_object(). Example detail view for override of get_object().
""" """
serializer_class = BasicSerializer serializer_class = BasicSerializer
def get_object(self): def get_object(self):
pk = int(self.kwargs['pk']) pk = int(self.kwargs['pk'])
return get_object_or_404(BasicModel.objects.all(), id=pk) return get_object_or_404(BasicModel.objects.all(), id=pk)
@ -372,7 +350,7 @@ class TestOverriddenGetObject(TestCase):
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0] assertresponse.data==self.data[0]
# Regression test for #285 # Regression test for #285
@ -388,8 +366,7 @@ class CommentView(generics.ListCreateAPIView):
model = Comment model = Comment
class TestCreateModelWithAutoNowAddField(TestCase): def setUp(self):
def setUp(self):
self.objects = Comment.objects self.objects = Comment.objects
self.view = CommentView.as_view() self.view = CommentView.as_view()
@ -432,8 +409,7 @@ class ExampleView(generics.ListCreateAPIView):
queryset = ClassA.objects.all() queryset = ClassA.objects.all()
class TestM2MBrowsableAPI(TestCase): def test_m2m_in_browsable_api(self):
def test_m2m_in_browsable_api(self):
""" """
Test for particularly ugly regression with m2m in browsable API Test for particularly ugly regression with m2m in browsable API
""" """
@ -476,8 +452,7 @@ class DynamicSerializerView(generics.ListCreateAPIView):
return DynamicSerializer return DynamicSerializer
class TestFilterBackendAppliedToViews(TestCase): def setUp(self):
def setUp(self):
""" """
Create 3 BasicModel instances to filter on. Create 3 BasicModel instances to filter on.
""" """
@ -485,10 +460,7 @@ class TestFilterBackendAppliedToViews(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
def test_get_root_view_filters_by_name_with_filter_backend(self): def test_get_root_view_filters_by_name_with_filter_backend(self):
""" """
@ -543,32 +515,29 @@ class TestFilterBackendAppliedToViews(TestCase):
assert 'field_a' not in content assert 'field_a' not in content
class TestGuardedQueryset(TestCase): def test_guarded_queryset(self):
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView): class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all() queryset = BasicModel.objects.all()
def get(self, request): def get(self, request):
return Response(list(self.queryset)) return Response(list(self.queryset))
view = QuerysetAccessError.as_view() view = QuerysetAccessError.as_view()
request = factory.get('/') request=factory.get('/')
with pytest.raises(RuntimeError): withpytest.raises(RuntimeError):
view(request).render() view(request).render()
class ApiViewsTests(TestCase):
def test_create_api_view_post(self): def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView): class MockCreateApiView(generics.CreateAPIView):
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockCreateApiView() view = MockCreateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.post('test request', 'test arg', test_kwarg='test') view.post('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_destroy_api_view_delete(self): def test_destroy_api_view_delete(self):
class MockDestroyApiView(generics.DestroyAPIView): class MockDestroyApiView(generics.DestroyAPIView):
@ -576,10 +545,10 @@ class ApiViewsTests(TestCase):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockDestroyApiView() view = MockDestroyApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.delete('test request', 'test arg', test_kwarg='test') view.delete('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_update_api_view_partial_update(self): def test_update_api_view_partial_update(self):
class MockUpdateApiView(generics.UpdateAPIView): class MockUpdateApiView(generics.UpdateAPIView):
@ -587,10 +556,10 @@ class ApiViewsTests(TestCase):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockUpdateApiView() view = MockUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.patch('test request', 'test arg', test_kwarg='test') view.patch('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_update_api_view_get(self): def test_retrieve_update_api_view_get(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -598,10 +567,10 @@ class ApiViewsTests(TestCase):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.get('test request', 'test arg', test_kwarg='test') view.get('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_update_api_view_put(self): def test_retrieve_update_api_view_put(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -609,10 +578,10 @@ class ApiViewsTests(TestCase):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.put('test request', 'test arg', test_kwarg='test') view.put('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_update_api_view_patch(self): def test_retrieve_update_api_view_patch(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -620,10 +589,10 @@ class ApiViewsTests(TestCase):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.patch('test request', 'test arg', test_kwarg='test') view.patch('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_destroy_api_view_get(self): def test_retrieve_destroy_api_view_get(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@ -631,10 +600,10 @@ class ApiViewsTests(TestCase):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView() view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.get('test request', 'test arg', test_kwarg='test') view.get('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_destroy_api_view_delete(self): def test_retrieve_destroy_api_view_delete(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@ -642,21 +611,18 @@ class ApiViewsTests(TestCase):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView() view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.delete('test request', 'test arg', test_kwarg='test') view.delete('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
class GetObjectOr404Tests(TestCase): def setUp(self):
def setUp(self):
super().setUp() super().setUp()
self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar') self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
def test_get_object_or_404_with_valid_uuid(self): def test_get_object_or_404_with_valid_uuid(self):
obj = generics.get_object_or_404( obj = generics.get_object_or_404( UUIDForeignKeyTarget, pk=self.uuid_object.pk )
UUIDForeignKeyTarget, pk=self.uuid_object.pk
)
assert obj == self.uuid_object assert obj == self.uuid_object
def test_get_object_or_404_with_invalid_string_for_uuid(self): def test_get_object_or_404_with_invalid_string_for_uuid(self):

View File

@ -42,19 +42,17 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer') @override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
class TemplateHTMLRendererTests(TestCase): def setUp(self):
def setUp(self):
class MockResponse: class MockResponse:
template_name = None template_name = None
self.mock_response = MockResponse() self.mock_response = MockResponse()
self._monkey_patch_get_template() self._monkey_patch_get_template()
def _monkey_patch_get_template(self): def _monkey_patch_get_template(self):
""" """
Monkeypatch get_template Monkeypatch get_template
""" """
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name, dirs=None): def get_template(template_name, dirs=None):
if template_name == 'example.html': if template_name == 'example.html':
return engines['django'].from_string("example: {{ object }}") return engines['django'].from_string("example: {{ object }}")
@ -66,7 +64,7 @@ class TemplateHTMLRendererTests(TestCase):
raise TemplateDoesNotExist(template_name_list[0]) raise TemplateDoesNotExist(template_name_list[0])
django.template.loader.get_template = get_template django.template.loader.get_template = get_template
django.template.loader.select_template = select_template django.template.loader.select_template=select_template
def tearDown(self): def tearDown(self):
""" """
@ -77,19 +75,19 @@ class TemplateHTMLRendererTests(TestCase):
def test_simple_html_view(self): def test_simple_html_view(self):
response = self.client.get('/') response = self.client.get('/')
self.assertContains(response, "example: foobar") self.assertContains(response, "example: foobar")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_not_found_html_view(self): def test_not_found_html_view(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
self.assertEqual(response.content, b"404 Not Found") assert response.content == b"404 Not Found"
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_permission_denied_html_view(self): def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(response.content, b"403 Forbidden") assert response.content == b"403 Forbidden"
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
# 2 tests below are based on order of if statements in corresponding method # 2 tests below are based on order of if statements in corresponding method
# of TemplateHTMLRenderer # of TemplateHTMLRenderer
@ -101,7 +99,6 @@ class TemplateHTMLRendererTests(TestCase):
def test_get_template_names_returns_view_template_name(self): def test_get_template_names_returns_view_template_name(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
class MockResponse: class MockResponse:
template_name = None template_name = None
@ -112,13 +109,10 @@ class TemplateHTMLRendererTests(TestCase):
class MockView2: class MockView2:
template_name = 'template from template_name attribute' template_name = 'template from template_name attribute'
template_name = renderer.get_template_names(self.mock_response, template_name = renderer.get_template_names(self.mock_response,MockView())
MockView()) asserttemplate_name==['template from get_template_names method']
assert template_name == ['template from get_template_names method'] template_name=renderer.get_template_names(self.mock_response,MockView2())
asserttemplate_name==['template from template_name attribute']
template_name = renderer.get_template_names(self.mock_response,
MockView2())
assert template_name == ['template from template_name attribute']
def test_get_template_names_raises_error_if_no_template_found(self): def test_get_template_names_raises_error_if_no_template_found(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
@ -127,13 +121,11 @@ class TemplateHTMLRendererTests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer') @override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
class TemplateHTMLRendererExceptionTests(TestCase): def setUp(self):
def setUp(self):
""" """
Monkeypatch get_template Monkeypatch get_template
""" """
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name): def get_template(template_name):
if template_name == '404.html': if template_name == '404.html':
return engines['django'].from_string("404: {{ detail }}") return engines['django'].from_string("404: {{ detail }}")
@ -151,13 +143,12 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self): def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
self.assertTrue(response.content in ( assert response.content in ( b"404: Not found", b"404 Not Found")
b"404: Not found", b"404 Not Found")) assert response['Content-Type'] == 'text/html; charset=utf-8'
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self): def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertTrue(response.content in (b"403: Permission denied", b"403 Forbidden")) assert response.content in (b"403: Permission denied", b"403 Forbidden")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'

View File

@ -34,8 +34,7 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks') @override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks')
class TestLazyHyperlinkNames(TestCase): def setUp(self):
def setUp(self):
self.example = Example.objects.create(text='foo') self.example = Example.objects.create(text='foo')
def test_lazy_hyperlink_names(self): def test_lazy_hyperlink_names(self):

View File

@ -308,8 +308,7 @@ class TestMetadata:
assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer) assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer)
class TestSimpleMetadataFieldInfo(TestCase): def test_null_boolean_field_info_type(self):
def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField()) field_info = options.get_field_info(serializers.NullBooleanField())
assert field_info['type'] == 'boolean' assert field_info['type'] == 'boolean'
@ -318,14 +317,11 @@ class TestSimpleMetadataFieldInfo(TestCase):
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
BasicModel.objects.create() BasicModel.objects.create()
with self.assertNumQueries(0): with self.assertNumQueries(0):
field_info = options.get_field_info( field_info = options.get_field_info( serializers.RelatedField(queryset=BasicModel.objects.all()) )
serializers.RelatedField(queryset=BasicModel.objects.all())
)
assert 'choices' not in field_info assert 'choices' not in field_info
class TestModelSerializerMetadata(TestCase): def test_read_only_primary_key_related_field(self):
def test_read_only_primary_key_related_field(self):
""" """
On generic views OPTIONS should return an 'actions' key with metadata On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests. It should on the fields that may be supplied to PUT and POST requests. It should
@ -341,7 +337,6 @@ class TestModelSerializerMetadata(TestCase):
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
children = serializers.PrimaryKeyRelatedField(read_only=True, many=True) children = serializers.PrimaryKeyRelatedField(read_only=True, many=True)
class Meta: class Meta:
model = Parent model = Parent
fields = '__all__' fields = '__all__'
@ -355,51 +350,7 @@ class TestModelSerializerMetadata(TestCase):
return ExampleSerializer() return ExampleSerializer()
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response=view(request=request)
expected = { expected={'name':'Example','description':'Example view.','renders':['application/json','text/html'],'parses':['application/json','application/x-www-form-urlencoded','multipart/form-data'],'actions':{'POST':{'id':{'type':'integer','required':False,'read_only':True,'label':'ID'},'children':{'type':'field','required':False,'read_only':True,'label':'Children'},'integer_field':{'type':'integer','required':True,'read_only':False,'label':'Integer field','min_value':1,'max_value':1000},'name':{'type':'string','required':False,'read_only':False,'label':'Name','max_length':100}}}}
'name': 'Example', assertresponse.status_code==status.HTTP_200_OK
'description': 'Example view.', assertresponse.data==expected
'renders': [
'application/json',
'text/html'
],
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
],
'actions': {
'POST': {
'id': {
'type': 'integer',
'required': False,
'read_only': True,
'label': 'ID'
},
'children': {
'type': 'field',
'required': False,
'read_only': True,
'label': 'Children'
},
'integer_field': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'Integer field',
'min_value': 1,
'max_value': 1000
},
'name': {
'type': 'string',
'required': False,
'read_only': False,
'label': 'Name',
'max_length': 100
}
}
}
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected

View File

@ -110,23 +110,17 @@ class UniqueChoiceModel(models.Model):
name = models.CharField(max_length=254, unique=True, choices=CHOICES) name = models.CharField(max_length=254, unique=True, choices=CHOICES)
class TestModelSerializer(TestCase): def test_create_method(self):
def test_create_method(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
non_model_field = serializers.CharField() non_model_field = serializers.CharField()
class Meta: class Meta:
model = OneFieldModel model = OneFieldModel
fields = ('char_field', 'non_model_field') fields = ('char_field', 'non_model_field')
serializer = TestSerializer(data={ serializer = TestSerializer(data={'char_field':'foo','non_model_field':'bar',})
'char_field': 'foo', serializer.is_valid()
'non_model_field': 'bar', msginitial='Got a `TypeError` when calling `OneFieldModel.objects.create()`.'
}) withself.assertRaisesMessage(TypeError,msginitial):
serializer.is_valid()
msginitial = 'Got a `TypeError` when calling `OneFieldModel.objects.create()`.'
with self.assertRaisesMessage(TypeError, msginitial):
serializer.save() serializer.save()
def test_abstract_model(self): def test_abstract_model(self):
@ -136,7 +130,6 @@ class TestModelSerializer(TestCase):
""" """
class AbstractModel(models.Model): class AbstractModel(models.Model):
afield = models.CharField(max_length=255) afield = models.CharField(max_length=255)
class Meta: class Meta:
abstract = True abstract = True
@ -145,17 +138,13 @@ class TestModelSerializer(TestCase):
model = AbstractModel model = AbstractModel
fields = ('afield',) fields = ('afield',)
serializer = TestSerializer(data={ serializer = TestSerializer(data={'afield':'foo',})
'afield': 'foo', msginitial='Cannot use ModelSerializer with Abstract Models.'
}) withself.assertRaisesMessage(ValueError,msginitial):
msginitial = 'Cannot use ModelSerializer with Abstract Models.'
with self.assertRaisesMessage(ValueError, msginitial):
serializer.is_valid() serializer.is_valid()
class TestRegularFieldMappings(TestCase): def test_regular_fields(self):
def test_regular_fields(self):
""" """
Model fields should map to their equivalent serializer fields. Model fields should map to their equivalent serializer fields.
""" """
@ -189,8 +178,7 @@ class TestRegularFieldMappings(TestCase):
custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>) custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>)
file_path_field = FilePathField(path='/tmp/') file_path_field = FilePathField(path='/tmp/')
""") """)
assertrepr(TestSerializer())==expected
self.assertEqual(repr(TestSerializer()), expected)
def test_field_options(self): def test_field_options(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -209,12 +197,12 @@ class TestRegularFieldMappings(TestCase):
descriptive_field = IntegerField(help_text='Some help text', label='A label') descriptive_field = IntegerField(help_text='Some help text', label='A label')
choices_field = ChoiceField(choices=(('red', 'Red'), ('blue', 'Blue'), ('green', 'Green'))) choices_field = ChoiceField(choices=(('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')))
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
# merge this into test_regular_fields / RegularFieldsModel when # merge this into test_regular_fields / RegularFieldsModel when
# Django 2.1 is the minimum supported version # Django 2.1 is the minimum supported version
@pytest.mark.skipif(django.VERSION < (2, 1), reason='Django version < 2.1') @pytest.mark.skipif(django.VERSION < (2, 1), reason='Django version < 2.1')
def test_nullable_boolean_field(self): deftest_nullable_boolean_field(self):
class NullableBooleanModel(models.Model): class NullableBooleanModel(models.Model):
field = models.BooleanField(null=True, default=False) field = models.BooleanField(null=True, default=False)
@ -227,8 +215,7 @@ class TestRegularFieldMappings(TestCase):
NullableBooleanSerializer(): NullableBooleanSerializer():
field = BooleanField(allow_null=True, required=False) field = BooleanField(allow_null=True, required=False)
""") """)
assertrepr(NullableBooleanSerializer())==expected
self.assertEqual(repr(NullableBooleanSerializer()), expected)
def test_method_field(self): def test_method_field(self):
""" """
@ -245,7 +232,7 @@ class TestRegularFieldMappings(TestCase):
auto_field = IntegerField(read_only=True) auto_field = IntegerField(read_only=True)
method = ReadOnlyField() method = ReadOnlyField()
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_fields(self): def test_pk_fields(self):
""" """
@ -261,7 +248,7 @@ class TestRegularFieldMappings(TestCase):
pk = IntegerField(label='Auto field', read_only=True) pk = IntegerField(label='Auto field', read_only=True)
auto_field = IntegerField(read_only=True) auto_field = IntegerField(read_only=True)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_extra_field_kwargs(self): def test_extra_field_kwargs(self):
""" """
@ -278,7 +265,7 @@ class TestRegularFieldMappings(TestCase):
auto_field = IntegerField(read_only=True) auto_field = IntegerField(read_only=True)
char_field = CharField(default='extra', max_length=100) char_field = CharField(default='extra', max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_extra_field_kwargs_required(self): def test_extra_field_kwargs_required(self):
""" """
@ -295,7 +282,7 @@ class TestRegularFieldMappings(TestCase):
auto_field = IntegerField(read_only=False, required=False) auto_field = IntegerField(read_only=False, required=False)
char_field = CharField(max_length=100) char_field = CharField(max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_invalid_field(self): def test_invalid_field(self):
""" """
@ -308,7 +295,7 @@ class TestRegularFieldMappings(TestCase):
fields = ('auto_field', 'invalid') fields = ('auto_field', 'invalid')
expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.' expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.'
with self.assertRaisesMessage(ImproperlyConfigured, expected): withself.assertRaisesMessage(ImproperlyConfigured,expected):
TestSerializer().fields TestSerializer().fields
def test_missing_field(self): def test_missing_field(self):
@ -318,16 +305,12 @@ class TestRegularFieldMappings(TestCase):
""" """
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
missing = serializers.ReadOnlyField() missing = serializers.ReadOnlyField()
class Meta: class Meta:
model = RegularFieldsModel model = RegularFieldsModel
fields = ('auto_field',) fields = ('auto_field',)
expected = ( expected = ("The field 'missing' was declared on serializer TestSerializer, ""but has not been included in the 'fields' option.")
"The field 'missing' was declared on serializer TestSerializer, " withself.assertRaisesMessage(AssertionError,expected):
"but has not been included in the 'fields' option."
)
with self.assertRaisesMessage(AssertionError, expected):
TestSerializer().fields TestSerializer().fields
def test_missing_superclass_field(self): def test_missing_superclass_field(self):
@ -354,8 +337,7 @@ class TestRegularFieldMappings(TestCase):
ExampleSerializer() ExampleSerializer()
class TestDurationFieldMapping(TestCase): def test_duration_field(self):
def test_duration_field(self):
class DurationFieldModel(models.Model): class DurationFieldModel(models.Model):
""" """
A model that defines DurationField. A model that defines DurationField.
@ -372,16 +354,14 @@ class TestDurationFieldMapping(TestCase):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
duration_field = DurationField() duration_field = DurationField()
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_duration_field_with_validators(self): def test_duration_field_with_validators(self):
class ValidatedDurationFieldModel(models.Model): class ValidatedDurationFieldModel(models.Model):
""" """
A model that defines DurationField with validators. A model that defines DurationField with validators.
""" """
duration_field = models.DurationField( duration_field = models.DurationField( validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))] )
validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))]
)
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -392,16 +372,15 @@ class TestDurationFieldMapping(TestCase):
TestSerializer(): TestSerializer():
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
duration_field = DurationField(max_value=datetime.timedelta(3), min_value=datetime.timedelta(1)) duration_field = DurationField(max_value=datetime.timedelta(3), min_value=datetime.timedelta(1))
""") if sys.version_info < (3, 7) else dedent(""" """)ifsys.version_info<(3,7)elsededent("""
TestSerializer(): TestSerializer():
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
duration_field = DurationField(max_value=datetime.timedelta(days=3), min_value=datetime.timedelta(days=1)) duration_field = DurationField(max_value=datetime.timedelta(days=3), min_value=datetime.timedelta(days=1))
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
class TestGenericIPAddressFieldValidation(TestCase): def test_ip_address_validation(self):
def test_ip_address_validation(self):
class IPAddressFieldModel(models.Model): class IPAddressFieldModel(models.Model):
address = models.GenericIPAddressField() address = models.GenericIPAddressField()
@ -411,15 +390,12 @@ class TestGenericIPAddressFieldValidation(TestCase):
fields = '__all__' fields = '__all__'
s = TestSerializer(data={'address': 'not an ip address'}) s = TestSerializer(data={'address': 'not an ip address'})
self.assertFalse(s.is_valid()) assertnots.is_valid()
self.assertEqual(1, len(s.errors['address']), assert1==len(s.errors['address']),'Unexpected number of validation errors: ''{}'.format(s.errors)
'Unexpected number of validation errors: '
'{}'.format(s.errors))
@pytest.mark.skipif('not postgres_fields') @pytest.mark.skipif('not postgres_fields')
class TestPosgresFieldsMapping(TestCase): def test_hstore_field(self):
def test_hstore_field(self):
class HStoreFieldModel(models.Model): class HStoreFieldModel(models.Model):
hstore_field = postgres_fields.HStoreField() hstore_field = postgres_fields.HStoreField()
@ -432,7 +408,7 @@ class TestPosgresFieldsMapping(TestCase):
TestSerializer(): TestSerializer():
hstore_field = HStoreField() hstore_field = HStoreField()
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_array_field(self): def test_array_field(self):
class ArrayFieldModel(models.Model): class ArrayFieldModel(models.Model):
@ -447,7 +423,7 @@ class TestPosgresFieldsMapping(TestCase):
TestSerializer(): TestSerializer():
array_field = ListField(child=CharField(label='Array field', validators=[<django.core.validators.MaxLengthValidator object>])) array_field = ListField(child=CharField(label='Array field', validators=[<django.core.validators.MaxLengthValidator object>]))
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_json_field(self): def test_json_field(self):
class JSONFieldModel(models.Model): class JSONFieldModel(models.Model):
@ -462,7 +438,7 @@ class TestPosgresFieldsMapping(TestCase):
TestSerializer(): TestSerializer():
json_field = JSONField(style={'base_template': 'textarea.html'}) json_field = JSONField(style={'base_template': 'textarea.html'})
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
# Tests for relational field mappings. # Tests for relational field mappings.
@ -505,8 +481,7 @@ class UniqueTogetherModel(models.Model):
unique_together = ("foreign_key", "one_to_one") unique_together = ("foreign_key", "one_to_one")
class TestRelationalFieldMappings(TestCase): def test_pk_relations(self):
def test_pk_relations(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = RelationalModel model = RelationalModel
@ -520,7 +495,7 @@ class TestRelationalFieldMappings(TestCase):
many_to_many = PrimaryKeyRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all()) many_to_many = PrimaryKeyRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all())
through = PrimaryKeyRelatedField(many=True, read_only=True) through = PrimaryKeyRelatedField(many=True, read_only=True)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_nested_relations(self): def test_nested_relations(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -545,7 +520,7 @@ class TestRelationalFieldMappings(TestCase):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_hyperlinked_relations(self): def test_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer): class TestSerializer(serializers.HyperlinkedModelSerializer):
@ -561,7 +536,7 @@ class TestRelationalFieldMappings(TestCase):
many_to_many = HyperlinkedRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') many_to_many = HyperlinkedRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_nested_hyperlinked_relations(self): def test_nested_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer): class TestSerializer(serializers.HyperlinkedModelSerializer):
@ -586,7 +561,7 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_nested_hyperlinked_relations_starred_source(self): def test_nested_hyperlinked_relations_starred_source(self):
class TestSerializer(serializers.HyperlinkedModelSerializer): class TestSerializer(serializers.HyperlinkedModelSerializer):
@ -594,11 +569,7 @@ class TestRelationalFieldMappings(TestCase):
model = RelationalModel model = RelationalModel
depth = 1 depth = 1
fields = '__all__' fields = '__all__'
extra_kwargs = { 'url': { 'source': '*', }}
extra_kwargs = {
'url': {
'source': '*',
}}
expected = dedent(""" expected = dedent("""
TestSerializer(): TestSerializer():
@ -616,8 +587,8 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
self.maxDiff = None self.maxDiff=None
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_nested_unique_together_relations(self): def test_nested_unique_together_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer): class TestSerializer(serializers.HyperlinkedModelSerializer):
@ -636,7 +607,7 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_reverse_foreign_key(self): def test_pk_reverse_foreign_key(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -650,7 +621,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_reverse_one_to_one(self): def test_pk_reverse_one_to_one(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -664,7 +635,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all()) reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all())
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_reverse_many_to_many(self): def test_pk_reverse_many_to_many(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -678,7 +649,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_reverse_through(self): def test_pk_reverse_through(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -692,7 +663,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
reverse_through = PrimaryKeyRelatedField(many=True, read_only=True) reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
class DisplayValueTargetModel(models.Model): class DisplayValueTargetModel(models.Model):
@ -706,13 +677,8 @@ class DisplayValueModel(models.Model):
color = models.ForeignKey(DisplayValueTargetModel, on_delete=models.CASCADE) color = models.ForeignKey(DisplayValueTargetModel, on_delete=models.CASCADE)
class TestRelationalFieldDisplayValue(TestCase): def setUp(self):
def setUp(self): DisplayValueTargetModel.objects.bulk_create([ DisplayValueTargetModel(name='Red'), DisplayValueTargetModel(name='Yellow'), DisplayValueTargetModel(name='Green'), ])
DisplayValueTargetModel.objects.bulk_create([
DisplayValueTargetModel(name='Red'),
DisplayValueTargetModel(name='Yellow'),
DisplayValueTargetModel(name='Green'),
])
def test_default_display_value(self): def test_default_display_value(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -721,8 +687,8 @@ class TestRelationalFieldDisplayValue(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
expected = OrderedDict([(1, 'Red Color'), (2, 'Yellow Color'), (3, 'Green Color')]) expected=OrderedDict([(1,'Red Color'),(2,'Yellow Color'),(3,'Green Color')])
self.assertEqual(serializer.fields['color'].choices, expected) assertserializer.fields['color'].choices==expected
def test_custom_display_value(self): def test_custom_display_value(self):
class TestField(serializers.PrimaryKeyRelatedField): class TestField(serializers.PrimaryKeyRelatedField):
@ -731,33 +697,20 @@ class TestRelationalFieldDisplayValue(TestCase):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
color = TestField(queryset=DisplayValueTargetModel.objects.all()) color = TestField(queryset=DisplayValueTargetModel.objects.all())
class Meta: class Meta:
model = DisplayValueModel model = DisplayValueModel
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
expected = OrderedDict([(1, 'My Red Color'), (2, 'My Yellow Color'), (3, 'My Green Color')]) expected=OrderedDict([(1,'My Red Color'),(2,'My Yellow Color'),(3,'My Green Color')])
self.assertEqual(serializer.fields['color'].choices, expected) assertserializer.fields['color'].choices==expected
class TestIntegration(TestCase): def setUp(self):
def setUp(self): self.foreign_key_target = ForeignKeyTargetModel.objects.create( name='foreign_key' )
self.foreign_key_target = ForeignKeyTargetModel.objects.create( self.one_to_one_target = OneToOneTargetModel.objects.create( name='one_to_one' )
name='foreign_key' self.many_to_many_targets = [ ManyToManyTargetModel.objects.create( name='many_to_many (%d)' % idx ) for idx in range(3) ]
) self.instance = RelationalModel.objects.create( foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, )
self.one_to_one_target = OneToOneTargetModel.objects.create(
name='one_to_one'
)
self.many_to_many_targets = [
ManyToManyTargetModel.objects.create(
name='many_to_many (%d)' % idx
) for idx in range(3)
]
self.instance = RelationalModel.objects.create(
foreign_key=self.foreign_key_target,
one_to_one=self.one_to_one_target,
)
self.instance.many_to_many.set(self.many_to_many_targets) self.instance.many_to_many.set(self.many_to_many_targets)
def test_pk_retrival(self): def test_pk_retrival(self):
@ -767,14 +720,8 @@ class TestIntegration(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer(self.instance) serializer = TestSerializer(self.instance)
expected = { expected={'id':self.instance.pk,'foreign_key':self.foreign_key_target.pk,'one_to_one':self.one_to_one_target.pk,'many_to_many':[item.pkforiteminself.many_to_many_targets],'through':[]}
'id': self.instance.pk, assertserializer.data==expected
'foreign_key': self.foreign_key_target.pk,
'one_to_one': self.one_to_one_target.pk,
'many_to_many': [item.pk for item in self.many_to_many_targets],
'through': []
}
self.assertEqual(serializer.data, expected)
def test_pk_create(self): def test_pk_create(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -782,47 +729,19 @@ class TestIntegration(TestCase):
model = RelationalModel model = RelationalModel
fields = '__all__' fields = '__all__'
new_foreign_key = ForeignKeyTargetModel.objects.create( new_foreign_key = ForeignKeyTargetModel.objects.create(name='foreign_key')
name='foreign_key' new_one_to_one=OneToOneTargetModel.objects.create(name='one_to_one')
) new_many_to_many=[ManyToManyTargetModel.objects.create(name='new many_to_many (%d)'%idx)foridxinrange(3)]
new_one_to_one = OneToOneTargetModel.objects.create( data={'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],}
name='one_to_one' serializer=TestSerializer(data=data)
) assertserializer.is_valid()
new_many_to_many = [ instance=serializer.save()
ManyToManyTargetModel.objects.create( assertinstance.foreign_key.pk==new_foreign_key.pk
name='new many_to_many (%d)' % idx assertinstance.one_to_one.pk==new_one_to_one.pk
) for idx in range(3) assert[item.pkforitemininstance.many_to_many.all()]==[item.pkforiteminnew_many_to_many]
] assertlist(instance.through.all())==[]
data = { expected={'id':instance.pk,'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],'through':[]}
'foreign_key': new_foreign_key.pk, assertserializer.data==expected
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
}
# Serializer should validate okay.
serializer = TestSerializer(data=data)
assert serializer.is_valid()
# Creating the instance, relationship attributes should be set.
instance = serializer.save()
assert instance.foreign_key.pk == new_foreign_key.pk
assert instance.one_to_one.pk == new_one_to_one.pk
assert [
item.pk for item in instance.many_to_many.all()
] == [
item.pk for item in new_many_to_many
]
assert list(instance.through.all()) == []
# Representation should be correct.
expected = {
'id': instance.pk,
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'through': []
}
self.assertEqual(serializer.data, expected)
def test_pk_update(self): def test_pk_update(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -830,47 +749,19 @@ class TestIntegration(TestCase):
model = RelationalModel model = RelationalModel
fields = '__all__' fields = '__all__'
new_foreign_key = ForeignKeyTargetModel.objects.create( new_foreign_key = ForeignKeyTargetModel.objects.create(name='foreign_key')
name='foreign_key' new_one_to_one=OneToOneTargetModel.objects.create(name='one_to_one')
) new_many_to_many=[ManyToManyTargetModel.objects.create(name='new many_to_many (%d)'%idx)foridxinrange(3)]
new_one_to_one = OneToOneTargetModel.objects.create( data={'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],}
name='one_to_one' serializer=TestSerializer(self.instance,data=data)
) assertserializer.is_valid()
new_many_to_many = [ instance=serializer.save()
ManyToManyTargetModel.objects.create( assertinstance.foreign_key.pk==new_foreign_key.pk
name='new many_to_many (%d)' % idx assertinstance.one_to_one.pk==new_one_to_one.pk
) for idx in range(3) assert[item.pkforitemininstance.many_to_many.all()]==[item.pkforiteminnew_many_to_many]
] assertlist(instance.through.all())==[]
data = { expected={'id':self.instance.pk,'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],'through':[]}
'foreign_key': new_foreign_key.pk, assertserializer.data==expected
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
}
# Serializer should validate okay.
serializer = TestSerializer(self.instance, data=data)
assert serializer.is_valid()
# Creating the instance, relationship attributes should be set.
instance = serializer.save()
assert instance.foreign_key.pk == new_foreign_key.pk
assert instance.one_to_one.pk == new_one_to_one.pk
assert [
item.pk for item in instance.many_to_many.all()
] == [
item.pk for item in new_many_to_many
]
assert list(instance.through.all()) == []
# Representation should be correct.
expected = {
'id': self.instance.pk,
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'through': []
}
self.assertEqual(serializer.data, expected)
# Tests for bulk create using `ListSerializer`. # Tests for bulk create using `ListSerializer`.
@ -879,8 +770,7 @@ class BulkCreateModel(models.Model):
name = models.CharField(max_length=10) name = models.CharField(max_length=10)
class TestBulkCreate(TestCase): def test_bulk_create(self):
def test_bulk_create(self):
class BasicModelSerializer(serializers.ModelSerializer): class BasicModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = BulkCreateModel model = BulkCreateModel
@ -890,35 +780,28 @@ class TestBulkCreate(TestCase):
child = BasicModelSerializer() child = BasicModelSerializer()
data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}] data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}]
serializer = BulkCreateSerializer(data=data) serializer=BulkCreateSerializer(data=data)
assert serializer.is_valid() assertserializer.is_valid()
instances=serializer.save()
# Objects are returned by save(). assertlen(instances)==3
instances = serializer.save() assert[item.nameforitemininstances]==['a','b','c']
assert len(instances) == 3 assertBulkCreateModel.objects.count()==3
assert [item.name for item in instances] == ['a', 'b', 'c'] assertlist(BulkCreateModel.objects.values_list('name',flat=True))==['a','b','c']
assertserializer.data==data
# Objects have been created in the database.
assert BulkCreateModel.objects.count() == 3
assert list(BulkCreateModel.objects.values_list('name', flat=True)) == ['a', 'b', 'c']
# Serializer returns correct data.
assert serializer.data == data
class MetaClassTestModel(models.Model): class MetaClassTestModel(models.Model):
text = models.CharField(max_length=100) text = models.CharField(max_length=100)
class TestSerializerMetaClass(TestCase): def test_meta_class_fields_option(self):
def test_meta_class_fields_option(self):
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = MetaClassTestModel model = MetaClassTestModel
fields = 'text' fields = 'text'
msginitial = "The `fields` option must be a list or tuple" msginitial = "The `fields` option must be a list or tuple"
with self.assertRaisesMessage(TypeError, msginitial): withself.assertRaisesMessage(TypeError,msginitial):
ExampleSerializer().fields ExampleSerializer().fields
def test_meta_class_exclude_option(self): def test_meta_class_exclude_option(self):
@ -928,7 +811,7 @@ class TestSerializerMetaClass(TestCase):
exclude = 'text' exclude = 'text'
msginitial = "The `exclude` option must be a list or tuple" msginitial = "The `exclude` option must be a list or tuple"
with self.assertRaisesMessage(TypeError, msginitial): withself.assertRaisesMessage(TypeError,msginitial):
ExampleSerializer().fields ExampleSerializer().fields
def test_meta_class_fields_and_exclude_options(self): def test_meta_class_fields_and_exclude_options(self):
@ -939,49 +822,36 @@ class TestSerializerMetaClass(TestCase):
exclude = ('text',) exclude = ('text',)
msginitial = "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer." msginitial = "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer."
with self.assertRaisesMessage(AssertionError, msginitial): withself.assertRaisesMessage(AssertionError,msginitial):
ExampleSerializer().fields ExampleSerializer().fields
def test_declared_fields_with_exclude_option(self): def test_declared_fields_with_exclude_option(self):
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
text = serializers.CharField() text = serializers.CharField()
class Meta: class Meta:
model = MetaClassTestModel model = MetaClassTestModel
exclude = ('text',) exclude = ('text',)
expected = ( expected = ("Cannot both declare the field 'text' and include it in the ""ExampleSerializer 'exclude' option. Remove the field or, if ""inherited from a parent serializer, disable with `text = None`.")
"Cannot both declare the field 'text' and include it in the " withself.assertRaisesMessage(AssertionError,expected):
"ExampleSerializer 'exclude' option. Remove the field or, if "
"inherited from a parent serializer, disable with `text = None`."
)
with self.assertRaisesMessage(AssertionError, expected):
ExampleSerializer().fields ExampleSerializer().fields
class Issue2704TestCase(TestCase): def test_queryset_all(self):
def test_queryset_all(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
additional_attr = serializers.CharField() additional_attr = serializers.CharField()
class Meta: class Meta:
model = OneFieldModel model = OneFieldModel
fields = ('char_field', 'additional_attr') fields = ('char_field', 'additional_attr')
OneFieldModel.objects.create(char_field='abc') OneFieldModel.objects.create(char_field='abc')
qs = OneFieldModel.objects.all() qs=OneFieldModel.objects.all()
foroinqs:
for o in qs:
o.additional_attr = '123' o.additional_attr = '123'
serializer = TestSerializer(instance=qs, many=True) serializer = TestSerializer(instance=qs, many=True)
expected=[{'char_field':'abc','additional_attr':'123',}]
expected = [{ assertserializer.data==expected
'char_field': 'abc',
'additional_attr': '123',
}]
assert serializer.data == expected
class DecimalFieldModel(models.Model): class DecimalFieldModel(models.Model):
@ -992,8 +862,7 @@ class DecimalFieldModel(models.Model):
) )
class TestDecimalFieldMappings(TestCase): def test_decimal_field_has_decimal_validator(self):
def test_decimal_field_has_decimal_validator(self):
""" """
Test that a `DecimalField` has no `DecimalValidator`. Test that a `DecimalField` has no `DecimalValidator`.
""" """
@ -1003,8 +872,7 @@ class TestDecimalFieldMappings(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
assertlen(serializer.fields['decimal_field'].validators)==2
assert len(serializer.fields['decimal_field'].validators) == 2
def test_min_value_is_passed(self): def test_min_value_is_passed(self):
""" """
@ -1017,8 +885,7 @@ class TestDecimalFieldMappings(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
assertserializer.fields['decimal_field'].min_value==1
assert serializer.fields['decimal_field'].min_value == 1
def test_max_value_is_passed(self): def test_max_value_is_passed(self):
""" """
@ -1031,15 +898,12 @@ class TestDecimalFieldMappings(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
assertserializer.fields['decimal_field'].max_value==3
assert serializer.fields['decimal_field'].max_value == 3
class TestMetaInheritance(TestCase): def test_extra_kwargs_not_altered(self):
def test_extra_kwargs_not_altered(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
non_model_field = serializers.CharField() non_model_field = serializers.CharField()
class Meta: class Meta:
model = OneFieldModel model = OneFieldModel
read_only_fields = ('char_field', 'non_model_field') read_only_fields = ('char_field', 'non_model_field')
@ -1055,15 +919,14 @@ class TestMetaInheritance(TestCase):
char_field = CharField(read_only=True) char_field = CharField(read_only=True)
non_model_field = CharField() non_model_field = CharField()
""") """)
child_expected=dedent("""
child_expected = dedent("""
ChildSerializer(): ChildSerializer():
char_field = CharField(max_length=100) char_field = CharField(max_length=100)
non_model_field = CharField() non_model_field = CharField()
""") """)
self.assertEqual(repr(ChildSerializer()), child_expected) assertrepr(ChildSerializer())==child_expected
self.assertEqual(repr(TestSerializer()), test_expected) assertrepr(TestSerializer())==test_expected
self.assertEqual(repr(ChildSerializer()), child_expected) assertrepr(ChildSerializer())==child_expected
class OneToOneTargetTestModel(models.Model): class OneToOneTargetTestModel(models.Model):
@ -1074,25 +937,22 @@ class OneToOneSourceTestModel(models.Model):
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE) target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE)
class TestModelFieldValues(TestCase): def test_model_field(self):
def test_model_field(self):
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = OneToOneSourceTestModel model = OneToOneSourceTestModel
fields = ('target',) fields = ('target',)
target = OneToOneTargetTestModel(id=1, text='abc') target = OneToOneTargetTestModel(id=1, text='abc')
source = OneToOneSourceTestModel(target=target) source=OneToOneSourceTestModel(target=target)
serializer = ExampleSerializer(source) serializer=ExampleSerializer(source)
self.assertEqual(serializer.data, {'target': 1}) assertserializer.data=={'target':1}
class TestUniquenessOverride(TestCase): def test_required_not_overwritten(self):
def test_required_not_overwritten(self):
class TestModel(models.Model): class TestModel(models.Model):
field_1 = models.IntegerField(null=True) field_1 = models.IntegerField(null=True)
field_2 = models.IntegerField() field_2 = models.IntegerField()
class Meta: class Meta:
unique_together = (('field_1', 'field_2'),) unique_together = (('field_1', 'field_2'),)
@ -1103,12 +963,11 @@ class TestUniquenessOverride(TestCase):
extra_kwargs = {'field_1': {'required': False}} extra_kwargs = {'field_1': {'required': False}}
fields = TestSerializer().fields fields = TestSerializer().fields
self.assertFalse(fields['field_1'].required) assertnotfields['field_1'].required
self.assertTrue(fields['field_2'].required) assertfields['field_2'].required
class Issue3674Test(TestCase): def test_nonPK_foreignkey_model_serializer(self):
def test_nonPK_foreignkey_model_serializer(self):
class TestParentModel(models.Model): class TestParentModel(models.Model):
title = models.CharField(max_length=64) title = models.CharField(max_length=64)
@ -1132,14 +991,13 @@ class Issue3674Test(TestCase):
title = CharField(max_length=64) title = CharField(max_length=64)
children = PrimaryKeyRelatedField(many=True, queryset=TestChildModel.objects.all()) children = PrimaryKeyRelatedField(many=True, queryset=TestChildModel.objects.all())
""") """)
self.assertEqual(repr(TestParentModelSerializer()), parent_expected) assertrepr(TestParentModelSerializer())==parent_expected
child_expected=dedent("""
child_expected = dedent("""
TestChildModelSerializer(): TestChildModelSerializer():
value = CharField(max_length=64, validators=[<UniqueValidator(queryset=TestChildModel.objects.all())>]) value = CharField(max_length=64, validators=[<UniqueValidator(queryset=TestChildModel.objects.all())>])
parent = PrimaryKeyRelatedField(queryset=TestParentModel.objects.all()) parent = PrimaryKeyRelatedField(queryset=TestParentModel.objects.all())
""") """)
self.assertEqual(repr(TestChildModelSerializer()), child_expected) assertrepr(TestChildModelSerializer())==child_expected
def test_nonID_PK_foreignkey_model_serializer(self): def test_nonID_PK_foreignkey_model_serializer(self):
@ -1154,20 +1012,16 @@ class Issue3674Test(TestCase):
fields = ('id', 'title', 'children') fields = ('id', 'title', 'children')
parent = Issue3674ParentModel.objects.create(title='abc') parent = Issue3674ParentModel.objects.create(title='abc')
child = Issue3674ChildModel.objects.create(value='def', parent=parent) child=Issue3674ChildModel.objects.create(value='def',parent=parent)
parent_serializer=TestParentModelSerializer(parent)
parent_serializer = TestParentModelSerializer(parent) child_serializer=TestChildModelSerializer(child)
child_serializer = TestChildModelSerializer(child) parent_expected={'children':['def'],'id':1,'title':'abc'}
assertparent_serializer.data==parent_expected
parent_expected = {'children': ['def'], 'id': 1, 'title': 'abc'} child_expected={'parent':1,'value':'def'}
self.assertEqual(parent_serializer.data, parent_expected) assertchild_serializer.data==child_expected
child_expected = {'parent': 1, 'value': 'def'}
self.assertEqual(child_serializer.data, child_expected)
class Issue4897TestCase(TestCase): def test_should_assert_if_writing_readonly_fields(self):
def test_should_assert_if_writing_readonly_fields(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = OneFieldModel model = OneFieldModel
@ -1175,27 +1029,24 @@ class Issue4897TestCase(TestCase):
readonly_fields = fields readonly_fields = fields
obj = OneFieldModel.objects.create(char_field='abc') obj = OneFieldModel.objects.create(char_field='abc')
withpytest.raises(AssertionError)ascm:
with pytest.raises(AssertionError) as cm:
TestSerializer(obj).fields TestSerializer(obj).fields
cm.match(r'readonly_fields') cm.match(r'readonly_fields')
class Test5004UniqueChoiceField(TestCase): def test_unique_choice_field(self):
def test_unique_choice_field(self):
class TestUniqueChoiceSerializer(serializers.ModelSerializer): class TestUniqueChoiceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = UniqueChoiceModel model = UniqueChoiceModel
fields = '__all__' fields = '__all__'
UniqueChoiceModel.objects.create(name='choice1') UniqueChoiceModel.objects.create(name='choice1')
serializer = TestUniqueChoiceSerializer(data={'name': 'choice1'}) serializer=TestUniqueChoiceSerializer(data={'name':'choice1'})
assert not serializer.is_valid() assertnotserializer.is_valid()
assert serializer.errors == {'name': ['unique choice model with this name already exists.']} assertserializer.errors=={'name':['unique choice model with this name already exists.']}
class TestFieldSource(TestCase): def test_traverse_nullable_fk(self):
def test_traverse_nullable_fk(self):
""" """
A dotted source with nullable elements uses default when any item in the chain is None. #5849. A dotted source with nullable elements uses default when any item in the chain is None. #5849.
@ -1203,16 +1054,13 @@ class TestFieldSource(TestCase):
but using RelatedField, rather than CharField. but using RelatedField, rather than CharField.
""" """
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
target = serializers.PrimaryKeyRelatedField( target = serializers.PrimaryKeyRelatedField( source='target.target', read_only=True, allow_null=True, default=None )
source='target.target', read_only=True, allow_null=True, default=None
)
class Meta: class Meta:
model = NestedForeignKeySource model = NestedForeignKeySource
fields = ('target', ) fields = ('target', )
model = NestedForeignKeySource.objects.create() model = NestedForeignKeySource.objects.create()
assert TestSerializer(model).data['target'] is None assertTestSerializer(model).data['target']isNone
def test_named_field_source(self): def test_named_field_source(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -1220,18 +1068,14 @@ class TestFieldSource(TestCase):
class Meta: class Meta:
model = RegularFieldsModel model = RegularFieldsModel
fields = ('number_field',) fields = ('number_field',)
extra_kwargs = { extra_kwargs = { 'number_field': { 'source': 'integer_field' } }
'number_field': {
'source': 'integer_field'
}
}
expected = dedent(""" expected = dedent("""
TestSerializer(): TestSerializer():
number_field = IntegerField(source='integer_field') number_field = IntegerField(source='integer_field')
""") """)
self.maxDiff = None self.maxDiff=None
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
class Issue6110TestModel(models.Model): class Issue6110TestModel(models.Model):
@ -1247,11 +1091,10 @@ class Issue6110ModelSerializer(serializers.ModelSerializer):
fields = ('name',) fields = ('name',)
class Issue6110Test(TestCase):
def test_model_serializer_custom_manager(self): def test_model_serializer_custom_manager(self):
instance = Issue6110ModelSerializer().create({'name': 'test_name'}) instance = Issue6110ModelSerializer().create({'name': 'test_name'})
self.assertEqual(instance.name, 'test_name') assert instance.name == 'test_name'
def test_model_serializer_custom_manager_error_message(self): def test_model_serializer_custom_manager_error_message(self):
msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.') msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.')

View File

@ -33,9 +33,8 @@ class AssociatedModelSerializer(serializers.ModelSerializer):
# Tests # Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
Assert that the parent pointer field is not included in the fields Assert that the parent pointer field is not included in the fields
serialized fields serialized fields
@ -59,9 +58,6 @@ class InheritedModelSerializationTests(TestCase):
Assert that the pointer to the parent table is not a required field Assert that the pointer to the parent table is not a required field
for input data for input data
""" """
data = { data = { 'name1': 'parent name', 'name2': 'child name', }
'name1': 'parent name',
'name2': 'child name',
}
serializer = DerivedModelSerializer(data=data) serializer = DerivedModelSerializer(data=data)
assert serializer.is_valid() is True assert serializer.is_valid() is True

View File

@ -30,8 +30,7 @@ class NoCharsetSpecifiedRenderer(BaseRenderer):
media_type = 'my/media' media_type = 'my/media'
class TestAcceptedMediaType(TestCase): def setUp(self):
def setUp(self):
self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()] self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()]
self.negotiator = DefaultContentNegotiation() self.negotiator = DefaultContentNegotiation()
@ -81,13 +80,12 @@ class TestAcceptedMediaType(TestCase):
class MockRenderer: class MockRenderer:
format = 'xml' format = 'xml'
renderers = [MockRenderer()] renderers = [MockRenderer()]
with pytest.raises(Http404): withpytest.raises(Http404):
self.negotiator.filter_renderers(renderers, format='json') self.negotiator.filter_renderers(renderers, format='json')
class BaseContentNegotiationTests(TestCase):
def setUp(self): def setUp(self):
self.negotiator = BaseContentNegotiation() self.negotiator = BaseContentNegotiation()
def test_raise_error_for_abstract_select_parser_method(self): def test_raise_error_for_abstract_select_parser_method(self):

View File

@ -28,14 +28,12 @@ class ChildAssociatedModelSerializer(serializers.ModelSerializer):
# Tests # Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
Assert that the parent pointer field is not included in the fields Assert that the parent pointer field is not included in the fields
serialized fields serialized fields
""" """
child = ChildModel(name1='parent name', name2='child name') child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child) serializer = DerivedModelSerializer(child)
self.assertEqual(set(serializer.data), assert set(serializer.data) == {'name1', 'name2', 'id', 'childassociatedmodel'}
{'name1', 'name2', 'id', 'childassociatedmodel'})

View File

@ -22,32 +22,25 @@ class Form(forms.Form):
field2 = forms.CharField() field2 = forms.CharField()
class TestFormParser(TestCase): def setUp(self):
def setUp(self):
self.string = "field1=abc&field2=defghijk" self.string = "field1=abc&field2=defghijk"
def test_parse(self): def test_parse(self):
""" Make sure the `QueryDict` works OK """ """ Make sure the `QueryDict` works OK """
parser = FormParser() parser = FormParser()
stream = io.StringIO(self.string) stream = io.StringIO(self.string)
data = parser.parse(stream) data = parser.parse(stream)
assert Form(data).is_valid() is True assert Form(data).is_valid() is True
class TestFileUploadParser(TestCase): def setUp(self):
def setUp(self):
class MockRequest: class MockRequest:
pass pass
self.stream = io.BytesIO(b"Test text file") self.stream = io.BytesIO(b"Test text file")
request = MockRequest() request=MockRequest()
request.upload_handlers = (MemoryFileUploadHandler(),) request.upload_handlers=(MemoryFileUploadHandler(),)
request.META = { request.META={'HTTP_CONTENT_DISPOSITION':'Content-Disposition: inline; filename=file.txt','HTTP_CONTENT_LENGTH':14,}
'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt', self.parser_context={'request':request,'kwargs':{}}
'HTTP_CONTENT_LENGTH': 14,
}
self.parser_context = {'request': request, 'kwargs': {}}
def test_parse(self): def test_parse(self):
""" """
@ -77,10 +70,7 @@ class TestFileUploadParser(TestCase):
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( self.parser_context['request'].upload_handlers = ( MemoryFileUploadHandler(), MemoryFileUploadHandler() )
MemoryFileUploadHandler(),
MemoryFileUploadHandler()
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo: with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context) parser.parse(self.stream, None, self.parser_context)
@ -92,9 +82,7 @@ class TestFileUploadParser(TestCase):
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( self.parser_context['request'].upload_handlers = ( TemporaryFileUploadHandler(), )
TemporaryFileUploadHandler(),
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo: with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context) parser.parse(self.stream, None, self.parser_context)
@ -107,15 +95,12 @@ class TestFileUploadParser(TestCase):
def test_get_encoded_filename(self): def test_get_encoded_filename(self):
parser = FileUploadParser() parser = FileUploadParser()
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt') self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt') self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt') self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == 'ÀĥƦ.txt'
@ -124,26 +109,22 @@ class TestFileUploadParser(TestCase):
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
class TestJSONParser(TestCase): def bytes(self, value):
def bytes(self, value):
return io.BytesIO(value.encode()) return io.BytesIO(value.encode())
def test_float_strictness(self): def test_float_strictness(self):
parser = JSONParser() parser = JSONParser()
# Default to strict
for value in ['Infinity', '-Infinity', 'NaN']: for value in ['Infinity', '-Infinity', 'NaN']:
with pytest.raises(ParseError): with pytest.raises(ParseError):
parser.parse(self.bytes(value)) parser.parse(self.bytes(value))
parser.strict = False parser.strict = False
assert parser.parse(self.bytes('Infinity')) == float('inf') assertparser.parse(self.bytes('Infinity'))==float('inf')
assert parser.parse(self.bytes('-Infinity')) == float('-inf') assertparser.parse(self.bytes('-Infinity'))==float('-inf')
assert math.isnan(parser.parse(self.bytes('NaN'))) assertmath.isnan(parser.parse(self.bytes('NaN')))
class TestPOSTAccessed(TestCase): def setUp(self):
def setUp(self):
self.factory = APIRequestFactory() self.factory = APIRequestFactory()
def test_post_accessed_in_post_method(self): def test_post_accessed_in_post_method(self):

View File

@ -72,32 +72,21 @@ def basic_auth_header(username, password):
return 'Basic %s' % base64_credentials return 'Basic %s' % base64_credentials
class ModelPermissionsIntegrationTests(TestCase): def setUp(self):
def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password') User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
user = User.objects.create_user('permitted', 'permitted@example.com', 'password') user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
user.user_permissions.set([ user.user_permissions.set([ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') ])
Permission.objects.get(codename='add_basicmodel'),
Permission.objects.get(codename='change_basicmodel'),
Permission.objects.get(codename='delete_basicmodel')
])
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
user.user_permissions.set([ user.user_permissions.set([ Permission.objects.get(codename='change_basicmodel'), ])
Permission.objects.get(codename='change_basicmodel'),
])
self.permitted_credentials = basic_auth_header('permitted', 'password') self.permitted_credentials = basic_auth_header('permitted', 'password')
self.disallowed_credentials = basic_auth_header('disallowed', 'password') self.disallowed_credentials = basic_auth_header('disallowed', 'password')
self.updateonly_credentials = basic_auth_header('updateonly', 'password') self.updateonly_credentials = basic_auth_header('updateonly', 'password')
BasicModel(text='foo').save() BasicModel(text='foo').save()
def test_has_create_permissions(self): def test_has_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request, pk=1) response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) assert response.status_code == status.HTTP_201_CREATED
def test_api_root_view_discard_default_django_model_permission(self): def test_api_root_view_discard_default_django_model_permission(self):
""" """
@ -105,133 +94,104 @@ class ModelPermissionsIntegrationTests(TestCase):
apply to APIRoot view. More specifically we check expected behavior of apply to APIRoot view. More specifically we check expected behavior of
``_ignore_model_permissions`` attribute support. ``_ignore_model_permissions`` attribute support.
""" """
request = factory.get('/', format='json', request = factory.get('/', format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials)
request.resolver_match = ResolverMatch('get', (), {}) request.resolver_match = ResolverMatch('get', (), {})
response = api_root_view(request) response = api_root_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_get_queryset_has_create_permissions(self): def test_get_queryset_has_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials)
response = get_queryset_list_view(request, pk=1) response = get_queryset_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) assert response.status_code == status.HTTP_201_CREATED
def test_has_put_permissions(self): def test_has_put_permissions(self):
request = factory.put('/1', {'text': 'foobar'}, format='json', request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_has_delete_permissions(self): def test_has_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) assert response.status_code == status.HTTP_204_NO_CONTENT
def test_does_not_have_create_permissions(self): def test_does_not_have_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = root_view(request, pk=1) response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
def test_does_not_have_put_permissions(self): def test_does_not_have_put_permissions(self):
request = factory.put('/1', {'text': 'foobar'}, format='json', request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
def test_does_not_have_delete_permissions(self): def test_does_not_have_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
def test_options_permitted(self): def test_options_permitted(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.permitted_credentials )
'/',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertIn('actions', response.data) assert 'actions' in response.data
self.assertEqual(list(response.data['actions']), ['POST']) assert list(response.data['actions']) == ['POST']
request = factory.options( '/1', HTTP_AUTHORIZATION=self.permitted_credentials )
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertIn('actions', response.data) assert 'actions' in response.data
self.assertEqual(list(response.data['actions']), ['PUT']) assert list(response.data['actions']) == ['PUT']
def test_options_disallowed(self): def test_options_disallowed(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.disallowed_credentials )
'/',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertNotIn('actions', response.data) assert 'actions' not in response.data
request = factory.options( '/1', HTTP_AUTHORIZATION=self.disallowed_credentials )
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertNotIn('actions', response.data) assert 'actions' not in response.data
def test_options_updateonly(self): def test_options_updateonly(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.updateonly_credentials )
'/',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertNotIn('actions', response.data) assert 'actions' not in response.data
request = factory.options( '/1', HTTP_AUTHORIZATION=self.updateonly_credentials )
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertIn('actions', response.data) assert 'actions' in response.data
self.assertEqual(list(response.data['actions']), ['PUT']) assert list(response.data['actions']) == ['PUT']
def test_empty_view_does_not_assert(self): def test_empty_view_does_not_assert(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = empty_list_view(request, pk=1) response = empty_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_calling_method_not_allowed(self): def test_calling_method_not_allowed(self):
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request) response = root_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_check_auth_before_queryset_call(self): def test_check_auth_before_queryset_call(self):
class View(RootView): class View(RootView):
def get_queryset(_): def get_queryset(_):
self.fail('should not reach due to auth check') self.fail('should not reach due to auth check')
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION='')
request = factory.get('/', HTTP_AUTHORIZATION='') response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_401_UNAUTHORIZED
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_queryset_assertions(self): def test_queryset_assertions(self):
class View(views.APIView): class View(views.APIView):
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions] permission_classes = [permissions.DjangoModelPermissions]
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials) msg='Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.'
msg = 'Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.' withself.assertRaisesMessage(AssertionError,msg):
with self.assertRaisesMessage(AssertionError, msg):
view(request) view(request)
# Faulty `get_queryset()` methods should trigger the above "view does not have a queryset" assertion. # Faulty `get_queryset()` methods should trigger the above "view does not have a queryset" assertion.
@ -239,9 +199,8 @@ class ModelPermissionsIntegrationTests(TestCase):
def get_queryset(self): def get_queryset(self):
return None return None
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials) withself.assertRaisesMessage(AssertionError,'View.get_queryset() returned None'):
with self.assertRaisesMessage(AssertionError, 'View.get_queryset() returned None'):
view(request) view(request)
@ -310,103 +269,74 @@ get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.a
@unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed') @unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed')
class ObjectPermissionsIntegrationTests(TestCase): """
"""
Integration tests for the object level permissions API. Integration tests for the object level permissions API.
""" """
def setUp(self): defsetUp(self):
from guardian.shortcuts import assign_perm from guardian.shortcuts import assign_perm
# create users
create = User.objects.create_user create = User.objects.create_user
users = { users = { 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), 'readonly': create('readonly', 'readonly@example.com', 'password'), 'writeonly': create('writeonly', 'writeonly@example.com', 'password'), 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), }
'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
'readonly': create('readonly', 'readonly@example.com', 'password'),
'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
}
# give everyone model level permissions, as we are not testing those
everyone = Group.objects.create(name='everyone') everyone = Group.objects.create(name='everyone')
model_name = BasicPermModel._meta.model_name model_name = BasicPermModel._meta.model_name
app_label = BasicPermModel._meta.app_label app_label = BasicPermModel._meta.app_label
f = '{}_{}'.format f = '{}_{}'.format
perms = { perms = { 'view': f('view', model_name), 'change': f('change', model_name), 'delete': f('delete', model_name) }
'view': f('view', model_name),
'change': f('change', model_name),
'delete': f('delete', model_name)
}
for perm in perms.values(): for perm in perms.values():
perm = '{}.{}'.format(app_label, perm) perm = '{}.{}'.format(app_label, perm)
assign_perm(perm, everyone) assign_perm(perm, everyone)
everyone.user_set.add(*users.values()) everyone.user_set.add(*users.values())
readers=Group.objects.create(name='readers')
# appropriate object level permissions writers=Group.objects.create(name='writers')
readers = Group.objects.create(name='readers') deleters=Group.objects.create(name='deleters')
writers = Group.objects.create(name='writers') model=BasicPermModel.objects.create(text='foo')
deleters = Group.objects.create(name='deleters') assign_perm(perms['view'],readers,model)
assign_perm(perms['change'],writers,model)
model = BasicPermModel.objects.create(text='foo') assign_perm(perms['delete'],deleters,model)
readers.user_set.add(users['fullaccess'],users['readonly'])
assign_perm(perms['view'], readers, model) writers.user_set.add(users['fullaccess'],users['writeonly'])
assign_perm(perms['change'], writers, model) deleters.user_set.add(users['fullaccess'],users['deleteonly'])
assign_perm(perms['delete'], deleters, model) self.credentials={}
foruserinusers.values():
readers.user_set.add(users['fullaccess'], users['readonly'])
writers.user_set.add(users['fullaccess'], users['writeonly'])
deleters.user_set.add(users['fullaccess'], users['deleteonly'])
self.credentials = {}
for user in users.values():
self.credentials[user.username] = basic_auth_header(user.username, 'password') self.credentials[user.username] = basic_auth_header(user.username, 'password')
# Delete # Delete
def test_can_delete_permissions(self): def test_can_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) assert response.status_code == status.HTTP_204_NO_CONTENT
def test_cannot_delete_permissions(self): def test_cannot_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
# Update # Update
def test_can_update_permissions(self): def test_can_update_permissions(self):
request = factory.patch( request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['writeonly'] )
'/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['writeonly']
)
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertEqual(response.data.get('text'), 'foobar') assert response.data.get('text') == 'foobar'
def test_cannot_update_permissions(self): def test_cannot_update_permissions(self):
request = factory.patch( request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
'/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
def test_cannot_update_permissions_non_existing(self): def test_cannot_update_permissions_non_existing(self):
request = factory.patch( request = factory.patch( '/999', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
'/999', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='999') response = object_permissions_view(request, pk='999')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
# Read # Read
def test_can_read_permissions(self): def test_can_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_cannot_read_permissions(self): def test_cannot_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
def test_can_read_get_queryset_permissions(self): def test_can_read_get_queryset_permissions(self):
""" """
@ -415,7 +345,7 @@ class ObjectPermissionsIntegrationTests(TestCase):
""" """
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = get_queryset_object_permissions_view(request, pk='1') response = get_queryset_object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
# Read list # Read list
def test_django_object_permissions_filter_deprecated(self): def test_django_object_permissions_filter_deprecated(self):
@ -423,36 +353,33 @@ class ObjectPermissionsIntegrationTests(TestCase):
warnings.simplefilter("always") warnings.simplefilter("always")
DjangoObjectPermissionsFilter() DjangoObjectPermissionsFilter()
message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved " message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved ""to the 3rd-party django-rest-framework-guardian package.")
"to the 3rd-party django-rest-framework-guardian package.") assertlen(w)==1
self.assertEqual(len(w), 1) assertw[-1].categoryisRemovedInDRF310Warning
self.assertIs(w[-1].category, RemovedInDRF310Warning) assertstr(w[-1].message)==message
self.assertEqual(str(w[-1].message), message)
def test_can_read_list_permissions(self): def test_can_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
# TODO: remove in version 3.10
with warnings.catch_warnings(record=True): with warnings.catch_warnings(record=True):
warnings.simplefilter("always") warnings.simplefilter("always")
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code== status.HTTP_200_OK
self.assertEqual(response.data[0].get('id'), 1) assertresponse.data[0].get('id')==1
def test_cannot_read_list_permissions(self): def test_cannot_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
# TODO: remove in version 3.10
with warnings.catch_warnings(record=True): with warnings.catch_warnings(record=True):
warnings.simplefilter("always") warnings.simplefilter("always")
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code== status.HTTP_200_OK
self.assertListEqual(response.data, []) assertresponse.data==[]
def test_cannot_method_not_allowed(self): def test_cannot_method_not_allowed(self):
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
class BasicPerm(permissions.BasePermission): class BasicPerm(permissions.BasePermission):
@ -507,10 +434,7 @@ denied_view_with_detail = DeniedViewWithDetail.as_view()
denied_object_view = DeniedObjectView.as_view() denied_object_view = DeniedObjectView.as_view()
denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view() denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()
def setUp(self):
class CustomPermissionsTests(TestCase):
def setUp(self):
BasicModel(text='foo').save() BasicModel(text='foo').save()
User.objects.create_user('username', 'username@example.com', 'password') User.objects.create_user('username', 'username@example.com', 'password')
credentials = basic_auth_header('username', 'password') credentials = basic_auth_header('username', 'password')
@ -520,39 +444,34 @@ class CustomPermissionsTests(TestCase):
def test_permission_denied(self): def test_permission_denied(self):
response = denied_view(self.request, pk=1) response = denied_view(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertNotEqual(detail, self.custom_message) assert detail != self.custom_message
def test_permission_denied_with_custom_detail(self): def test_permission_denied_with_custom_detail(self):
response = denied_view_with_detail(self.request, pk=1) response = denied_view_with_detail(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(detail, self.custom_message) assert detail == self.custom_message
def test_permission_denied_for_object(self): def test_permission_denied_for_object(self):
response = denied_object_view(self.request, pk=1) response = denied_object_view(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertNotEqual(detail, self.custom_message) assert detail != self.custom_message
def test_permission_denied_for_object_with_custom_detail(self): def test_permission_denied_for_object_with_custom_detail(self):
response = denied_object_view_with_detail(self.request, pk=1) response = denied_object_view_with_detail(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(detail, self.custom_message) assert detail == self.custom_message
class PermissionsCompositionTests(TestCase):
def setUp(self): def setUp(self):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user( self.user = User.objects.create_user( self.username, self.email, self.password )
self.username,
self.email,
self.password
)
self.client.login(username=self.username, password=self.password) self.client.login(username=self.username, password=self.password)
def test_and_false(self): def test_and_false(self):
@ -594,46 +513,30 @@ class PermissionsCompositionTests(TestCase):
def test_several_levels_without_negation(self): def test_several_levels_without_negation(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated )
permissions.IsAuthenticated &
permissions.IsAuthenticated &
permissions.IsAuthenticated &
permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence_with_negation(self): def test_several_levels_and_precedence_with_negation(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & ~ permissions.IsAdminUser & permissions.IsAuthenticated & ~(permissions.IsAdminUser & permissions.IsAdminUser) )
permissions.IsAuthenticated &
~ permissions.IsAdminUser &
permissions.IsAuthenticated &
~(permissions.IsAdminUser & permissions.IsAdminUser)
)
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence(self): def test_several_levels_and_precedence(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated | permissions.IsAuthenticated & permissions.IsAuthenticated )
permissions.IsAuthenticated &
permissions.IsAuthenticated |
permissions.IsAuthenticated &
permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_or_lazyness(self): deftest_or_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_allow.assert_called_once() mock_allow.assert_called_once()
mock_deny.assert_not_called() mock_deny.assert_not_called()
@ -641,20 +544,19 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_deny.assert_called_once() mock_deny.assert_called_once()
mock_allow.assert_called_once() mock_allow.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_or_lazyness(self): deftest_object_or_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_allow.assert_called_once() mock_allow.assert_called_once()
mock_deny.assert_not_called() mock_deny.assert_not_called()
@ -662,20 +564,19 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_deny.assert_called_once() mock_deny.assert_called_once()
mock_allow.assert_called_once() mock_allow.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_and_lazyness(self): deftest_and_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_called_once() mock_allow.assert_called_once()
mock_deny.assert_called_once() mock_deny.assert_called_once()
@ -683,20 +584,19 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_not_called() mock_allow.assert_not_called()
mock_deny.assert_called_once() mock_deny.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_and_lazyness(self): deftest_object_and_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_called_once() mock_allow.assert_called_once()
mock_deny.assert_called_once() mock_deny.assert_called_once()
@ -704,6 +604,6 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_not_called() mock_allow.assert_not_called()
mock_deny.assert_called_once() mock_deny.assert_called_once()

View File

@ -18,8 +18,7 @@ class UserUpdate(generics.UpdateAPIView):
serializer_class = UserSerializer serializer_class = UserSerializer
class TestPrefetchRelatedUpdates(TestCase): def setUp(self):
def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com') self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.user.groups.set(self.groups) self.user.groups.set(self.groups)
@ -31,12 +30,7 @@ class TestPrefetchRelatedUpdates(TestCase):
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk) response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1 assert User.objects.get(pk=pk).groups.count() == 1
expected = { expected = { 'id': pk, 'username': 'new', 'groups': [1], 'email': 'tom@example.com' }
'id': pk,
'username': 'new',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected assert response.data == expected
def test_prefetch_related_excluding_instance_from_original_queryset(self): def test_prefetch_related_excluding_instance_from_original_queryset(self):
@ -49,10 +43,5 @@ class TestPrefetchRelatedUpdates(TestCase):
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk) response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1 assert User.objects.get(pk=pk).groups.count() == 1
expected = { expected = { 'id': pk, 'username': 'exclude', 'groups': [1], 'email': 'tom@example.com' }
'id': pk,
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected assert response.data == expected

View File

@ -70,8 +70,7 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedManyToManyTests(TestCase): def setUp(self):
def setUp(self):
for idx in range(1, 4): for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx) target = ManyToManyTarget(name='target-%d' % idx)
target.save() target.save()
@ -83,22 +82,14 @@ class HyperlinkedManyToManyTests(TestCase):
def test_relative_hyperlinks(self): def test_relative_hyperlinks(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
expected = [ expected = [ {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} ]
{'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
def test_many_to_many_retrieve(self): def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -111,11 +102,7 @@ class HyperlinkedManyToManyTests(TestCase):
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -126,15 +113,9 @@ class HyperlinkedManyToManyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
@ -144,15 +125,9 @@ class HyperlinkedManyToManyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
assert serializer.data == expected assert serializer.data == expected
def test_many_to_many_create(self): def test_many_to_many_create(self):
@ -162,16 +137,9 @@ class HyperlinkedManyToManyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ]
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_many_to_many_create(self): def test_reverse_many_to_many_create(self):
@ -181,22 +149,14 @@ class HyperlinkedManyToManyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-4' assert obj.name == 'target-4'
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ]
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
]
assert serializer.data == expected assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedForeignKeyTests(TestCase): def setUp(self):
def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
new_target = ForeignKeyTarget(name='target-2') new_target = ForeignKeyTarget(name='target-2')
@ -208,21 +168,14 @@ class HyperlinkedForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
with self.assertNumQueries(1): with self.assertNumQueries(1):
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
with self.assertNumQueries(3): with self.assertNumQueries(3):
assert serializer.data == expected assert serializer.data == expected
@ -233,15 +186,9 @@ class HyperlinkedForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
@ -256,26 +203,15 @@ class HyperlinkedForeignKeyTests(TestCase):
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid() assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected assert new_serializer.data == expected
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create(self): def test_foreign_key_create(self):
@ -285,16 +221,9 @@ class HyperlinkedForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, ]
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
@ -304,15 +233,9 @@ class HyperlinkedForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-3' assert obj.name == 'target-3'
# Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
@ -324,24 +247,19 @@ class HyperlinkedForeignKeyTests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedNullableForeignKeyTests(TestCase): def setUp(self):
def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
@ -351,16 +269,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self): def test_foreign_key_create_with_valid_emptystring(self):
@ -375,16 +286,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
@ -394,15 +298,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self): def test_foreign_key_update_with_valid_emptystring(self):
@ -417,21 +315,14 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedNullableOneToOneTests(TestCase): def setUp(self):
def setUp(self):
target = OneToOneTarget(name='target-1') target = OneToOneTarget(name='target-1')
target.save() target.save()
new_target = OneToOneTarget(name='target-2') new_target = OneToOneTarget(name='target-2')
@ -442,8 +333,5 @@ class HyperlinkedNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self): def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ]
{'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
assert serializer.data == expected assert serializer.data == expected

View File

@ -77,8 +77,7 @@ class OneToOnePKSourceSerializer(serializers.ModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
class PKManyToManyTests(TestCase): def setUp(self):
def setUp(self):
for idx in range(1, 4): for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx) target = ManyToManyTarget(name='target-%d' % idx)
target.save() target.save()
@ -90,11 +89,7 @@ class PKManyToManyTests(TestCase):
def test_many_to_many_retrieve(self): def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -107,11 +102,7 @@ class PKManyToManyTests(TestCase):
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -122,15 +113,9 @@ class PKManyToManyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
{'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
@ -140,15 +125,9 @@ class PKManyToManyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
{'id': 1, 'name': 'target-1', 'sources': [1]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]}
]
assert serializer.data == expected assert serializer.data == expected
def test_many_to_many_create(self): def test_many_to_many_create(self):
@ -158,25 +137,15 @@ class PKManyToManyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, ]
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
{'id': 4, 'name': 'source-4', 'targets': [1, 3]},
]
assert serializer.data == expected assert serializer.data == expected
def test_many_to_many_unsaved(self): def test_many_to_many_unsaved(self):
source = ManyToManySource(name='source-unsaved') source = ManyToManySource(name='source-unsaved')
serializer = ManyToManySourceSerializer(source) serializer = ManyToManySourceSerializer(source)
expected = {'id': None, 'name': 'source-unsaved', 'targets': []} expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
# no query if source hasn't been created yet
with self.assertNumQueries(0): with self.assertNumQueries(0):
assert serializer.data == expected assert serializer.data == expected
@ -187,21 +156,13 @@ class PKManyToManyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-4' assert obj.name == 'target-4'
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]}, {'id': 4, 'name': 'target-4', 'sources': [1, 3]} ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]},
{'id': 4, 'name': 'target-4', 'sources': [1, 3]}
]
assert serializer.data == expected assert serializer.data == expected
class PKForeignKeyTests(TestCase): def setUp(self):
def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
new_target = ForeignKeyTarget(name='target-2') new_target = ForeignKeyTarget(name='target-2')
@ -213,21 +174,14 @@ class PKForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1}
]
with self.assertNumQueries(1): with self.assertNumQueries(1):
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []},
]
with self.assertNumQueries(3): with self.assertNumQueries(3):
assert serializer.data == expected assert serializer.data == expected
@ -244,15 +198,9 @@ class PKForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 2}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
{'id': 1, 'name': 'source-1', 'target': 2},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
@ -267,26 +215,15 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected assert new_serializer.data == expected
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, ]
{'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': [1, 3]},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create(self): def test_foreign_key_create(self):
@ -296,16 +233,9 @@ class PKForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1}, {'id': 4, 'name': 'source-4', 'target': 2}, ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1},
{'id': 4, 'name': 'source-4', 'target': 2},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
@ -315,15 +245,9 @@ class PKForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-3' assert obj.name == 'target-3'
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, ]
{'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': [1, 3]},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
@ -336,10 +260,7 @@ class PKForeignKeyTests(TestCase):
def test_foreign_key_with_unsaved(self): def test_foreign_key_with_unsaved(self):
source = ForeignKeySource(name='source-unsaved') source = ForeignKeySource(name='source-unsaved')
expected = {'id': None, 'name': 'source-unsaved', 'target': None} expected = {'id': None, 'name': 'source-unsaved', 'target': None}
serializer = ForeignKeySourceSerializer(source) serializer = ForeignKeySourceSerializer(source)
# no query if source hasn't been created yet
with self.assertNumQueries(0): with self.assertNumQueries(0):
assert serializer.data == expected assert serializer.data == expected
@ -361,8 +282,8 @@ class PKForeignKeyTests(TestCase):
class Meta(ForeignKeySourceSerializer.Meta): class Meta(ForeignKeySourceSerializer.Meta):
extra_kwargs = {'target': {'required': False}} extra_kwargs = {'target': {'required': False}}
serializer = ModelSerializer(data={'name': 'test'}) serializer = ModelSerializer(data={'name': 'test'})
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
assert 'target' not in serializer.validated_data assert'target'notinserializer.validated_data
def test_queryset_size_without_limited_choices(self): def test_queryset_size_without_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target") limited_target = ForeignKeyTarget(name="limited-target")
@ -379,34 +300,28 @@ class PKForeignKeyTests(TestCase):
def test_queryset_size_with_Q_limited_choices(self): def test_queryset_size_with_Q_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target") limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save() limited_target.save()
class QLimitedChoicesSerializer(serializers.ModelSerializer): class QLimitedChoicesSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ForeignKeySourceWithQLimitedChoices model = ForeignKeySourceWithQLimitedChoices
fields = ("id", "target") fields = ("id", "target")
queryset = QLimitedChoicesSerializer().fields["target"].get_queryset() queryset = QLimitedChoicesSerializer().fields["target"].get_queryset()
assert len(queryset) == 1 assertlen(queryset)==1
class PKNullableForeignKeyTests(TestCase): def setUp(self):
def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
@ -416,16 +331,9 @@ class PKNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self): def test_foreign_key_create_with_valid_emptystring(self):
@ -440,16 +348,9 @@ class PKNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
@ -459,15 +360,9 @@ class PKNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self): def test_foreign_key_update_with_valid_emptystring(self):
@ -482,15 +377,9 @@ class PKNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_null_uuid_foreign_key_serializes_as_none(self): def test_null_uuid_foreign_key_serializes_as_none(self):
@ -505,8 +394,7 @@ class PKNullableForeignKeyTests(TestCase):
assert serializer.is_valid(), serializer.errors assert serializer.is_valid(), serializer.errors
class PKNullableOneToOneTests(TestCase): def setUp(self):
def setUp(self):
target = OneToOneTarget(name='target-1') target = OneToOneTarget(name='target-1')
target.save() target.save()
new_target = OneToOneTarget(name='target-2') new_target = OneToOneTarget(name='target-2')
@ -517,16 +405,12 @@ class PKNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self): def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True) serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ]
{'id': 1, 'name': 'target-1', 'nullable_source': None},
{'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
assert serializer.data == expected assert serializer.data == expected
class OneToOnePrimaryKeyTests(TestCase):
def setUp(self): def setUp(self):
# Given: Some target models already exist # Given: Some target models already exist
self.target = target = OneToOneTarget(name='target-1') self.target = target = OneToOneTarget(name='target-1')
target.save() target.save()
@ -537,36 +421,29 @@ class OneToOnePrimaryKeyTests(TestCase):
# When: Creating a Source pointing at the id of the second Target # When: Creating a Source pointing at the id of the second Target
target_pk = self.alt_target.id target_pk = self.alt_target.id
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
# Then: The source is valid with the serializer
if not source.is_valid(): if not source.is_valid():
self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors)) self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
# Then: Saving the serializer creates a new object # Then: Saving the serializer creates a new object
new_source = source.save() new_source = source.save()
# Then: The new object has the same pk as the target object assertnew_source.pk==target_pk
self.assertEqual(new_source.pk, target_pk)
def test_one_to_one_when_primary_key_no_duplicates(self): def test_one_to_one_when_primary_key_no_duplicates(self):
# When: Creating a Source pointing at the id of the second Target # When: Creating a Source pointing at the id of the second Target
target_pk = self.target.id target_pk = self.target.id
data = {'name': 'source-1', 'target': target_pk} data = {'name': 'source-1', 'target': target_pk}
source = OneToOnePKSourceSerializer(data=data) source = OneToOnePKSourceSerializer(data=data)
# Then: The source is valid with the serializer assert source.is_valid()
self.assertTrue(source.is_valid())
# Then: Saving the serializer creates a new object
new_source = source.save() new_source = source.save()
# Then: The new object has the same pk as the target object assert new_source.pk == target_pk
self.assertEqual(new_source.pk, target_pk)
# When: Trying to create a second object
second_source = OneToOnePKSourceSerializer(data=data) second_source = OneToOnePKSourceSerializer(data=data)
self.assertFalse(second_source.is_valid()) assert not second_source.is_valid()
expected = {'target': ['one to one pk source with this target already exists.']} expected = {'target': ['one to one pk source with this target already exists.']}
self.assertDictEqual(second_source.errors, expected) assert second_source.errors == expected
def test_one_to_one_when_primary_key_does_not_exist(self): def test_one_to_one_when_primary_key_does_not_exist(self):
# Given: a target PK that does not exist # Given: a target PK that does not exist
target_pk = self.target.pk + self.alt_target.pk target_pk = self.target.pk + self.alt_target.pk
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
# Then: The source is not valid with the serializer assert not source.is_valid()
self.assertFalse(source.is_valid()) assert "Invalid pk" in source.errors['target'][0]
self.assertIn("Invalid pk", source.errors['target'][0]) assert "object does not exist" in source.errors['target'][0]
self.assertIn("object does not exist", source.errors['target'][0])

View File

@ -42,8 +42,7 @@ class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
# TODO: M2M Tests, FKTests (Non-nullable), One2One # TODO: M2M Tests, FKTests (Non-nullable), One2One
class SlugForeignKeyTests(TestCase): def setUp(self):
def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
new_target = ForeignKeyTarget(name='target-2') new_target = ForeignKeyTarget(name='target-2')
@ -55,11 +54,7 @@ class SlugForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -72,10 +67,7 @@ class SlugForeignKeyTests(TestCase):
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self): def test_reverse_foreign_key_retrieve_prefetch_related(self):
@ -91,15 +83,9 @@ class SlugForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-2'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
{'id': 1, 'name': 'source-1', 'target': 'target-2'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
@ -114,26 +100,15 @@ class SlugForeignKeyTests(TestCase):
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected assert new_serializer.data == expected
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create(self): def test_foreign_key_create(self):
@ -144,16 +119,9 @@ class SlugForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'}, {'id': 4, 'name': 'source-4', 'target': 'target-2'}, ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'},
{'id': 4, 'name': 'source-4', 'target': 'target-2'},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
@ -163,15 +131,9 @@ class SlugForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-3' assert obj.name == 'target-3'
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
@ -182,24 +144,19 @@ class SlugForeignKeyTests(TestCase):
assert serializer.errors == {'target': ['This field may not be null.']} assert serializer.errors == {'target': ['This field may not be null.']}
class SlugNullableForeignKeyTests(TestCase): def setUp(self):
def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
@ -209,16 +166,9 @@ class SlugNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self): def test_foreign_key_create_with_valid_emptystring(self):
@ -233,16 +183,9 @@ class SlugNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
@ -252,15 +195,9 @@ class SlugNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self): def test_foreign_key_update_with_valid_emptystring(self):
@ -275,13 +212,7 @@ class SlugNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected

View File

@ -49,11 +49,10 @@ class DummyTestModel(models.Model):
name = models.CharField(max_length=42, default='') name = models.CharField(max_length=42, default='')
class BasicRendererTests(TestCase): def test_expected_results(self):
def test_expected_results(self):
for value, renderer_cls, expected in expected_results: for value, renderer_cls, expected in expected_results:
output = renderer_cls().render(value) output = renderer_cls().render(value)
self.assertEqual(output, expected) assert output == expected
class RendererA(BaseRenderer): class RendererA(BaseRenderer):
@ -144,8 +143,7 @@ class POSTDeniedView(APIView):
return Response() return Response()
class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self):
def test_only_permitted_forms_are_displayed(self):
view = POSTDeniedView.as_view() view = POSTDeniedView.as_view()
request = APIRequestFactory().get('/') request = APIRequestFactory().get('/')
response = view(request).render() response = view(request).render()
@ -155,90 +153,82 @@ class DocumentingRendererTests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_renderers') @override_settings(ROOT_URLCONF='tests.test_renderers')
class RendererEndToEndTests(TestCase): """
"""
End-to-end testing of renderers using an RendererMixin on a generic view. End-to-end testing of renderers using an RendererMixin on a generic view.
""" """
def test_default_renderer_serializes_content(self): deftest_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/') resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self): def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests.""" """No response must be included in HEAD requests."""
resp = self.client.head('/') resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, b'') assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*') resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self): def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)""" (In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self): def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)""" (In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_unsatisfiable_accept_header_on_request_returns_406_status(self): def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar') resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) assert resp.status_code == status.HTTP_406_NOT_ACCEPTABLE
def test_specified_renderer_serializes_content_on_format_query(self): def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
param = '?%s=%s' % ( param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
api_settings.URL_FORMAT_OVERRIDE,
RendererB.format
)
resp = self.client.get('/' + param) resp = self.client.get('/' + param)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self): def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching """If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/something.formatb') resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the response."""
param = '?%s=%s' % ( param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
api_settings.URL_FORMAT_OVERRIDE, resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type)
RendererB.format assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
resp = self.client.get('/' + param, assert resp.status_code == DUMMYSTATUS
HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_parse_error_renderers_browsable_api(self): def test_parse_error_renderers_browsable_api(self):
"""Invalid data should still render the browsable API correctly.""" """Invalid data should still render the browsable API correctly."""
resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) assert resp.status_code == status.HTTP_400_BAD_REQUEST
def test_204_no_content_responses_have_no_content_type_set(self): def test_204_no_content_responses_have_no_content_type_set(self):
""" """
@ -247,8 +237,8 @@ class RendererEndToEndTests(TestCase):
https://github.com/encode/django-rest-framework/issues/1196 https://github.com/encode/django-rest-framework/issues/1196
""" """
resp = self.client.get('/empty') resp = self.client.get('/empty')
self.assertEqual(resp.get('Content-Type', None), None) assert resp.get('Content-Type', None) == None
self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) assert resp.status_code == status.HTTP_204_NO_CONTENT
def test_contains_headers_of_api_response(self): def test_contains_headers_of_api_response(self):
""" """
@ -275,11 +265,10 @@ def strip_trailing_whitespace(content):
return re.sub(' +\n', '\n', content) return re.sub(' +\n', '\n', content)
class BaseRendererTests(TestCase): """
"""
Tests BaseRenderer Tests BaseRenderer
""" """
def test_render_raise_error(self): deftest_render_raise_error(self):
""" """
BaseRenderer.render should raise NotImplementedError BaseRenderer.render should raise NotImplementedError
""" """
@ -287,31 +276,29 @@ class BaseRendererTests(TestCase):
BaseRenderer().render('test') BaseRenderer().render('test')
class JSONRendererTests(TestCase): """
"""
Tests specific to the JSON Renderer Tests specific to the JSON Renderer
""" """
deftest_render_lazy_strings(self):
def test_render_lazy_strings(self):
""" """
JSONRenderer should deal with lazy translated strings. JSONRenderer should deal with lazy translated strings.
""" """
ret = JSONRenderer().render(_('test')) ret = JSONRenderer().render(_('test'))
self.assertEqual(ret, b'"test"') assert ret == b'"test"'
def test_render_queryset_values(self): def test_render_queryset_values(self):
o = DummyTestModel.objects.create(name='dummy') o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values('id', 'name') qs = DummyTestModel.objects.values('id', 'name')
ret = JSONRenderer().render(qs) ret = JSONRenderer().render(qs)
data = json.loads(ret.decode()) data = json.loads(ret.decode())
self.assertEqual(data, [{'id': o.id, 'name': o.name}]) assert data == [{'id': o.id, 'name': o.name}]
def test_render_queryset_values_list(self): def test_render_queryset_values_list(self):
o = DummyTestModel.objects.create(name='dummy') o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values_list('id', 'name') qs = DummyTestModel.objects.values_list('id', 'name')
ret = JSONRenderer().render(qs) ret = JSONRenderer().render(qs)
data = json.loads(ret.decode()) data = json.loads(ret.decode())
self.assertEqual(data, [[o.id, o.name]]) assert data == [[o.id, o.name]]
def test_render_dict_abc_obj(self): def test_render_dict_abc_obj(self):
class Dict(MutableMapping): class Dict(MutableMapping):
@ -337,11 +324,11 @@ class JSONRendererTests(TestCase):
return self._dict.keys() return self._dict.keys()
x = Dict() x = Dict()
x['key'] = 'string value' x['key']='string value'
x[2] = 3 x[2]=3
ret = JSONRenderer().render(x) ret=JSONRenderer().render(x)
data = json.loads(ret.decode()) data=json.loads(ret.decode())
self.assertEqual(data, {'key': 'string value', '2': 3}) assertdata=={'key':'string value','2':3}
def test_render_obj_with_getitem(self): def test_render_obj_with_getitem(self):
class DictLike: class DictLike:
@ -355,22 +342,20 @@ class JSONRendererTests(TestCase):
return self._dict[key] return self._dict[key]
x = DictLike() x = DictLike()
x.set({'a': 1, 'b': 'string'}) x.set({'a':1,'b':'string'})
with self.assertRaises(TypeError): withpytest.raises(TypeError):
JSONRenderer().render(x) JSONRenderer().render(x)
def test_float_strictness(self): def test_float_strictness(self):
renderer = JSONRenderer() renderer = JSONRenderer()
# Default to strict
for value in [float('inf'), float('-inf'), float('nan')]: for value in [float('inf'), float('-inf'), float('nan')]:
with pytest.raises(ValueError): with pytest.raises(ValueError):
renderer.render(value) renderer.render(value)
renderer.strict = False renderer.strict = False
assert renderer.render(float('inf')) == b'Infinity' assertrenderer.render(float('inf'))==b'Infinity'
assert renderer.render(float('-inf')) == b'-Infinity' assertrenderer.render(float('-inf'))==b'-Infinity'
assert renderer.render(float('nan')) == b'NaN' assertrenderer.render(float('nan'))==b'NaN'
def test_without_content_type_args(self): def test_without_content_type_args(self):
""" """
@ -379,8 +364,7 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']} obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library. assert content.decode() == _flat_repr
self.assertEqual(content.decode(), _flat_repr)
def test_with_content_type_args(self): def test_with_content_type_args(self):
""" """
@ -389,18 +373,17 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']} obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2') content = renderer.render(obj, 'application/json; indent=2')
self.assertEqual(strip_trailing_whitespace(content.decode()), _indented_repr) assert strip_trailing_whitespace(content.decode()) == _indented_repr
class UnicodeJSONRendererTests(TestCase): """
"""
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
def test_proper_encoding(self): deftest_proper_encoding(self):
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode()) assert content == '{"countries":["United Kingdom","France","España"]}'.encode()
def test_u2028_u2029(self): def test_u2028_u2029(self):
# The \u2028 and \u2029 characters should be escaped, # The \u2028 and \u2029 characters should be escaped,
@ -409,29 +392,27 @@ class UnicodeJSONRendererTests(TestCase):
obj = {'should_escape': '\u2028\u2029'} obj = {'should_escape': '\u2028\u2029'}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode()) assert content == '{"should_escape":"\\u2028\\u2029"}'.encode()
class AsciiJSONRendererTests(TestCase): """
"""
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
def test_proper_encoding(self): deftest_proper_encoding(self):
class AsciiJSONRenderer(JSONRenderer): class AsciiJSONRenderer(JSONRenderer):
ensure_ascii = True ensure_ascii = True
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = AsciiJSONRenderer() renderer=AsciiJSONRenderer()
content = renderer.render(obj, 'application/json') content=renderer.render(obj,'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()) assertcontent=='{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()
# Tests for caching issue, #346 # Tests for caching issue, #346
@override_settings(ROOT_URLCONF='tests.test_renderers') @override_settings(ROOT_URLCONF='tests.test_renderers')
class CacheRenderTest(TestCase): """
"""
Tests specific to caching responses Tests specific to caching responses
""" """
def test_head_caching(self): deftest_head_caching(self):
""" """
Test caching of HEAD requests Test caching of HEAD requests
""" """
@ -476,136 +457,105 @@ class TestJSONIndentationStyles:
assert renderer.render(data) == b'{"a": 1, "b": 2}' assert renderer.render(data) == b'{"a": 1, "b": 2}'
class TestHiddenFieldHTMLFormRenderer(TestCase): def test_hidden_field_rendering(self):
def test_hidden_field_rendering(self):
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
published = serializers.HiddenField(default=True) published = serializers.HiddenField(default=True)
serializer = TestSerializer(data={}) serializer = TestSerializer(data={})
serializer.is_valid() serializer.is_valid()
renderer = HTMLFormRenderer() renderer=HTMLFormRenderer()
field = serializer['published'] field=serializer['published']
rendered = renderer.render_field(field, {}) rendered=renderer.render_field(field,{})
assert rendered == '' assertrendered==''
class TestHTMLFormRenderer(TestCase): def setUp(self):
def setUp(self):
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
test_field = serializers.CharField() test_field = serializers.CharField()
self.renderer = HTMLFormRenderer() self.renderer = HTMLFormRenderer()
self.serializer = TestSerializer(data={}) self.serializer=TestSerializer(data={})
def test_render_with_default_args(self): def test_render_with_default_args(self):
self.serializer.is_valid() self.serializer.is_valid()
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data) result = renderer.render(self.serializer.data)
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText)
def test_render_with_provided_args(self): def test_render_with_provided_args(self):
self.serializer.is_valid() self.serializer.is_valid()
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data, None, {}) result = renderer.render(self.serializer.data, None, {})
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText)
class TestChoiceFieldHTMLFormRenderer(TestCase): """
"""
Test rendering ChoiceField with HTMLFormRenderer. Test rendering ChoiceField with HTMLFormRenderer.
""" """
defsetUp(self):
def setUp(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
test_field = serializers.ChoiceField(choices=choices, test_field = serializers.ChoiceField(choices=choices, initial=2)
initial=2)
self.TestSerializer = TestSerializer self.TestSerializer = TestSerializer
self.renderer = HTMLFormRenderer() self.renderer=HTMLFormRenderer()
def test_render_initial_option(self): def test_render_initial_option(self):
serializer = self.TestSerializer() serializer = self.TestSerializer()
result = self.renderer.render(serializer.data) result = self.renderer.render(serializer.data)
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="2" selected>Option2</option>', result)
self.assertInHTML('<option value="2" selected>Option2</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result) self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="12">Option12</option>', result) self.assertInHTML('<option value="12">Option12</option>', result)
def test_render_selected_option(self): def test_render_selected_option(self):
serializer = self.TestSerializer(data={'test_field': '12'}) serializer = self.TestSerializer(data={'test_field': '12'})
serializer.is_valid() serializer.is_valid()
result = self.renderer.render(serializer.data) result = self.renderer.render(serializer.data)
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="12" selected>Option12</option>', result)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result) self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result) self.assertInHTML('<option value="2">Option2</option>', result)
class TestMultipleChoiceFieldHTMLFormRenderer(TestCase): """
"""
Test rendering MultipleChoiceField with HTMLFormRenderer. Test rendering MultipleChoiceField with HTMLFormRenderer.
""" """
defsetUp(self):
def setUp(self):
self.renderer = HTMLFormRenderer() self.renderer = HTMLFormRenderer()
def test_render_selected_option_with_string_option_ids(self): def test_render_selected_option_with_string_option_ids(self):
choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), ('}', 'OptionBrace'))
('}', 'OptionBrace'))
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices) test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']}) serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid() serializer.is_valid()
result=self.renderer.render(serializer.data)
result = self.renderer.render(serializer.data) assertisinstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',result)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="1">Option1</option>',result)
self.assertInHTML('<option value="2">Option2</option>',result)
self.assertInHTML('<option value="12" selected>Option12</option>', self.assertInHTML('<option value="}">OptionBrace</option>',result)
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
self.assertInHTML('<option value="}">OptionBrace</option>', result)
def test_render_selected_option_with_integer_option_ids(self): def test_render_selected_option_with_integer_option_ids(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices) test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']}) serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid() serializer.is_valid()
result=self.renderer.render(serializer.data)
result = self.renderer.render(serializer.data) assertisinstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',result)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="1">Option1</option>',result)
self.assertInHTML('<option value="2">Option2</option>',result)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
class StaticHTMLRendererTests(TestCase): """
"""
Tests specific for Static HTML Renderer Tests specific for Static HTML Renderer
""" """
def setUp(self): defsetUp(self):
self.renderer = StaticHTMLRenderer() self.renderer = StaticHTMLRenderer()
def test_static_renderer(self): def test_static_renderer(self):
@ -614,10 +564,7 @@ class StaticHTMLRendererTests(TestCase):
assert result == data assert result == data
def test_static_renderer_with_exception(self): def test_static_renderer_with_exception(self):
context = { context = { 'response': Response(status=500, exception=True), 'request': Request(HttpRequest()) }
'response': Response(status=500, exception=True),
'request': Request(HttpRequest())
}
result = self.renderer.render({}, renderer_context=context) result = self.renderer.render({}, renderer_context=context)
assert result == '500 Internal Server Error' assert result == '500 Internal Server Error'
@ -658,196 +605,142 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
assert '>Extra list action<' in resp.content.decode() assert '>Extra list action<' in resp.content.decode()
class AdminRendererTests(TestCase):
def setUp(self): def setUp(self):
self.renderer = AdminRenderer() self.renderer = AdminRenderer()
def test_render_when_resource_created(self): def test_render_when_resource_created(self):
class DummyView(APIView): class DummyView(APIView):
renderer_classes = (AdminRenderer, ) renderer_classes = (AdminRenderer, )
request = Request(HttpRequest()) request = Request(HttpRequest())
request.build_absolute_uri = lambda: 'http://example.com' request.build_absolute_uri=lambda:'http://example.com'
response = Response(status=201, headers={'Location': '/test'}) response=Response(status=201,headers={'Location':'/test'})
context = { context={'view':DummyView(),'request':request,'response':response}
'view': DummyView(), result=self.renderer.render(data={'test':'test'},renderer_context=context)
'request': request, assertresult==''
'response': response assertresponse.status_code==status.HTTP_303_SEE_OTHER
} assertresponse['Location']=='http://example.com'
result = self.renderer.render(data={'test': 'test'},
renderer_context=context)
assert result == ''
assert response.status_code == status.HTTP_303_SEE_OTHER
assert response['Location'] == 'http://example.com'
def test_render_dict(self): def test_render_dict(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
renderer_classes = (AdminRenderer, ) renderer_classes = (AdminRenderer, )
def get(self, request): def get(self, request):
return Response({'foo': 'a string'}) return Response({'foo': 'a string'})
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
response.render() response.render()
self.assertContains(response, '<tr><th>Foo</th><td>a string</td></tr>', html=True) self.assertContains(response,'<tr><th>Foo</th><td>a string</td></tr>',html=True)
def test_render_dict_with_items_key(self): def test_render_dict_with_items_key(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
renderer_classes = (AdminRenderer, ) renderer_classes = (AdminRenderer, )
def get(self, request): def get(self, request):
return Response({'items': 'a string'}) return Response({'items': 'a string'})
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
response.render() response.render()
self.assertContains(response, '<tr><th>Items</th><td>a string</td></tr>', html=True) self.assertContains(response,'<tr><th>Items</th><td>a string</td></tr>',html=True)
def test_render_dict_with_iteritems_key(self): def test_render_dict_with_iteritems_key(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
renderer_classes = (AdminRenderer, ) renderer_classes = (AdminRenderer, )
def get(self, request): def get(self, request):
return Response({'iteritems': 'a string'}) return Response({'iteritems': 'a string'})
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
response.render() response.render()
self.assertContains(response, '<tr><th>Iteritems</th><td>a string</td></tr>', html=True) self.assertContains(response,'<tr><th>Iteritems</th><td>a string</td></tr>',html=True)
def test_get_result_url(self): def test_get_result_url(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyGenericViewsetLike(APIView): class DummyGenericViewsetLike(APIView):
lookup_field = 'test' lookup_field = 'test'
def reverse_action(view, *args, **kwargs): def reverse_action(view, *args, **kwargs):
self.assertEqual(kwargs['kwargs']['test'], 1) assert kwargs['kwargs']['test'] == 1
return '/example/' return '/example/'
# get the view instance instead of the view function # get the view instance instead of the view function
view = DummyGenericViewsetLike.as_view() view = DummyGenericViewsetLike.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
view = response.renderer_context['view'] view=response.renderer_context['view']
assertself.renderer.get_result_url({'test':1},view)=='/example/'
self.assertEqual(self.renderer.get_result_url({'test': 1}, view), '/example/') assertself.renderer.get_result_url({},view)is None
self.assertIsNone(self.renderer.get_result_url({}, view))
def test_get_result_url_no_result(self): def test_get_result_url_no_result(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
lookup_field = 'test' lookup_field = 'test'
# get the view instance instead of the view function # get the view instance instead of the view function
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
view = response.renderer_context['view'] view=response.renderer_context['view']
assertself.renderer.get_result_url({'test':1},view)is None
self.assertIsNone(self.renderer.get_result_url({'test': 1}, view)) assertself.renderer.get_result_url({},view)is None
self.assertIsNone(self.renderer.get_result_url({}, view))
def test_get_context_result_urls(self): def test_get_context_result_urls(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
lookup_field = 'test' lookup_field = 'test'
def reverse_action(view, url_name, args=None, kwargs=None): def reverse_action(view, url_name, args=None, kwargs=None):
return '/%s/%d' % (url_name, kwargs['test']) return '/%s/%d' % (url_name, kwargs['test'])
# get the view instance instead of the view function # get the view instance instead of the view function
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
data=[{'test':1},{'url':'/example','test':2},{'url':None,'test':3},{},]
data = [ context={'view':DummyView(),'request':Request(request),'response':response}
{'test': 1}, context=self.renderer.get_context(data,None,context)
{'url': '/example', 'test': 2}, results=context['results']
{'url': None, 'test': 3}, assertlen(results)==4
{}, assertresults[0]['url']=='/detail/1'
] assertresults[1]['url']=='/example'
context = { assertresults[2]['url']==None
'view': DummyView(), assert'url'not inresults[3]
'request': Request(request),
'response': response
}
context = self.renderer.get_context(data, None, context)
results = context['results']
self.assertEqual(len(results), 4)
self.assertEqual(results[0]['url'], '/detail/1')
self.assertEqual(results[1]['url'], '/example')
self.assertEqual(results[2]['url'], None)
self.assertNotIn('url', results[3])
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestDocumentationRenderer(TestCase):
def test_document_with_link_named_data(self): def test_document_with_link_named_data(self):
""" """
Ref #5395: Doc's `document.data` would fail with a Link named "data". Ref #5395: Doc's `document.data` would fail with a Link named "data".
As per #4972, use templatetag instead. As per #4972, use templatetag instead.
""" """
document = coreapi.Document( document = coreapi.Document( title='Data Endpoint API', url='https://api.example.org/', content={ 'data': coreapi.Link( url='/data/', action='get', fields=[], description='Return data.' ) } )
title='Data Endpoint API',
url='https://api.example.org/',
content={
'data': coreapi.Link(
url='/data/',
action='get',
fields=[],
description='Return data.'
)
}
)
factory = APIRequestFactory() factory = APIRequestFactory()
request = factory.get('/') request = factory.get('/')
renderer = DocumentationRenderer() renderer = DocumentationRenderer()
html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request}) html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
assert '<h1>Data Endpoint API</h1>' in html assert '<h1>Data Endpoint API</h1>' in html
def test_shell_code_example_rendering(self): def test_shell_code_example_rendering(self):
template = loader.get_template('rest_framework/docs/langs/shell.html') template = loader.get_template('rest_framework/docs/langs/shell.html')
context = { context = { 'document': coreapi.Document(url='https://api.example.org/'), 'link_key': 'testcases > list', 'link': coreapi.Link(url='/data/', action='get', fields=[]), }
'document': coreapi.Document(url='https://api.example.org/'),
'link_key': 'testcases > list',
'link': coreapi.Link(url='/data/', action='get', fields=[]),
}
html = template.render(context) html = template.render(context)
assert 'testcases list' in html assert 'testcases list' in html
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestSchemaJSRenderer(TestCase):
def test_schemajs_output(self): def test_schemajs_output(self):
""" """
Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates, Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates,
and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix. and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix.
""" """
factory = APIRequestFactory() factory = APIRequestFactory()
request = factory.get('/') request = factory.get('/')
renderer = SchemaJSRenderer() renderer = SchemaJSRenderer()
output = renderer.render('data', renderer_context={"request": request}) output = renderer.render('data', renderer_context={"request": request})
assert "'ImRhdGEi'" in output assert "'ImRhdGEi'" in output
assert "'b'ImRhdGEi''" not in output assert "'b'ImRhdGEi''" not in output

View File

@ -23,16 +23,9 @@ from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
factory = APIRequestFactory() factory = APIRequestFactory()
def test_request_type(self):
class TestInitializer(TestCase):
def test_request_type(self):
request = Request(factory.get('/')) request = Request(factory.get('/'))
message = ( 'The `request` argument must be an instance of ' '`django.http.HttpRequest`, not `rest_framework.request.Request`.' )
message = (
'The `request` argument must be an instance of '
'`django.http.HttpRequest`, not `rest_framework.request.Request`.'
)
with self.assertRaisesMessage(AssertionError, message): with self.assertRaisesMessage(AssertionError, message):
Request(request) Request(request)
@ -50,8 +43,7 @@ class PlainTextParser(BaseParser):
return stream.read() return stream.read()
class TestContentParsing(TestCase): def test_standard_behaviour_determines_no_content_GET(self):
def test_standard_behaviour_determines_no_content_GET(self):
""" """
Ensure request.data returns empty QueryDict for GET request. Ensure request.data returns empty QueryDict for GET request.
""" """
@ -160,22 +152,19 @@ urlpatterns = [
@override_settings( @override_settings(
ROOT_URLCONF='tests.test_request', ROOT_URLCONF='tests.test_request',
FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler']) FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler'])
class FileUploadTests(TestCase):
def test_fileuploads_closed_at_request_end(self): def test_fileuploads_closed_at_request_end(self):
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
response = self.client.post('/upload/', {'file': f}) response = self.client.post('/upload/', {'file': f})
# sanity check that file was processed # sanity check that file was processed
assert len(response.data) == 1 assert len(response.data) == 1
forfileinresponse.data:
for file in response.data:
assert not os.path.exists(file) assert not os.path.exists(file)
@override_settings(ROOT_URLCONF='tests.test_request') @override_settings(ROOT_URLCONF='tests.test_request')
class TestContentParsingWithAuthentication(TestCase): def setUp(self):
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
@ -188,24 +177,20 @@ class TestContentParsingWithAuthentication(TestCase):
doesn't log in. doesn't log in.
""" """
content = {'example': 'example'} content = {'example': 'example'}
response = self.client.post('/', content) response = self.client.post('/', content)
assert status.HTTP_200_OK == response.status_code assert status.HTTP_200_OK == response.status_code
response = self.csrf_client.post('/', content) response = self.csrf_client.post('/', content)
assert status.HTTP_200_OK == response.status_code assert status.HTTP_200_OK == response.status_code
class TestUserSetter(TestCase):
def setUp(self): def setUp(self):
# Pass request object through session middleware so session is # Pass request object through session middleware so session is
# available to login and logout functions # available to login and logout functions
self.wrapped_request = factory.get('/') self.wrapped_request = factory.get('/')
self.request = Request(self.wrapped_request) self.request = Request(self.wrapped_request)
SessionMiddleware().process_request(self.wrapped_request) SessionMiddleware().process_request(self.wrapped_request)
AuthenticationMiddleware().process_request(self.wrapped_request) AuthenticationMiddleware().process_request(self.wrapped_request)
User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
self.user = authenticate(username='ringo', password='yellow') self.user = authenticate(username='ringo', password='yellow')
@ -237,13 +222,9 @@ class TestUserSetter(TestCase):
self.MISSPELLED_NAME_THAT_DOESNT_EXIST self.MISSPELLED_NAME_THAT_DOESNT_EXIST
request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),)) request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),))
assertself.wrapped_request.user.is_anonymous
# The middleware processes the underlying Django request, sets anonymous user expected=r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
assert self.wrapped_request.user.is_anonymous withpytest.raises(WrappedAttributeError,match=expected):
# The DRF request object does not have a user and should run authenticators
expected = r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
with pytest.raises(WrappedAttributeError, match=expected):
request.user request.user
with pytest.raises(WrappedAttributeError, match=expected): with pytest.raises(WrappedAttributeError, match=expected):
@ -253,16 +234,14 @@ class TestUserSetter(TestCase):
login(request, self.user) login(request, self.user)
class TestAuthSetter(TestCase): def test_auth_can_be_set(self):
def test_auth_can_be_set(self):
request = Request(factory.get('/')) request = Request(factory.get('/'))
request.auth = 'DUMMY' request.auth = 'DUMMY'
assert request.auth == 'DUMMY' assert request.auth == 'DUMMY'
class TestSecure(TestCase):
def test_default_secure_false(self): def test_default_secure_false(self):
request = Request(factory.get('/', secure=False)) request = Request(factory.get('/', secure=False))
assert request.scheme == 'http' assert request.scheme == 'http'
@ -271,15 +250,12 @@ class TestSecure(TestCase):
assert request.scheme == 'https' assert request.scheme == 'https'
class TestHttpRequest(TestCase): def test_attribute_access_proxy(self):
def test_attribute_access_proxy(self):
http_request = factory.get('/') http_request = factory.get('/')
request = Request(http_request) request = Request(http_request)
inner_sentinel = object() inner_sentinel = object()
http_request.inner_property = inner_sentinel http_request.inner_property = inner_sentinel
assert request.inner_property is inner_sentinel assert request.inner_property is inner_sentinel
outer_sentinel = object() outer_sentinel = object()
request.inner_property = outer_sentinel request.inner_property = outer_sentinel
assert request.inner_property is outer_sentinel assert request.inner_property is outer_sentinel
@ -288,30 +264,25 @@ class TestHttpRequest(TestCase):
# ensure the exception message is not for the underlying WSGIRequest # ensure the exception message is not for the underlying WSGIRequest
http_request = factory.get('/') http_request = factory.get('/')
request = Request(http_request) request = Request(http_request)
message = "'Request' object has no attribute 'inner_property'" message = "'Request' object has no attribute 'inner_property'"
with self.assertRaisesMessage(AttributeError, message): with self.assertRaisesMessage(AttributeError, message):
request.inner_property request.inner_property
@override_settings(ROOT_URLCONF='tests.test_request') @override_settings(ROOT_URLCONF='tests.test_request')
def test_duplicate_request_stream_parsing_exception(self): deftest_duplicate_request_stream_parsing_exception(self):
""" """
Check assumption that duplicate stream parsing will result in a Check assumption that duplicate stream parsing will result in a
`RawPostDataException` being raised. `RawPostDataException` being raised.
""" """
response = APIClient().post('/echo/', data={'a': 'b'}, format='json') response = APIClient().post('/echo/', data={'a': 'b'}, format='json')
request = response.renderer_context['request'] request = response.renderer_context['request']
# ensure that request stream was consumed by json parser
assert request.content_type.startswith('application/json') assert request.content_type.startswith('application/json')
assert response.data == {'a': 'b'} assert response.data == {'a': 'b'}
# pass same HttpRequest to view, stream already consumed
with pytest.raises(RawPostDataException): with pytest.raises(RawPostDataException):
EchoView.as_view()(request._request) EchoView.as_view()(request._request)
@override_settings(ROOT_URLCONF='tests.test_request') @override_settings(ROOT_URLCONF='tests.test_request')
def test_duplicate_request_form_data_access(self): deftest_duplicate_request_form_data_access(self):
""" """
Form data is copied to the underlying django request for middleware Form data is copied to the underlying django request for middleware
and file closing reasons. Duplicate processing of a request with form and file closing reasons. Duplicate processing of a request with form
@ -320,15 +291,9 @@ class TestHttpRequest(TestCase):
""" """
response = APIClient().post('/echo/', data={'a': 'b'}) response = APIClient().post('/echo/', data={'a': 'b'})
request = response.renderer_context['request'] request = response.renderer_context['request']
# ensure that request stream was consumed by form parser
assert request.content_type.startswith('multipart/form-data') assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']} assert response.data == {'a': ['b']}
# pass same HttpRequest to view, form data set on underlying request
response = EchoView.as_view()(request._request) response = EchoView.as_view()(request._request)
request = response.renderer_context['request'] request = response.renderer_context['request']
# ensure that request stream was consumed by form parser
assert request.content_type.startswith('multipart/form-data') assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']} assert response.data == {'a': ['b']}

View File

@ -131,97 +131,90 @@ urlpatterns = [
# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... # TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class RendererIntegrationTests(TestCase): """
"""
End-to-end testing of renderers using an ResponseMixin on a generic view. End-to-end testing of renderers using an ResponseMixin on a generic view.
""" """
def test_default_renderer_serializes_content(self): deftest_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/') resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self): def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests.""" """No response must be included in HEAD requests."""
resp = self.client.head('/') resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, b'') assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*') resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self): def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)""" (In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self): def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)""" (In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_query(self): def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format) resp = self.client.get('/?format=%s' % RendererB.format)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self): def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching """If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/something.formatb') resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format, resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type)
HTTP_ACCEPT=RendererB.media_type) assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp.status_code, DUMMYSTATUS)
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class UnsupportedMediaTypeTests(TestCase): def test_should_allow_posting_json(self):
def test_should_allow_posting_json(self):
response = self.client.post('/json', data='{"test": 123}', content_type='application/json') response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
assert response.status_code == 200
self.assertEqual(response.status_code, 200)
def test_should_not_allow_posting_xml(self): def test_should_not_allow_posting_xml(self):
response = self.client.post('/json', data='<test>123</test>', content_type='application/xml') response = self.client.post('/json', data='<test>123</test>', content_type='application/xml')
assert response.status_code == 415
self.assertEqual(response.status_code, 415)
def test_should_not_allow_posting_a_form(self): def test_should_not_allow_posting_a_form(self):
response = self.client.post('/json', data={'test': 123}) response = self.client.post('/json', data={'test': 123})
assert response.status_code == 415
self.assertEqual(response.status_code, 415)
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue122Tests(TestCase): """
"""
Tests that covers #122. Tests that covers #122.
""" """
def test_only_html_renderer(self): deftest_only_html_renderer(self):
""" """
Test if no infinite recursion occurs. Test if no infinite recursion occurs.
""" """
@ -235,30 +228,28 @@ class Issue122Tests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue467Tests(TestCase): """
"""
Tests for #467 Tests for #467
""" """
def test_form_has_label_and_help_text(self): deftest_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue807Tests(TestCase): """
"""
Covers #807 Covers #807
""" """
def test_does_not_append_charset_by_default(self): deftest_does_not_append_charset_by_default(self):
""" """
Renderers don't include a charset unless set explicitly. Renderers don't include a charset unless set explicitly.
""" """
headers = {"HTTP_ACCEPT": RendererA.media_type} headers = {"HTTP_ACCEPT": RendererA.media_type}
resp = self.client.get('/', **headers) resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererA.media_type, 'utf-8') expected = "{}; charset={}".format(RendererA.media_type, 'utf-8')
self.assertEqual(expected, resp['Content-Type']) assert expected == resp['Content-Type']
def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
""" """
@ -268,7 +259,7 @@ class Issue807Tests(TestCase):
headers = {"HTTP_ACCEPT": RendererC.media_type} headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/', **headers) resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset) expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset)
self.assertEqual(expected, resp['Content-Type']) assert expected == resp['Content-Type']
def test_content_type_set_explicitly_on_response(self): def test_content_type_set_explicitly_on_response(self):
""" """
@ -276,10 +267,10 @@ class Issue807Tests(TestCase):
""" """
headers = {"HTTP_ACCEPT": RendererC.media_type} headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/setbyview', **headers) resp = self.client.get('/setbyview', **headers)
self.assertEqual('setbyview', resp['Content-Type']) assert 'setbyview' == resp['Content-Type']
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')

View File

@ -30,11 +30,10 @@ class MockVersioningScheme:
@override_settings(ROOT_URLCONF='tests.test_reverse') @override_settings(ROOT_URLCONF='tests.test_reverse')
class ReverseTests(TestCase): """
"""
Tests for fully qualified URLs when using `reverse`. Tests for fully qualified URLs when using `reverse`.
""" """
def test_reversed_urls_are_fully_qualified(self): deftest_reversed_urls_are_fully_qualified(self):
request = factory.get('/view') request = factory.get('/view')
url = reverse('view', request=request) url = reverse('view', request=request)
assert url == 'http://testserver/view' assert url == 'http://testserver/view'
@ -42,13 +41,11 @@ class ReverseTests(TestCase):
def test_reverse_with_versioning_scheme(self): def test_reverse_with_versioning_scheme(self):
request = factory.get('/view') request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme() request.versioning_scheme = MockVersioningScheme()
url = reverse('view', request=request) url = reverse('view', request=request)
assert url == 'http://scheme-reversed/view' assert url == 'http://scheme-reversed/view'
def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self): def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self):
request = factory.get('/view') request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme(raise_error=True) request.versioning_scheme = MockVersioningScheme(raise_error=True)
url = reverse('view', request=request) url = reverse('view', request=request)
assert url == 'http://testserver/view' assert url == 'http://testserver/view'

View File

@ -214,20 +214,19 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"} assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"}
class TestLookupValueRegex(TestCase): """
"""
Ensure the router honors lookup_value_regex when applied Ensure the router honors lookup_value_regex when applied
to the viewset. to the viewset.
""" """
def setUp(self): defsetUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all() queryset = RouterTestModel.objects.all()
lookup_field = 'uuid' lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f]{32}' lookup_value_regex = '[0-9a-f]{32}'
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes',NoteViewSet)
self.urls = self.router.urls self.urls=self.router.urls
def test_urls_limited_by_lookup_value_regex(self): def test_urls_limited_by_lookup_value_regex(self):
expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$'] expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$']
@ -264,14 +263,13 @@ class TestLookupUrlKwargs(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"} assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"}
class TestTrailingSlashIncluded(TestCase): def setUp(self):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all() queryset = RouterTestModel.objects.all()
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes',NoteViewSet)
self.urls = self.router.urls self.urls=self.router.urls
def test_urls_have_trailing_slash_by_default(self): def test_urls_have_trailing_slash_by_default(self):
expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$'] expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$']
@ -279,14 +277,13 @@ class TestTrailingSlashIncluded(TestCase):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == get_regex_pattern(self.urls[idx])
class TestTrailingSlashRemoved(TestCase): def setUp(self):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all() queryset = RouterTestModel.objects.all()
self.router = SimpleRouter(trailing_slash=False) self.router = SimpleRouter(trailing_slash=False)
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes',NoteViewSet)
self.urls = self.router.urls self.urls=self.router.urls
def test_urls_can_have_trailing_slash_removed(self): def test_urls_can_have_trailing_slash_removed(self):
expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$'] expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
@ -294,40 +291,34 @@ class TestTrailingSlashRemoved(TestCase):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == get_regex_pattern(self.urls[idx])
class TestNameableRoot(TestCase): def setUp(self):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all() queryset = RouterTestModel.objects.all()
self.router = DefaultRouter() self.router = DefaultRouter()
self.router.root_view_name = 'nameable-root' self.router.root_view_name='nameable-root'
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes',NoteViewSet)
self.urls = self.router.urls self.urls=self.router.urls
def test_router_has_custom_name(self): def test_router_has_custom_name(self):
expected = 'nameable-root' expected = 'nameable-root'
assert expected == self.urls[-1].name assert expected == self.urls[-1].name
class TestActionKeywordArgs(TestCase): """
"""
Ensure keyword arguments passed in the `@action` decorator Ensure keyword arguments passed in the `@action` decorator
are properly handled. Refs #940. are properly handled. Refs #940.
""" """
defsetUp(self):
def setUp(self):
class TestViewSet(viewsets.ModelViewSet): class TestViewSet(viewsets.ModelViewSet):
permission_classes = [] permission_classes = []
@action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny]) @action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny])
def custom(self, request, *args, **kwargs): def custom(self, request, *args, **kwargs):
return Response({ return Response({ 'permission_classes': self.permission_classes })
'permission_classes': self.permission_classes
})
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'test', TestViewSet, basename='test') self.router.register(r'test',TestViewSet,basename='test')
self.view = self.router.urls[-1].callback self.view=self.router.urls[-1].callback
def test_action_kwargs(self): def test_action_kwargs(self):
request = factory.post('/test/0/custom/') request = factory.post('/test/0/custom/')
@ -335,25 +326,20 @@ class TestActionKeywordArgs(TestCase):
assert response.data == {'permission_classes': [permissions.AllowAny]} assert response.data == {'permission_classes': [permissions.AllowAny]}
class TestActionAppliedToExistingRoute(TestCase): """
"""
Ensure `@action` decorator raises an except when applied Ensure `@action` decorator raises an except when applied
to an existing route to an existing route
""" """
deftest_exception_raised_when_action_applied_to_existing_route(self):
def test_exception_raised_when_action_applied_to_existing_route(self):
class TestViewSet(viewsets.ModelViewSet): class TestViewSet(viewsets.ModelViewSet):
@action(methods=['post'], detail=True) @action(methods=['post'], detail=True)
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
return Response({ return Response({ 'hello': 'world' })
'hello': 'world'
})
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'test', TestViewSet, basename='test') self.router.register(r'test',TestViewSet,basename='test')
withpytest.raises(ImproperlyConfigured):
with pytest.raises(ImproperlyConfigured):
self.router.urls self.router.urls
@ -390,28 +376,17 @@ class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet):
pass pass
class TestDynamicListAndDetailRouter(TestCase): def setUp(self):
def setUp(self):
self.router = SimpleRouter() self.router = SimpleRouter()
def _test_list_and_detail_route_decorators(self, viewset): def _test_list_and_detail_route_decorators(self, viewset):
routes = self.router.get_routes(viewset) routes = self.router.get_routes(viewset)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path') MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
# Make sure all these endpoints exist and none have been clobbered for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), MethodNamesMap('list_route_get', 'list_route_get'), MethodNamesMap('list_route_post', 'list_route_post'), MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), MethodNamesMap('detail_route_get', 'detail_route_get'), MethodNamesMap('detail_route_post', 'detail_route_post') ]):
for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'),
MethodNamesMap('list_route_get', 'list_route_get'),
MethodNamesMap('list_route_post', 'list_route_post'),
MethodNamesMap('detail_custom_route_get', 'detail_custom-route'),
MethodNamesMap('detail_route_get', 'detail_route_get'),
MethodNamesMap('detail_route_post', 'detail_route_post')
]):
route = decorator_routes[i] route = decorator_routes[i]
# check url listing
method_name = endpoint.method_name method_name = endpoint.method_name
url_path = endpoint.url_path url_path = endpoint.url_path
if method_name.startswith('list_'): if method_name.startswith('list_'):
assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path) assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
else: else:
@ -490,69 +465,52 @@ class TestViewInitkwargs(URLPatternsTestCase, TestCase):
assert initkwargs['basename'] == 'routertestmodel' assert initkwargs['basename'] == 'routertestmodel'
class TestBaseNameRename(TestCase):
def test_base_name_and_basename_assertion(self): def test_base_name_and_basename_assertion(self):
router = SimpleRouter() router = SimpleRouter()
msg = "Do not provide both the `basename` and `base_name` arguments." msg = "Do not provide both the `basename` and `base_name` arguments."
with warnings.catch_warnings(record=True) as w, \ with warnings.catch_warnings(record=True) as w, self.assertRaisesMessage(AssertionError, msg):
self.assertRaisesMessage(AssertionError, msg):
warnings.simplefilter('always') warnings.simplefilter('always')
router.register('mock', MockViewSet, 'mock', base_name='mock') router.register('mock', MockViewSet, 'mock', base_name='mock')
msg = "The `base_name` argument is pending deprecation in favor of `basename`." msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assert len(w) == 1 assertlen(w)==1
assert str(w[0].message) == msg assertstr(w[0].message)==msg
def test_base_name_argument_deprecation(self): def test_base_name_argument_deprecation(self):
router = SimpleRouter() router = SimpleRouter()
with pytest.warns(RemovedInDRF311Warning) as w: with pytest.warns(RemovedInDRF311Warning) as w:
warnings.simplefilter('always') warnings.simplefilter('always')
router.register('mock', MockViewSet, base_name='mock') router.register('mock', MockViewSet, base_name='mock')
msg = "The `base_name` argument is pending deprecation in favor of `basename`." msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assert len(w) == 1 assertlen(w)==1
assert str(w[0].message) == msg assertstr(w[0].message)==msg
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'mock'),]
('mock', MockViewSet, 'mock'),
]
def test_basename_argument_no_warnings(self): def test_basename_argument_no_warnings(self):
router = SimpleRouter() router = SimpleRouter()
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always') warnings.simplefilter('always')
router.register('mock', MockViewSet, basename='mock') router.register('mock', MockViewSet, basename='mock')
assert len(w) == 0 assert len(w) == 0
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'mock'),]
('mock', MockViewSet, 'mock'),
]
def test_get_default_base_name_deprecation(self): def test_get_default_base_name_deprecation(self):
msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`." msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`."
# Class definition should raise a warning
with pytest.warns(RemovedInDRF311Warning) as w: with pytest.warns(RemovedInDRF311Warning) as w:
warnings.simplefilter('always') warnings.simplefilter('always')
class CustomRouter(SimpleRouter): class CustomRouter(SimpleRouter):
def get_default_base_name(self, viewset): def get_default_base_name(self, viewset):
return 'foo' return 'foo'
assert len(w) == 1 assert len(w) == 1
assert str(w[0].message) == msg assertstr(w[0].message)==msg
withwarnings.catch_warnings(record=True)asw:
# Deprecated method implementation should still be called
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always') warnings.simplefilter('always')
router = CustomRouter() router = CustomRouter()
router.register('mock', MockViewSet) router.register('mock', MockViewSet)
assert len(w) == 0 assert len(w) == 0
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'foo'),]
('mock', MockViewSet, 'foo'),
]

File diff suppressed because one or more lines are too long

View File

@ -4,14 +4,10 @@ Tests to cover bulk create and update using serializers.
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
"""
class BulkCreateSerializerTests(TestCase):
"""
Creating multiple instances using serializers. Creating multiple instances using serializers.
""" """
defsetUp(self):
def setUp(self):
class BookSerializer(serializers.Serializer): class BookSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
title = serializers.CharField(max_length=100) title = serializers.CharField(max_length=100)
@ -23,23 +19,7 @@ class BulkCreateSerializerTests(TestCase):
""" """
Correct bulk update serialization should return the input data. Correct bulk update serialization should return the input data.
""" """
data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 2, 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
data = [
{
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 2,
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is True assert serializer.is_valid() is True
assert serializer.validated_data == data assert serializer.validated_data == data
@ -49,28 +29,8 @@ class BulkCreateSerializerTests(TestCase):
""" """
Incorrect bulk create serialization should return errors. Incorrect bulk create serialization should return errors.
""" """
data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 'foo', 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
data = [ expected_errors = [ {}, {}, {'id': ['A valid integer is required.']} ]
{
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 'foo',
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
expected_errors = [
{},
{},
{'id': ['A valid integer is required.']}
]
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
assert serializer.errors == expected_errors assert serializer.errors == expected_errors
@ -83,14 +43,8 @@ class BulkCreateSerializerTests(TestCase):
data = ['foo', 'bar', 'baz'] data = ['foo', 'bar', 'baz']
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
message = 'Invalid data. Expected a dictionary, but got str.' message = 'Invalid data. Expected a dictionary, but got str.'
expected_errors = [ expected_errors = [ {'non_field_errors': [message]}, {'non_field_errors': [message]}, {'non_field_errors': [message]} ]
{'non_field_errors': [message]},
{'non_field_errors': [message]},
{'non_field_errors': [message]}
]
assert serializer.errors == expected_errors assert serializer.errors == expected_errors
def test_invalid_single_datatype(self): def test_invalid_single_datatype(self):
@ -100,9 +54,7 @@ class BulkCreateSerializerTests(TestCase):
data = 123 data = 123
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
assert serializer.errors == expected_errors assert serializer.errors == expected_errors
def test_invalid_single_object(self): def test_invalid_single_object(self):
@ -110,14 +62,8 @@ class BulkCreateSerializerTests(TestCase):
Data containing only a single object, instead of a list of objects Data containing only a single object, instead of a list of objects
should return errors. should return errors.
""" """
data = { data = { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
assert serializer.errors == expected_errors assert serializer.errors == expected_errors