unittest to pytest assertion convertions

This commit is contained in:
Asif Saif Uddin 2019-05-02 13:36:18 +06:00
parent 375f1b59c6
commit 3d28279d05
31 changed files with 3503 additions and 5311 deletions

2
.gitignore vendored
View File

@ -2,7 +2,7 @@
*.db *.db
*~ *~
.* .*
.py.bak *.py.bak
/site/ /site/

View File

@ -11,6 +11,7 @@ from rest_framework.response import Response
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from tests.models import BasicModel from tests.models import BasicModel
import pytest
factory = APIRequestFactory() factory = APIRequestFactory()
@ -52,51 +53,49 @@ urlpatterns = (
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionTests(TestCase): def setUp(self):
def setUp(self): self.view = BasicView.as_view()
self.view = BasicView.as_view() connections.databases['default']['ATOMIC_REQUESTS'] = True
connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_no_exception_commit_transaction(self): def test_no_exception_commit_transaction(self):
request = factory.post('/') request = factory.post('/')
with self.assertNumQueries(1):
with self.assertNumQueries(1): response = self.view(request)
response = self.view(request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
assert BasicModel.objects.count() == 1 assertBasicModel.objects.count()==1
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionErrorTests(TestCase): def setUp(self):
def setUp(self): self.view = ErrorView.as_view()
self.view = ErrorView.as_view() connections.databases['default']['ATOMIC_REQUESTS'] = True
connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_generic_exception_delegate_transaction_management(self): def test_generic_exception_delegate_transaction_management(self):
""" """
Transaction is eventually managed by outer-most transaction atomic Transaction is eventually managed by outer-most transaction atomic
block. DRF do not try to interfere here. block. DRF do not try to interfere here.
We let django deal with the transaction when it will catch the Exception. We let django deal with the transaction when it will catch the Exception.
""" """
request = factory.post('/') request = factory.post('/')
with self.assertNumQueries(3): with self.assertNumQueries(3):
# 1 - begin savepoint # 1 - begin savepoint
# 2 - insert # 2 - insert
# 3 - release savepoint # 3 - release savepoint
with transaction.atomic(): with transaction.atomic():
self.assertRaises(Exception, self.view, request) with pytest.raises(Exception):
assert not transaction.get_rollback() self.view(request)
assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
@ -104,30 +103,29 @@ class DBTransactionErrorTests(TestCase):
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionAPIExceptionTests(TestCase): def setUp(self):
def setUp(self): self.view = APIExceptionView.as_view()
self.view = APIExceptionView.as_view() connections.databases['default']['ATOMIC_REQUESTS'] = True
connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_api_exception_rollback_transaction(self): def test_api_exception_rollback_transaction(self):
""" """
Transaction is rollbacked by our transaction atomic block. Transaction is rollbacked by our transaction atomic block.
""" """
request = factory.post('/') request = factory.post('/')
num_queries = 4 if connection.features.can_release_savepoints else 3 num_queries = 4 if connection.features.can_release_savepoints else 3
with self.assertNumQueries(num_queries): with self.assertNumQueries(num_queries):
# 1 - begin savepoint # 1 - begin savepoint
# 2 - insert # 2 - insert
# 3 - rollback savepoint # 3 - rollback savepoint
# 4 - release savepoint # 4 - release savepoint
with transaction.atomic(): with transaction.atomic():
response = self.view(request) response = self.view(request)
assert transaction.get_rollback() assert transaction.get_rollback()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.count() == 0 assertBasicModel.objects.count()==0
@unittest.skipUnless( @unittest.skipUnless(

View File

@ -13,74 +13,68 @@ from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
def setUp(self):
class AuthTokenTests(TestCase): self.site = site
self.user = User.objects.create_user(username='test_user')
def setUp(self): self.token = Token.objects.create(key='test token', user=self.user)
self.site = site
self.user = User.objects.create_user(username='test_user')
self.token = Token.objects.create(key='test token', user=self.user)
def test_model_admin_displayed_fields(self): def test_model_admin_displayed_fields(self):
mock_request = object() mock_request = object()
token_admin = TokenAdmin(self.token, self.site) token_admin = TokenAdmin(self.token, self.site)
assert token_admin.get_fields(mock_request) == ('user',) assert token_admin.get_fields(mock_request) == ('user',)
def test_token_string_representation(self): def test_token_string_representation(self):
assert str(self.token) == 'test token' assert str(self.token) == 'test token'
def test_validate_raise_error_if_no_credentials_provided(self): def test_validate_raise_error_if_no_credentials_provided(self):
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
AuthTokenSerializer().validate({}) AuthTokenSerializer().validate({})
def test_whitespace_in_password(self): def test_whitespace_in_password(self):
data = {'username': self.user.username, 'password': 'test pass '} data = {'username': self.user.username, 'password': 'test pass '}
self.user.set_password(data['password']) self.user.set_password(data['password'])
self.user.save() self.user.save()
assert AuthTokenSerializer(data=data).is_valid() assert AuthTokenSerializer(data=data).is_valid()
class AuthTokenCommandTests(TestCase):
def setUp(self): def setUp(self):
self.site = site self.site = site
self.user = User.objects.create_user(username='test_user') self.user = User.objects.create_user(username='test_user')
def test_command_create_user_token(self): def test_command_create_user_token(self):
token = AuthTokenCommand().create_user_token(self.user.username, False) token = AuthTokenCommand().create_user_token(self.user.username, False)
assert token is not None assert token is not None
token_saved = Token.objects.first() token_saved = Token.objects.first()
assert token.key == token_saved.key assert token.key == token_saved.key
def test_command_create_user_token_invalid_user(self): def test_command_create_user_token_invalid_user(self):
with pytest.raises(User.DoesNotExist): with pytest.raises(User.DoesNotExist):
AuthTokenCommand().create_user_token('not_existing_user', False) AuthTokenCommand().create_user_token('not_existing_user', False)
def test_command_reset_user_token(self): def test_command_reset_user_token(self):
AuthTokenCommand().create_user_token(self.user.username, False) AuthTokenCommand().create_user_token(self.user.username, False)
first_token_key = Token.objects.first().key first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, True) AuthTokenCommand().create_user_token(self.user.username, True)
second_token_key = Token.objects.first().key second_token_key = Token.objects.first().key
assert first_token_key != second_token_key
assert first_token_key != second_token_key
def test_command_do_not_reset_user_token(self): def test_command_do_not_reset_user_token(self):
AuthTokenCommand().create_user_token(self.user.username, False) AuthTokenCommand().create_user_token(self.user.username, False)
first_token_key = Token.objects.first().key first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, False) AuthTokenCommand().create_user_token(self.user.username, False)
second_token_key = Token.objects.first().key second_token_key = Token.objects.first().key
assert first_token_key == second_token_key
assert first_token_key == second_token_key
def test_command_raising_error_for_invalid_user(self): def test_command_raising_error_for_invalid_user(self):
out = StringIO() out = StringIO()
with pytest.raises(CommandError): with pytest.raises(CommandError):
call_command('drf_create_token', 'not_existing_user', stdout=out) call_command('drf_create_token', 'not_existing_user', stdout=out)
def test_command_output(self): def test_command_output(self):
out = StringIO() out = StringIO()
call_command('drf_create_token', self.user.username, stdout=out) call_command('drf_create_token', self.user.username, stdout=out)
token_saved = Token.objects.first() token_saved = Token.objects.first()
self.assertIn('Generated token', out.getvalue()) assert 'Generated token' in out.getvalue()
self.assertIn(self.user.username, out.getvalue()) assert self.user.username in out.getvalue()
self.assertIn(token_saved.key, out.getvalue()) assert token_saved.key in out.getvalue()

View File

@ -17,307 +17,270 @@ from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView from rest_framework.views import APIView
def setUp(self):
class DecoratorTestCase(TestCase): self.factory = APIRequestFactory()
def setUp(self):
self.factory = APIRequestFactory()
def _finalize_response(self, request, response, *args, **kwargs): def _finalize_response(self, request, response, *args, **kwargs):
response.request = request response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs) return APIView.finalize_response(self, request, response, *args, **kwargs)
def test_api_view_incorrect(self): def test_api_view_incorrect(self):
""" """
If @api_view is not applied correct, we should raise an assertion. If @api_view is not applied correct, we should raise an assertion.
""" """
@api_view
def view(request):
return Response()
@api_view request = self.factory.get('/')
withpytest.raises(AssertionError):
view(request)
def test_api_view_incorrect_arguments(self):
"""
If @api_view is missing arguments, we should raise an assertion.
"""
with pytest.raises(AssertionError):
@api_view('GET')
def view(request): def view(request):
return Response() return Response()
request = self.factory.get('/')
self.assertRaises(AssertionError, view, request)
def test_api_view_incorrect_arguments(self):
"""
If @api_view is missing arguments, we should raise an assertion.
"""
with self.assertRaises(AssertionError):
@api_view('GET')
def view(request):
return Response()
def test_calling_method(self): def test_calling_method(self):
@api_view(['GET']) @api_view(['GET'])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
request = self.factory.post('/') response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_put_method(self): def test_calling_put_method(self):
@api_view(['GET', 'PUT']) @api_view(['GET', 'PUT'])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.put('/') request = self.factory.put('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
request = self.factory.post('/') response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_patch_method(self): def test_calling_patch_method(self):
@api_view(['GET', 'PATCH']) @api_view(['GET', 'PATCH'])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.patch('/') request = self.factory.patch('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
request = self.factory.post('/') response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_renderer_classes(self): def test_renderer_classes(self):
@api_view(['GET']) @api_view(['GET'])
@renderer_classes([JSONRenderer]) @renderer_classes([JSONRenderer])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
response = view(request) response=view(request)
assert isinstance(response.accepted_renderer, JSONRenderer) assertisinstance(response.accepted_renderer,JSONRenderer)
def test_parser_classes(self): def test_parser_classes(self):
@api_view(['GET']) @api_view(['GET'])
@parser_classes([JSONParser]) @parser_classes([JSONParser])
def view(request): def view(request):
assert len(request.parsers) == 1 assert len(request.parsers) == 1
assert isinstance(request.parsers[0], JSONParser) assert isinstance(request.parsers[0], JSONParser)
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
view(request) view(request)
def test_authentication_classes(self): def test_authentication_classes(self):
@api_view(['GET']) @api_view(['GET'])
@authentication_classes([BasicAuthentication]) @authentication_classes([BasicAuthentication])
def view(request): def view(request):
assert len(request.authenticators) == 1 assert len(request.authenticators) == 1
assert isinstance(request.authenticators[0], BasicAuthentication) assert isinstance(request.authenticators[0], BasicAuthentication)
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
view(request) view(request)
def test_permission_classes(self): def test_permission_classes(self):
@api_view(['GET']) @api_view(['GET'])
@permission_classes([IsAuthenticated]) @permission_classes([IsAuthenticated])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN assertresponse.status_code==status.HTTP_403_FORBIDDEN
def test_throttle_classes(self): def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle): class OncePerDayUserThrottle(UserRateThrottle):
rate = '1/day' rate = '1/day'
@api_view(['GET']) @api_view(['GET'])
@throttle_classes([OncePerDayUserThrottle]) @throttle_classes([OncePerDayUserThrottle])
def view(request): defview(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
response = view(request) response=view(request)
assert response.status_code == status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_429_TOO_MANY_REQUESTS
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
def test_schema(self): def test_schema(self):
""" """
Checks CustomSchema class is set on view Checks CustomSchema class is set on view
""" """
class CustomSchema(AutoSchema): class CustomSchema(AutoSchema):
pass pass
@api_view(['GET']) @api_view(['GET'])
@schema(CustomSchema()) @schema(CustomSchema())
def view(request): defview(request):
return Response({}) return Response({})
assert isinstance(view.cls.schema, CustomSchema) assert isinstance(view.cls.schema, CustomSchema)
class ActionDecoratorTestCase(TestCase):
def test_defaults(self): def test_defaults(self):
@action(detail=True) @action(detail=True)
def test_action(request): def test_action(request):
"""Description""" """Description"""
assert test_action.mapping == {'get': 'test_action'} assert test_action.mapping == {'get': 'test_action'}
assert test_action.detail is True asserttest_action.detailisTrue
assert test_action.url_path == 'test_action' asserttest_action.url_path=='test_action'
assert test_action.url_name == 'test-action' asserttest_action.url_name=='test-action'
assert test_action.kwargs == { asserttest_action.kwargs=={'name':'Test action','description':'Description',}
'name': 'Test action',
'description': 'Description',
}
def test_detail_required(self): def test_detail_required(self):
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
@action() @action()
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
assert str(excinfo.value) == "@action() missing required argument: 'detail'" assert str(excinfo.value) == "@action() missing required argument: 'detail'"
def test_method_mapping_http_methods(self): def test_method_mapping_http_methods(self):
# All HTTP methods should be mappable # All HTTP methods should be mappable
@action(detail=False, methods=[]) @action(detail=False, methods=[])
def test_action(): def test_action():
raise NotImplementedError raise NotImplementedError
for name in APIView.http_method_names: for name in APIView.http_method_names:
def method(): def method():
raise NotImplementedError raise NotImplementedError
method.__name__ = name method.__name__ = name
getattr(test_action.mapping, name)(method) getattr(test_action.mapping,name)(method)
# ensure the mapping returns the correct method name # ensure the mapping returns the correct method name
for name in APIView.http_method_names: for name in APIView.http_method_names:
assert test_action.mapping[name] == name assert test_action.mapping[name] == name
def test_view_name_kwargs(self): def test_view_name_kwargs(self):
""" """
'name' and 'suffix' are mutually exclusive kwargs used for generating 'name' and 'suffix' are mutually exclusive kwargs used for generating
a view's display name. a view's display name.
""" """
# by default, generate name from method @action(detail=True)
@action(detail=True) def test_action(request):
def test_action(request): raise NotImplementedError
raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {'description':None,'name':'Test action',}
'description': None, @action(detail=True,name='test name')
'name': 'Test action', deftest_action(request):
} raise NotImplementedError
# name kwarg supersedes name generation assert test_action.kwargs == {'description':None,'name':'test name',}
@action(detail=True, name='test name') @action(detail=True,suffix='Suffix')
def test_action(request): deftest_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {'description':None,'suffix':'Suffix',}
'description': None, withpytest.raises(TypeError)asexcinfo:
'name': 'test name', action(detail=True, name='test name', suffix='Suffix')
}
# suffix kwarg supersedes name generation
@action(detail=True, suffix='Suffix')
def test_action(request):
raise NotImplementedError
assert test_action.kwargs == {
'description': None,
'suffix': 'Suffix',
}
# name + suffix is a conflict.
with pytest.raises(TypeError) as excinfo:
action(detail=True, name='test name', suffix='Suffix')
assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments." assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments."
def test_method_mapping(self): def test_method_mapping(self):
@action(detail=False) @action(detail=False)
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
@test_action.mapping.post @test_action.mapping.post
def test_action_post(request): deftest_action_post(request):
raise NotImplementedError raise NotImplementedError
# The secondary handler methods should not have the action attributes # The secondary handler methods should not have the action attributes
for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']: for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']:
assert hasattr(test_action, name) and not hasattr(test_action_post, name) assert hasattr(test_action, name) and not hasattr(test_action_post, name)
def test_method_mapping_already_mapped(self): def test_method_mapping_already_mapped(self):
@action(detail=True) @action(detail=True)
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
msg = "Method 'get' has already been mapped to '.test_action'." msg = "Method 'get' has already been mapped to '.test_action'."
with self.assertRaisesMessage(AssertionError, msg): withself.assertRaisesMessage(AssertionError,msg):
@test_action.mapping.get @test_action.mapping.get
def test_action_get(request): def test_action_get(request):
raise NotImplementedError raise NotImplementedError
def test_method_mapping_overwrite(self): def test_method_mapping_overwrite(self):
@action(detail=True) @action(detail=True)
def test_action():
raise NotImplementedError
msg = ("Method mapping does not behave like the property decorator. You ""cannot use the same method name for each mapping declaration.")
withself.assertRaisesMessage(AssertionError,msg):
@test_action.mapping.post
def test_action(): def test_action():
raise NotImplementedError raise NotImplementedError
msg = ("Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.")
with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.post
def test_action():
raise NotImplementedError
def test_detail_route_deprecation(self): def test_detail_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record: with pytest.warns(RemovedInDRF310Warning) as record:
@detail_route() @detail_route()
def view(request): def view(request):
raise NotImplementedError raise NotImplementedError
assert len(record) == 1 assert len(record) == 1
assert str(record[0].message) == ( assertstr(record[0].message)==("`detail_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=True)` instead.")
"`detail_route` is deprecated and will be removed in "
"3.10 in favor of `action`, which accepts a `detail` bool. Use "
"`@action(detail=True)` instead."
)
def test_list_route_deprecation(self): def test_list_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record: with pytest.warns(RemovedInDRF310Warning) as record:
@list_route() @list_route()
def view(request): def view(request):
raise NotImplementedError raise NotImplementedError
assert len(record) == 1 assert len(record) == 1
assert str(record[0].message) == ( assertstr(record[0].message)==("`list_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=False)` instead.")
"`list_route` is deprecated and will be removed in "
"3.10 in favor of `action`, which accepts a `detail` bool. Use "
"`@action(detail=False)` instead."
)
def test_route_url_name_from_path(self): def test_route_url_name_from_path(self):
# pre-3.8 behavior was to base the `url_name` off of the `url_path` # pre-3.8 behavior was to base the `url_name` off of the `url_path`
with pytest.warns(RemovedInDRF310Warning): with pytest.warns(RemovedInDRF310Warning):
@list_route(url_path='foo_bar') @list_route(url_path='foo_bar')
def view(request): def view(request):
raise NotImplementedError raise NotImplementedError
assert view.url_path == 'foo_bar' assert view.url_path == 'foo_bar'
assert view.url_name == 'foo-bar' assertview.url_name=='foo-bar'

View File

@ -69,37 +69,34 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
</code></pre> </code></pre>
<p>indented</p> <p>indented</p>
<h2 id="hash-style-header">hash style header</h2>%s""" <h2 id="hash-style-header">hash style header</h2>%s"""
def test_view_name_uses_class_name(self):
"""
class TestViewNamesAndDescriptions(TestCase):
def test_view_name_uses_class_name(self):
"""
Ensure view names are based on the class name. Ensure view names are based on the class name.
""" """
class MockView(APIView): class MockView(APIView):
pass pass
assert MockView().get_view_name() == 'Mock' assert MockView().get_view_name() == 'Mock'
def test_view_name_uses_name_attribute(self): def test_view_name_uses_name_attribute(self):
class MockView(APIView): class MockView(APIView):
name = 'Foo' name = 'Foo'
assert MockView().get_view_name() == 'Foo' assert MockView().get_view_name() == 'Foo'
def test_view_name_uses_suffix_attribute(self): def test_view_name_uses_suffix_attribute(self):
class MockView(APIView): class MockView(APIView):
suffix = 'List' suffix = 'List'
assert MockView().get_view_name() == 'Mock List' assert MockView().get_view_name() == 'Mock List'
def test_view_name_preferences_name_over_suffix(self): def test_view_name_preferences_name_over_suffix(self):
class MockView(APIView): class MockView(APIView):
name = 'Foo' name = 'Foo'
suffix = 'List' suffix = 'List'
assert MockView().get_view_name() == 'Foo' assert MockView().get_view_name() == 'Foo'
def test_view_description_uses_docstring(self): def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring.""" """Ensure view descriptions are based on the docstring."""
class MockView(APIView): class MockView(APIView):
"""an example docstring """an example docstring
==================== ====================
* list * list
@ -124,64 +121,53 @@ class TestViewNamesAndDescriptions(TestCase):
assert MockView().get_view_description() == DESCRIPTION assert MockView().get_view_description() == DESCRIPTION
def test_view_description_uses_description_attribute(self): def test_view_description_uses_description_attribute(self):
class MockView(APIView): class MockView(APIView):
description = 'Foo' description = 'Foo'
assert MockView().get_view_description() == 'Foo' assert MockView().get_view_description() == 'Foo'
def test_view_description_allows_empty_description(self): def test_view_description_allows_empty_description(self):
class MockView(APIView): class MockView(APIView):
"""Description.""" """Description."""
description = '' description = ''
assert MockView().get_view_description() == '' assert MockView().get_view_description() == ''
def test_view_description_can_be_empty(self): def test_view_description_can_be_empty(self):
""" """
Ensure that if a view has no docstring, Ensure that if a view has no docstring,
then it's description is the empty string. then it's description is the empty string.
""" """
class MockView(APIView): class MockView(APIView):
pass pass
assert MockView().get_view_description() == '' assert MockView().get_view_description() == ''
def test_view_description_can_be_promise(self): def test_view_description_can_be_promise(self):
""" """
Ensure a view may have a docstring that is actually a lazily evaluated Ensure a view may have a docstring that is actually a lazily evaluated
class that can be converted to a string. class that can be converted to a string.
See: https://github.com/encode/django-rest-framework/issues/1708 See: https://github.com/encode/django-rest-framework/issues/1708
""" """
# use a mock object instead of gettext_lazy to ensure that we can't end class MockLazyStr:
# up with a test case string in our l10n catalog def __init__(self, string):
self.s = string
class MockLazyStr:
def __init__(self, string):
self.s = string
def __str__(self): def __str__(self):
return self.s return self.s
class MockView(APIView): class MockView(APIView):
__doc__ = MockLazyStr("a gettext string") __doc__ = MockLazyStr("a gettext string")
assert MockView().get_view_description() == 'a gettext string' assert MockView().get_view_description() == 'a gettext string'
def test_markdown(self): def test_markdown(self):
""" """
Ensure markdown to HTML works as expected. Ensure markdown to HTML works as expected.
""" """
if apply_markdown: if apply_markdown:
md_applied = apply_markdown(DESCRIPTION) md_applied = apply_markdown(DESCRIPTION)
gte_21_match = ( gte_21_match = ( md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
md_applied == ( lt_21_match = ( md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or assert gte_21_match or lt_21_match
md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
lt_21_match = (
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
assert gte_21_match or lt_21_match
def test_dedent_tabs(): def test_dedent_tabs():

View File

@ -15,81 +15,79 @@ class MockList:
return [1, 2, 3] return [1, 2, 3]
class JSONEncoderTests(TestCase): """
"""
Tests the JSONEncoder method Tests the JSONEncoder method
""" """
defsetUp(self):
def setUp(self): self.encoder = JSONEncoder()
self.encoder = JSONEncoder()
def test_encode_decimal(self): def test_encode_decimal(self):
""" """
Tests encoding a decimal Tests encoding a decimal
""" """
d = Decimal(3.14) d = Decimal(3.14)
assert self.encoder.default(d) == float(d) assert self.encoder.default(d) == float(d)
def test_encode_datetime(self): def test_encode_datetime(self):
""" """
Tests encoding a datetime object Tests encoding a datetime object
""" """
current_time = datetime.now() current_time = datetime.now()
assert self.encoder.default(current_time) == current_time.isoformat() assert self.encoder.default(current_time) == current_time.isoformat()
current_time_utc = current_time.replace(tzinfo=utc) current_time_utc = current_time.replace(tzinfo=utc)
assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z' assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z'
def test_encode_time(self): def test_encode_time(self):
""" """
Tests encoding a timezone Tests encoding a timezone
""" """
current_time = datetime.now().time() current_time = datetime.now().time()
assert self.encoder.default(current_time) == current_time.isoformat() assert self.encoder.default(current_time) == current_time.isoformat()
def test_encode_time_tz(self): def test_encode_time_tz(self):
""" """
Tests encoding a timezone aware timestamp Tests encoding a timezone aware timestamp
""" """
current_time = datetime.now().time() current_time = datetime.now().time()
current_time = current_time.replace(tzinfo=utc) current_time = current_time.replace(tzinfo=utc)
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.encoder.default(current_time) self.encoder.default(current_time)
def test_encode_date(self): def test_encode_date(self):
""" """
Tests encoding a date object Tests encoding a date object
""" """
current_date = date.today() current_date = date.today()
assert self.encoder.default(current_date) == current_date.isoformat() assert self.encoder.default(current_date) == current_date.isoformat()
def test_encode_timedelta(self): def test_encode_timedelta(self):
""" """
Tests encoding a timedelta object Tests encoding a timedelta object
""" """
delta = timedelta(hours=1) delta = timedelta(hours=1)
assert self.encoder.default(delta) == str(delta.total_seconds()) assert self.encoder.default(delta) == str(delta.total_seconds())
def test_encode_uuid(self): def test_encode_uuid(self):
""" """
Tests encoding a UUID object Tests encoding a UUID object
""" """
unique_id = uuid4() unique_id = uuid4()
assert self.encoder.default(unique_id) == str(unique_id) assert self.encoder.default(unique_id) == str(unique_id)
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_encode_coreapi_raises_error(self): deftest_encode_coreapi_raises_error(self):
""" """
Tests encoding a coreapi objects raises proper error Tests encoding a coreapi objects raises proper error
""" """
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
self.encoder.default(coreapi.Document()) self.encoder.default(coreapi.Document())
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
self.encoder.default(coreapi.Error()) self.encoder.default(coreapi.Error())
def test_encode_object_with_tolist(self): def test_encode_object_with_tolist(self):
""" """
Tests encoding a object with tolist method Tests encoding a object with tolist method
""" """
foo = MockList() foo = MockList()
assert self.encoder.default(foo) == [1, 2, 3] assert self.encoder.default(foo) == [1, 2, 3]

View File

@ -7,89 +7,58 @@ from rest_framework.exceptions import (
server_error server_error
) )
def test_get_error_details(self):
class ExceptionTestCase(TestCase): example = "string"
lazy_example = _(example)
def test_get_error_details(self): assert _get_error_details(lazy_example) == example
assert isinstance( _get_error_details(lazy_example), ErrorDetail )
example = "string" assert _get_error_details({'nested': lazy_example})['nested'] == example
lazy_example = _(example) assert isinstance( _get_error_details({'nested': lazy_example})['nested'], ErrorDetail )
assert _get_error_details([[lazy_example]])[0][0] == example
assert _get_error_details(lazy_example) == example assert isinstance( _get_error_details([[lazy_example]])[0][0], ErrorDetail )
assert isinstance(
_get_error_details(lazy_example),
ErrorDetail
)
assert _get_error_details({'nested': lazy_example})['nested'] == example
assert isinstance(
_get_error_details({'nested': lazy_example})['nested'],
ErrorDetail
)
assert _get_error_details([[lazy_example]])[0][0] == example
assert isinstance(
_get_error_details([[lazy_example]])[0][0],
ErrorDetail
)
def test_get_full_details_with_throttling(self): def test_get_full_details_with_throttling(self):
exception = Throttled() exception = Throttled()
assert exception.get_full_details() == { assert exception.get_full_details() == { 'message': 'Request was throttled.', 'code': 'throttled'}
'message': 'Request was throttled.', 'code': 'throttled'} exception = Throttled(wait=2)
assert exception.get_full_details() == { 'message': 'Request was throttled. Expected available in {} seconds.'.format(2), 'code': 'throttled'}
exception = Throttled(wait=2) exception = Throttled(wait=2, detail='Slow down!')
assert exception.get_full_details() == { assert exception.get_full_details() == { 'message': 'Slow down! Expected available in {} seconds.'.format(2), 'code': 'throttled'}
'message': 'Request was throttled. Expected available in {} seconds.'.format(2),
'code': 'throttled'}
exception = Throttled(wait=2, detail='Slow down!')
assert exception.get_full_details() == {
'message': 'Slow down! Expected available in {} seconds.'.format(2),
'code': 'throttled'}
class ErrorDetailTests(TestCase):
def test_eq(self): def test_eq(self):
assert ErrorDetail('msg') == ErrorDetail('msg') assert ErrorDetail('msg') == ErrorDetail('msg')
assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code') assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
assert ErrorDetail('msg') == 'msg'
assert ErrorDetail('msg') == 'msg' assert ErrorDetail('msg', 'code') == 'msg'
assert ErrorDetail('msg', 'code') == 'msg'
def test_ne(self): def test_ne(self):
assert ErrorDetail('msg1') != ErrorDetail('msg2') assert ErrorDetail('msg1') != ErrorDetail('msg2')
assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid') assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
assert ErrorDetail('msg1') != 'msg2'
assert ErrorDetail('msg1') != 'msg2' assert ErrorDetail('msg1', 'code') != 'msg2'
assert ErrorDetail('msg1', 'code') != 'msg2'
def test_repr(self): def test_repr(self):
assert repr(ErrorDetail('msg1')) == \ assert repr(ErrorDetail('msg1')) == 'ErrorDetail(string={!r}, code=None)'.format('msg1')
'ErrorDetail(string={!r}, code=None)'.format('msg1') assert repr(ErrorDetail('msg1', 'code')) == 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
assert repr(ErrorDetail('msg1', 'code')) == \
'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
def test_str(self): def test_str(self):
assert str(ErrorDetail('msg1')) == 'msg1' assert str(ErrorDetail('msg1')) == 'msg1'
assert str(ErrorDetail('msg1', 'code')) == 'msg1' assert str(ErrorDetail('msg1', 'code')) == 'msg1'
def test_hash(self): def test_hash(self):
assert hash(ErrorDetail('msg')) == hash('msg') assert hash(ErrorDetail('msg')) == hash('msg')
assert hash(ErrorDetail('msg', 'code')) == hash('msg') assert hash(ErrorDetail('msg', 'code')) == hash('msg')
class TranslationTests(TestCase):
@translation.override('fr') @translation.override('fr')
def test_message(self): deftest_message(self):
# this test largely acts as a sanity test to ensure the translation files are present. # this test largely acts as a sanity test to ensure the translation files are present.
self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.') assert _('A server error occurred.') == 'Une erreur du serveur est survenue.'
self.assertEqual(str(APIException()), 'Une erreur du serveur est survenue.') assert str(APIException()) == 'Une erreur du serveur est survenue.'
def test_server_error(): def test_server_error():

View File

@ -1134,40 +1134,38 @@ class TestNoStringCoercionDecimalField(FieldValues):
) )
class TestLocalizedDecimalField(TestCase): @override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl') deftest_to_internal_value(self):
def test_to_internal_value(self): field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) assert field.to_internal_value('1,1') == Decimal('1.1')
assert field.to_internal_value('1,1') == Decimal('1.1')
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl') @override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
def test_to_representation(self): deftest_to_representation(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
assert field.to_representation(Decimal('1.1')) == '1,1' assert field.to_representation(Decimal('1.1')) == '1,1'
def test_localize_forces_coerce_to_string(self): def test_localize_forces_coerce_to_string(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True) field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True)
assert isinstance(field.to_representation(Decimal('1.1')), str) assert isinstance(field.to_representation(Decimal('1.1')), str)
class TestQuantizedValueForDecimal(TestCase): def test_int_quantized_value_for_decimal(self):
def test_int_quantized_value_for_decimal(self): field = serializers.DecimalField(max_digits=4, decimal_places=2)
field = serializers.DecimalField(max_digits=4, decimal_places=2) value = field.to_internal_value(12).as_tuple()
value = field.to_internal_value(12).as_tuple() expected_digit_tuple = (0, (1, 2, 0, 0), -2)
expected_digit_tuple = (0, (1, 2, 0, 0), -2) assert value == expected_digit_tuple
assert value == expected_digit_tuple
def test_string_quantized_value_for_decimal(self): def test_string_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2) field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value('12').as_tuple() value = field.to_internal_value('12').as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2) expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple assert value == expected_digit_tuple
def test_part_precision_string_quantized_value_for_decimal(self): def test_part_precision_string_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2) field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value('12.0').as_tuple() value = field.to_internal_value('12.0').as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2) expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple assert value == expected_digit_tuple
class TestNoDecimalPlaces(FieldValues): class TestNoDecimalPlaces(FieldValues):
@ -1185,17 +1183,15 @@ class TestNoDecimalPlaces(FieldValues):
field = serializers.DecimalField(max_digits=6, decimal_places=None) field = serializers.DecimalField(max_digits=6, decimal_places=None)
class TestRoundingDecimalField(TestCase): def test_valid_rounding(self):
def test_valid_rounding(self): field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP) assert field.to_representation(Decimal('1.234')) == '1.24'
assert field.to_representation(Decimal('1.234')) == '1.24' field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
assert field.to_representation(Decimal('1.234')) == '1.23'
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
assert field.to_representation(Decimal('1.234')) == '1.23'
def test_invalid_rounding(self): def test_invalid_rounding(self):
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN') serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
assert 'Invalid rounding option' in str(excinfo.value) assert 'Invalid rounding option' in str(excinfo.value)
@ -1369,46 +1365,41 @@ class TestTZWithDateTimeField(FieldValues):
@override_settings(TIME_ZONE='UTC', USE_TZ=True) @override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestDefaultTZDateTimeField(TestCase): """
"""
Test the current/default timezone handling in `DateTimeField`. Test the current/default timezone handling in `DateTimeField`.
""" """
@classmethod
@classmethod defsetup_class(cls):
def setup_class(cls): cls.field = serializers.DateTimeField()
cls.field = serializers.DateTimeField() cls.kolkata = pytz.timezone('Asia/Kolkata')
cls.kolkata = pytz.timezone('Asia/Kolkata')
def test_default_timezone(self): def test_default_timezone(self):
assert self.field.default_timezone() == utc assert self.field.default_timezone() == utc
def test_current_timezone(self): def test_current_timezone(self):
assert self.field.default_timezone() == utc assert self.field.default_timezone() == utc
activate(self.kolkata) activate(self.kolkata)
assert self.field.default_timezone() == self.kolkata assert self.field.default_timezone() == self.kolkata
deactivate() deactivate()
assert self.field.default_timezone() == utc assert self.field.default_timezone() == utc
@pytest.mark.skipif(pytz is None, reason='pytz not installed') @pytest.mark.skipif(pytz is None, reason='pytz not installed')
@override_settings(TIME_ZONE='UTC', USE_TZ=True) @override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestCustomTimezoneForDateTimeField(TestCase):
@classmethod @classmethod
def setup_class(cls): defsetup_class(cls):
cls.kolkata = pytz.timezone('Asia/Kolkata') cls.kolkata = pytz.timezone('Asia/Kolkata')
cls.date_format = '%d/%m/%Y %H:%M' cls.date_format = '%d/%m/%Y %H:%M'
def test_should_render_date_time_in_default_timezone(self): def test_should_render_date_time_in_default_timezone(self):
field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format) field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc) dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
with override(self.kolkata):
with override(self.kolkata): rendered_date = field.to_representation(dt)
rendered_date = field.to_representation(dt)
rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format) rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format)
assertrendered_date==rendered_date_in_timezone
assert rendered_date == rendered_date_in_timezone
class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues): class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):

View File

@ -13,28 +13,25 @@ from rest_framework.compat import coreschema
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
def setUp(self):
self.original_coreapi = filters.coreapi
class BaseFilterTests(TestCase): filters.coreapi = True
def setUp(self): self.filter_backend = filters.BaseFilterBackend()
self.original_coreapi = filters.coreapi
filters.coreapi = True # mock it, because not None value needed
self.filter_backend = filters.BaseFilterBackend()
def tearDown(self): def tearDown(self):
filters.coreapi = self.original_coreapi filters.coreapi = self.original_coreapi
def test_filter_queryset_raises_error(self): def test_filter_queryset_raises_error(self):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
self.filter_backend.filter_queryset(None, None, None) self.filter_backend.filter_queryset(None, None, None)
@pytest.mark.skipif(not coreschema, reason='coreschema is not installed') @pytest.mark.skipif(not coreschema, reason='coreschema is not installed')
def test_get_schema_fields_checks_for_coreapi(self): deftest_get_schema_fields_checks_for_coreapi(self):
filters.coreapi = None filters.coreapi = None
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
self.filter_backend.get_schema_fields({}) self.filter_backend.get_schema_fields({})
filters.coreapi = True filters.coreapi = True
assert self.filter_backend.get_schema_fields({}) == [] assertself.filter_backend.get_schema_fields({})==[]
class SearchFilterModel(models.Model): class SearchFilterModel(models.Model):
@ -48,137 +45,115 @@ class SearchFilterSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterTests(TestCase): def setUp(self):
def setUp(self):
# Sequence of title/text is: # Sequence of title/text is:
# #
# z abc # z abc
# zz bcd # zz bcd
# zzz cde # zzz cde
# ... # ...
for idx in range(10): for idx in range(10):
title = 'z' * (idx + 1) title = 'z' * (idx + 1)
text = ( text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(idx + ord('a')) + SearchFilterModel(title=title, text=text).save()
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModel(title=title, text=text).save()
def test_search(self): def test_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
view = SearchListView.as_view()
request=factory.get('/',{'search':'b'})
response=view(request)
assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
view = SearchListView.as_view()
request=factory.get('/')
response=view(request)
expected=SearchFilterSerializer(SearchFilterModel.objects.all(),many=True).data
assertresponse.data==expected
def test_exact_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
view = SearchListView.as_view()
request=factory.get('/',{'search':'zzz'})
response=view(request)
assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
view = SearchListView.as_view()
request=factory.get('/',{'search':'b'})
response=view(request)
assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
def test_regexp_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('$title', '$text')
view = SearchListView.as_view()
request=factory.get('/',{'search':'z{2} ^b'})
response=view(request)
assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
def test_search_with_nonstandard_search_param(self):
with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
reload_module(filters)
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text') search_fields = ('title', 'text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
view = SearchListView.as_view()
request = factory.get('/')
response = view(request)
expected = SearchFilterSerializer(SearchFilterModel.objects.all(),
many=True).data
assert response.data == expected
def test_exact_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'zzz'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'})
response = view(request)
assert response.data == [
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_regexp_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('$title', '$text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'z{2} ^b'})
response = view(request)
assert response.data == [
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_with_nonstandard_search_param(self):
with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
reload_module(filters)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'query': 'b'}) request=factory.get('/',{'query':'b'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
reload_module(filters) reload_module(filters)
def test_search_with_filter_subclass(self): def test_search_with_filter_subclass(self):
class CustomSearchFilter(filters.SearchFilter): class CustomSearchFilter(filters.SearchFilter):
# Filter that dynamically changes search fields # Filter that dynamically changes search fields
def get_search_fields(self, view, request): def get_search_fields(self, view, request):
if request.query_params.get('title_only'): if request.query_params.get('title_only'):
return ('$title',) return ('$title',)
return super().get_search_fields(view, request) return super().get_search_fields(view, request)
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
filter_backends = (CustomSearchFilter,) filter_backends = (CustomSearchFilter,)
search_fields = ('$title', '$text') search_fields = ('$title', '$text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': r'^\w{3}$'}) request=factory.get('/',{'search':r'^\w{3}$'})
response = view(request) response=view(request)
assert len(response.data) == 10 assertlen(response.data)==10
request=factory.get('/',{'search':r'^\w{3}$','title_only':'true'})
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'}) response=view(request)
response = view(request) assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
class AttributeModel(models.Model): class AttributeModel(models.Model):
@ -196,31 +171,21 @@ class SearchFilterFkSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterFkTests(TestCase):
def test_must_call_distinct(self): def test_must_call_distinct(self):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix] )
SearchFilterModelFk._meta, assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix, "%sattribute__label" % prefix] )
["%stitle" % prefix]
)
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%stitle" % prefix, "%sattribute__label" % prefix]
)
def test_must_call_distinct_restores_meta_for_each_field(self): def test_must_call_distinct_restores_meta_for_each_field(self):
# In this test case the attribute of the fk model comes first in the # In this test case the attribute of the fk model comes first in the
# list of search fields. # list of search fields.
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%sattribute__label" % prefix, "%stitle" % prefix] )
SearchFilterModelFk._meta,
["%sattribute__label" % prefix, "%stitle" % prefix]
)
class SearchFilterModelM2M(models.Model): class SearchFilterModelM2M(models.Model):
@ -235,53 +200,41 @@ class SearchFilterM2MSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterM2MTests(TestCase): def setUp(self):
def setUp(self):
# Sequence of title/text/attributes is: # Sequence of title/text/attributes is:
# #
# z abc [1, 2, 3] # z abc [1, 2, 3]
# zz bcd [1, 2, 3] # zz bcd [1, 2, 3]
# zzz cde [1, 2, 3] # zzz cde [1, 2, 3]
# ... # ...
for idx in range(3): for idx in range(3):
label = 'w' * (idx + 1) label = 'w' * (idx + 1)
AttributeModel.objects.create(label=label) AttributeModel.objects.create(label=label)
for idx in range(10): for idx in range(10):
title = 'z' * (idx + 1) title = 'z' * (idx + 1)
text = ( text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(idx + ord('a')) + SearchFilterModelM2M(title=title, text=text).save()
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModelM2M(title=title, text=text).save()
SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3) SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3)
def test_m2m_search(self): def test_m2m_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModelM2M.objects.all() queryset = SearchFilterModelM2M.objects.all()
serializer_class = SearchFilterM2MSerializer serializer_class = SearchFilterM2MSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text', 'attributes__label') search_fields = ('=title', 'text', 'attributes__label')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'zz'}) request=factory.get('/',{'search':'zz'})
response = view(request) response=view(request)
assert len(response.data) == 1 assertlen(response.data)==1
def test_must_call_distinct(self): def test_must_call_distinct(self):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix] )
SearchFilterModelM2M._meta, assert filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix] )
["%stitle" % prefix]
)
assert filter_.must_call_distinct(
SearchFilterModelM2M._meta,
["%stitle" % prefix, "%sattributes__label" % prefix]
)
class Blog(models.Model): class Blog(models.Model):
@ -300,32 +253,27 @@ class BlogSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterToManyTests(TestCase):
@classmethod @classmethod
def setUpTestData(cls): defsetUpTestData(cls):
b1 = Blog.objects.create(name='Blog 1') b1 = Blog.objects.create(name='Blog 1')
b2 = Blog.objects.create(name='Blog 2') b2 = Blog.objects.create(name='Blog 2')
Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
# Multiple entries on Lennon published in 1979 - distinct should deduplicate Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1)) Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1)) Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
# Entry on Lennon *and* a separate entry in 1979 - should not match
Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
def test_multiple_filter_conditions(self): def test_multiple_filter_conditions(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = Blog.objects.all() queryset = Blog.objects.all()
serializer_class = BlogSerializer serializer_class = BlogSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year') search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'Lennon,1979'}) request=factory.get('/',{'search':'Lennon,1979'})
response = view(request) response=view(request)
assert len(response.data) == 1 assertlen(response.data)==1
class SearchFilterAnnotatedSerializer(serializers.ModelSerializer): class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
@ -336,28 +284,23 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
fields = ('title', 'text', 'title_text') fields = ('title', 'text', 'title_text')
class SearchFilterAnnotatedFieldTests(TestCase): @classmethod
@classmethod defsetUpTestData(cls):
def setUpTestData(cls): SearchFilterModel.objects.create(title='abc', text='def')
SearchFilterModel.objects.create(title='abc', text='def') SearchFilterModel.objects.create(title='ghi', text='jkl')
SearchFilterModel.objects.create(title='ghi', text='jkl')
def test_search_in_annotated_field(self): def test_search_in_annotated_field(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.annotate( queryset = SearchFilterModel.objects.annotate( title_text=Upper( Concat(models.F('title'), models.F('text')) ) ).all()
title_text=Upper( serializer_class = SearchFilterAnnotatedSerializer
Concat(models.F('title'), models.F('text')) filter_backends = (filters.SearchFilter,)
) search_fields = ('title_text',)
).all()
serializer_class = SearchFilterAnnotatedSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title_text',)
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'ABCDEF'}) request=factory.get('/',{'search':'ABCDEF'})
response = view(request) response=view(request)
assert len(response.data) == 1 assertlen(response.data)==1
assert response.data[0]['title_text'] == 'ABCDEF' assertresponse.data[0]['title_text']=='ABCDEF'
class OrderingFilterModel(models.Model): class OrderingFilterModel(models.Model):
@ -403,253 +346,187 @@ class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class OrderingFilterTests(TestCase): def setUp(self):
def setUp(self):
# Sequence of title/text is: # Sequence of title/text is:
# #
# zyx abc # zyx abc
# yxw bcd # yxw bcd
# xwv cde # xwv cde
for idx in range(3): for idx in range(3):
title = ( title = ( chr(ord('z') - idx) + chr(ord('y') - idx) + chr(ord('x') - idx) )
chr(ord('z') - idx) + text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(ord('y') - idx) + OrderingFilterModel(title=title, text=text).save()
chr(ord('x') - idx)
)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
OrderingFilterModel(title=title, text=text).save()
def test_ordering(self): def test_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request=factory.get('/',{'ordering':'text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
def test_reverse_ordering(self): def test_reverse_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-text'}) request=factory.get('/',{'ordering':'-text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_incorrecturl_extrahyphens_ordering(self): def test_incorrecturl_extrahyphens_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '--text'}) request=factory.get('/',{'ordering':'--text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_incorrectfield_ordering(self): def test_incorrectfield_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'foobar'}) request=factory.get('/',{'ordering':'foobar'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_default_ordering(self): def test_default_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('') request=factory.get('')
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_default_ordering_using_string(self): def test_default_ordering_using_string(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = 'title' ordering = 'title'
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('') request=factory.get('')
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_ordering_by_aggregate_field(self): def test_ordering_by_aggregate_field(self):
# create some related models to aggregate order by # create some related models to aggregate order by
num_objs = [2, 5, 3] num_objs = [2, 5, 3]
for obj, num_relateds in zip(OrderingFilterModel.objects.all(), for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs):
num_objs): for _ in range(num_relateds):
for _ in range(num_relateds): new_related = OrderingFilterRelatedModel( related_object=obj )
new_related = OrderingFilterRelatedModel( new_related.save()
related_object=obj
)
new_related.save()
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = 'title' ordering = 'title'
ordering_fields = '__all__' ordering_fields = '__all__'
queryset = OrderingFilterModel.objects.all().annotate( queryset = OrderingFilterModel.objects.all().annotate( models.Count("relateds"))
models.Count("relateds"))
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'relateds__count'}) request=factory.get('/',{'ordering':'relateds__count'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
]
def test_ordering_by_dotted_source(self): def test_ordering_by_dotted_source(self):
for index, obj in enumerate(OrderingFilterModel.objects.all()): for index, obj in enumerate(OrderingFilterModel.objects.all()):
OrderingFilterRelatedModel.objects.create( OrderingFilterRelatedModel.objects.create( related_object=obj, index=index )
related_object=obj,
index=index
)
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
serializer_class = OrderingDottedRelatedSerializer serializer_class = OrderingDottedRelatedSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
queryset = OrderingFilterRelatedModel.objects.all() queryset = OrderingFilterRelatedModel.objects.all()
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'related_object__text'}) request=factory.get('/',{'ordering':'related_object__text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'related_title':'zyx','related_text':'abc','index':0},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'xwv','related_text':'cde','index':2},]
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0}, request=factory.get('/',{'ordering':'-index'})
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1}, response=view(request)
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2}, assertresponse.data==[{'related_title':'xwv','related_text':'cde','index':2},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'zyx','related_text':'abc','index':0},]
]
request = factory.get('/', {'ordering': '-index'})
response = view(request)
assert response.data == [
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
]
def test_ordering_with_nonstandard_ordering_param(self): def test_ordering_with_nonstandard_ordering_param(self):
with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
reload_module(filters) reload_module(filters)
class OrderingListView(generics.ListAPIView):
class OrderingListView(generics.ListAPIView): queryset = OrderingFilterModel.objects.all()
queryset = OrderingFilterModel.objects.all() serializer_class = OrderingFilterSerializer
serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,)
filter_backends = (filters.OrderingFilter,) ordering = ('title',)
ordering = ('title',) ordering_fields = ('text',)
ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'order': 'text'}) request=factory.get('/',{'order':'text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
reload_module(filters) reload_module(filters)
def test_get_template_context(self): def test_get_template_context(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
ordering_fields = '__all__' ordering_fields = '__all__'
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html') request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html')
view = OrderingListView.as_view() view=OrderingListView.as_view()
response = view(request) response=view(request)
self.assertContains(response,'verbose title')
self.assertContains(response, 'verbose title')
def test_ordering_with_overridden_get_serializer_class(self): def test_ordering_with_overridden_get_serializer_class(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
def get_serializer_class(self):
# note: no ordering_fields and serializer_class specified return OrderingFilterSerializer
def get_serializer_class(self):
return OrderingFilterSerializer
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request=factory.get('/',{'ordering':'text'})
response = view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
def test_ordering_with_improper_configuration(self): def test_ordering_with_improper_configuration(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
# note: no ordering_fields and serializer_class # note: no ordering_fields and serializer_class
# or get_serializer_class specified # or get_serializer_class specified
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request=factory.get('/',{'ordering':'text'})
with self.assertRaises(ImproperlyConfigured): withpytest.raises(ImproperlyConfigured):
view(request) view(request)
class SensitiveOrderingFilterModel(models.Model): class SensitiveOrderingFilterModel(models.Model):
@ -684,63 +561,44 @@ class SensitiveDataSerializer3(serializers.ModelSerializer):
fields = ('id', 'user') fields = ('id', 'user')
class SensitiveOrderingFilterTests(TestCase): def setUp(self):
def setUp(self): for idx in range(3):
for idx in range(3): username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] SensitiveOrderingFilterModel(username=username, password=password).save()
SensitiveOrderingFilterModel(username=username, password=password).save()
def test_order_by_serializer_fields(self): def test_order_by_serializer_fields(self):
for serializer_cls in [ for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
SensitiveDataSerializer1, class OrderingListView(generics.ListAPIView):
SensitiveDataSerializer2, queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
SensitiveDataSerializer3 filter_backends = (filters.OrderingFilter,)
]: serializer_class = serializer_cls
class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-username'}) request=factory.get('/',{'ordering':'-username'})
response = view(request) response=view(request)
ifserializer_cls==SensitiveDataSerializer3:
if serializer_cls == SensitiveDataSerializer3: username_field = 'user'
username_field = 'user'
else: else:
username_field = 'username' username_field = 'username'
# Note: Inverse username ordering correctly applied. # Note: Inverse username ordering correctly applied.
assert response.data == [ assert response.data == [{'id':3,username_field:'userC'},{'id':2,username_field:'userB'},{'id':1,username_field:'userA'},]
{'id': 3, username_field: 'userC'},
{'id': 2, username_field: 'userB'},
{'id': 1, username_field: 'userA'},
]
def test_cannot_order_by_non_serializer_fields(self): def test_cannot_order_by_non_serializer_fields(self):
for serializer_cls in [ for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
SensitiveDataSerializer1, class OrderingListView(generics.ListAPIView):
SensitiveDataSerializer2, queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
SensitiveDataSerializer3 filter_backends = (filters.OrderingFilter,)
]: serializer_class = serializer_cls
class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'password'}) request=factory.get('/',{'ordering':'password'})
response = view(request) response=view(request)
ifserializer_cls==SensitiveDataSerializer3:
if serializer_cls == SensitiveDataSerializer3: username_field = 'user'
username_field = 'user'
else: else:
username_field = 'username' username_field = 'username'
# Note: The passwords are not in order. Default ordering is used. # Note: The passwords are not in order. Default ordering is used.
assert response.data == [ assert response.data == [{'id':1,username_field:'userA'},{'id':2,username_field:'userB'},{'id':3,username_field:'userC'},]
{'id': 1, username_field: 'userA'}, # PassB
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
]

View File

@ -23,14 +23,12 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_generateschema') @override_settings(ROOT_URLCONF='tests.test_generateschema')
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class GenerateSchemaTests(TestCase): """Tests for management command generateschema."""
"""Tests for management command generateschema.""" defsetUp(self):
self.out = io.StringIO()
def setUp(self):
self.out = io.StringIO()
def test_renders_default_schema_with_custom_title_url_and_description(self): def test_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info: expected_out = """info:
description: Sample description description: Sample description
title: SampleAPI title: SampleAPI
version: '' version: ''
@ -42,45 +40,16 @@ class GenerateSchemaTests(TestCase):
servers: servers:
- url: http://api.sample.com/ - url: http://api.sample.com/
""" """
call_command('generateschema', call_command('generateschema', '--title=SampleAPI', '--url=http://api.sample.com', '--description=Sample description', stdout=self.out)
'--title=SampleAPI', assert formatting.dedent(expected_out) in self.out.getvalue()
'--url=http://api.sample.com',
'--description=Sample description',
stdout=self.out)
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
def test_renders_openapi_json_schema(self): def test_renders_openapi_json_schema(self):
expected_out = { expected_out = { "openapi": "3.0.0", "info": { "version": "", "title": "", "description": "" }, "servers": [ { "url": "" } ], "paths": { "/": { "get": { "operationId": "list" } } } }
"openapi": "3.0.0", call_command('generateschema', '--format=openapi-json', stdout=self.out)
"info": { out_json = json.loads(self.out.getvalue())
"version": "", assert out_json == expected_out
"title": "",
"description": ""
},
"servers": [
{
"url": ""
}
],
"paths": {
"/": {
"get": {
"operationId": "list"
}
}
}
}
call_command('generateschema',
'--format=openapi-json',
stdout=self.out)
out_json = json.loads(self.out.getvalue())
self.assertDictEqual(out_json, expected_out)
def test_renders_corejson_schema(self): def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema', call_command('generateschema', '--format=corejson', stdout=self.out)
'--format=corejson', assert expected_out in self.out.getvalue()
stdout=self.out)
self.assertIn(expected_out, self.out.getvalue())

View File

@ -75,304 +75,282 @@ class SlugBasedInstanceView(InstanceView):
# Tests # Tests
class TestRootView(TestCase): def setUp(self):
def setUp(self): """
"""
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz'] items = ['foo', 'bar', 'baz']
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text} self.view=RootView.as_view()
for obj in self.objects.all()
]
self.view = RootView.as_view()
def test_get_root_view(self): def test_get_root_view(self):
""" """
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/') request = factory.get('/')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data assertresponse.data==self.data
def test_head_root_view(self): def test_head_root_view(self):
""" """
HEAD requests to ListCreateAPIView should return 200. HEAD requests to ListCreateAPIView should return 200.
""" """
request = factory.head('/') request = factory.head('/')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
def test_post_root_view(self): def test_post_root_view(self):
""" """
POST requests to ListCreateAPIView should create a new object. POST requests to ListCreateAPIView should create a new object.
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'} assertresponse.data=={'id':4,'text':'foobar'}
created = self.objects.get(id=4) created=self.objects.get(id=4)
assert created.text == 'foobar' assertcreated.text=='foobar'
def test_put_root_view(self): def test_put_root_view(self):
""" """
PUT requests to ListCreateAPIView should not be allowed PUT requests to ListCreateAPIView should not be allowed
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/', data, format='json') request = factory.put('/', data, format='json')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "PUT" not allowed.'} assertresponse.data=={"detail":'Method "PUT" not allowed.'}
def test_delete_root_view(self): def test_delete_root_view(self):
""" """
DELETE requests to ListCreateAPIView should not be allowed DELETE requests to ListCreateAPIView should not be allowed
""" """
request = factory.delete('/') request = factory.delete('/')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "DELETE" not allowed.'} assertresponse.data=={"detail":'Method "DELETE" not allowed.'}
def test_post_cannot_set_id(self): def test_post_cannot_set_id(self):
""" """
POST requests to create a new object should not be able to set the id. POST requests to create a new object should not be able to set the id.
""" """
data = {'id': 999, 'text': 'foobar'} data = {'id': 999, 'text': 'foobar'}
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'} assertresponse.data=={'id':4,'text':'foobar'}
created = self.objects.get(id=4) created=self.objects.get(id=4)
assert created.text == 'foobar' assertcreated.text=='foobar'
def test_post_error_root_view(self): def test_post_error_root_view(self):
""" """
POST requests to ListCreateAPIView in HTML should include a form error. POST requests to ListCreateAPIView in HTML should include a form error.
""" """
data = {'text': 'foobar' * 100} data = {'text': 'foobar' * 100}
request = factory.post('/', data, HTTP_ACCEPT='text/html') request = factory.post('/', data, HTTP_ACCEPT='text/html')
response = self.view(request).render() response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>' expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode() assert expected_error in response.rendered_content.decode()
EXPECTED_QUERIES_FOR_PUT = 2 EXPECTED_QUERIES_FOR_PUT = 2
def setUp(self):
"""
class TestInstanceView(TestCase):
def setUp(self):
"""
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz', 'filtered out'] items = ['foo', 'bar', 'baz', 'filtered out']
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out') self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text} self.view=InstanceView.as_view()
for obj in self.objects.all() self.slug_based_view=SlugBasedInstanceView.as_view()
]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
def test_get_instance_view(self): def test_get_instance_view(self):
""" """
GET requests to RetrieveUpdateDestroyAPIView should return a single object. GET requests to RetrieveUpdateDestroyAPIView should return a single object.
""" """
request = factory.get('/1') request = factory.get('/1')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0] assertresponse.data==self.data[0]
def test_post_instance_view(self): def test_post_instance_view(self):
""" """
POST requests to RetrieveUpdateDestroyAPIView should not be allowed POST requests to RetrieveUpdateDestroyAPIView should not be allowed
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "POST" not allowed.'} assertresponse.data=={"detail":'Method "POST" not allowed.'}
def test_put_instance_view(self): def test_put_instance_view(self):
""" """
PUT requests to RetrieveUpdateDestroyAPIView should update an object. PUT requests to RetrieveUpdateDestroyAPIView should update an object.
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render() response = self.view(request, pk='1').render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'} assertdict(response.data)=={'id':1,'text':'foobar'}
updated = self.objects.get(id=1) updated=self.objects.get(id=1)
assert updated.text == 'foobar' assertupdated.text=='foobar'
def test_patch_instance_view(self): def test_patch_instance_view(self):
""" """
PATCH requests to RetrieveUpdateDestroyAPIView should update an object. PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.patch('/1', data, format='json') request = factory.patch('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): response = self.view(request, pk=1).render()
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'} assertresponse.data=={'id':1,'text':'foobar'}
updated = self.objects.get(id=1) updated=self.objects.get(id=1)
assert updated.text == 'foobar' assertupdated.text=='foobar'
def test_delete_instance_view(self): def test_delete_instance_view(self):
""" """
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
""" """
request = factory.delete('/1') request = factory.delete('/1')
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == b'' assertresponse.content==b''
ids = [obj.id for obj in self.objects.all()] ids=[obj.idforobjinself.objects.all()]
assert ids == [2, 3] assertids==[2,3]
def test_get_instance_view_incorrect_arg(self): def test_get_instance_view_incorrect_arg(self):
""" """
GET requests with an incorrect pk type, should raise 404, not 500. GET requests with an incorrect pk type, should raise 404, not 500.
Regression test for #890. Regression test for #890.
""" """
request = factory.get('/a') request = factory.get('/a')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request, pk='a').render() response = self.view(request, pk='a').render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self): def test_put_cannot_set_id(self):
""" """
PUT requests to create a new object should not be able to set the id. PUT requests to create a new object should not be able to set the id.
""" """
data = {'id': 999, 'text': 'foobar'} data = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'} assertresponse.data=={'id':1,'text':'foobar'}
updated = self.objects.get(id=1) updated=self.objects.get(id=1)
assert updated.text == 'foobar' assertupdated.text=='foobar'
def test_put_to_deleted_instance(self): def test_put_to_deleted_instance(self):
""" """
PUT requests to RetrieveUpdateDestroyAPIView should return 404 if PUT requests to RetrieveUpdateDestroyAPIView should return 404 if
an object does not currently exist. an object does not currently exist.
""" """
self.objects.get(id=1).delete() self.objects.get(id=1).delete()
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_to_filtered_out_instance(self): def test_put_to_filtered_out_instance(self):
""" """
PUT requests to an URL of instance which is filtered out should not be PUT requests to an URL of instance which is filtered out should not be
able to create new objects. able to create new objects.
""" """
data = {'text': 'foo'} data = {'text': 'foo'}
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
request = factory.put('/{}'.format(filtered_out_pk), data, format='json') request = factory.put('/{}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render() response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_patch_cannot_create_an_object(self): def test_patch_cannot_create_an_object(self):
""" """
PATCH requests should not be able to create objects. PATCH requests should not be able to create objects.
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.patch('/999', data, format='json') request = factory.patch('/999', data, format='json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=999).render() response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert not self.objects.filter(id=999).exists() assertnotself.objects.filter(id=999).exists()
def test_put_error_instance_view(self): def test_put_error_instance_view(self):
""" """
Incorrect PUT requests in HTML should include a form error. Incorrect PUT requests in HTML should include a form error.
""" """
data = {'text': 'foobar' * 100} data = {'text': 'foobar' * 100}
request = factory.put('/', data, HTTP_ACCEPT='text/html') request = factory.put('/', data, HTTP_ACCEPT='text/html')
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>' expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode() assert expected_error in response.rendered_content.decode()
class TestFKInstanceView(TestCase): def setUp(self):
def setUp(self): """
"""
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz'] items = ['foo', 'bar', 'baz']
for item in items: for item in items:
t = ForeignKeyTarget(name=item) t = ForeignKeyTarget(name=item)
t.save() t.save()
ForeignKeySource(name='source_' + item, target=t).save() ForeignKeySource(name='source_' + item, target=t).save()
self.objects = ForeignKeySource.objects self.objects = ForeignKeySource.objects
self.data = [ self.data=[{'id':obj.id,'name':obj.name}forobjinself.objects.all()]
{'id': obj.id, 'name': obj.name} self.view=FKInstanceView.as_view()
for obj in self.objects.all()
]
self.view = FKInstanceView.as_view()
class TestOverriddenGetObject(TestCase): """
"""
Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
queryset/model mechanism but instead overrides get_object() queryset/model mechanism but instead overrides get_object()
""" """
defsetUp(self):
def setUp(self): """
"""
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz'] items = ['foo', 'bar', 'baz']
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text} classOverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
for obj in self.objects.all() """
]
class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
"""
Example detail view for override of get_object(). Example detail view for override of get_object().
""" """
serializer_class = BasicSerializer serializer_class = BasicSerializer
def get_object(self):
def get_object(self): pk = int(self.kwargs['pk'])
pk = int(self.kwargs['pk']) return get_object_or_404(BasicModel.objects.all(), id=pk)
return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view() self.view = OverriddenGetObjectView.as_view()
def test_overridden_get_object_view(self): def test_overridden_get_object_view(self):
""" """
GET requests to RetrieveUpdateDestroyAPIView should return a single object. GET requests to RetrieveUpdateDestroyAPIView should return a single object.
""" """
request = factory.get('/1') request = factory.get('/1')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0] assertresponse.data==self.data[0]
# Regression test for #285 # Regression test for #285
@ -388,23 +366,22 @@ class CommentView(generics.ListCreateAPIView):
model = Comment model = Comment
class TestCreateModelWithAutoNowAddField(TestCase): def setUp(self):
def setUp(self): self.objects = Comment.objects
self.objects = Comment.objects self.view = CommentView.as_view()
self.view = CommentView.as_view()
def test_create_model_with_auto_now_add_field(self): def test_create_model_with_auto_now_add_field(self):
""" """
Regression test for #285 Regression test for #285
https://github.com/encode/django-rest-framework/issues/285 https://github.com/encode/django-rest-framework/issues/285
""" """
data = {'email': 'foobar@example.com', 'content': 'foobar'} data = {'email': 'foobar@example.com', 'content': 'foobar'}
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1) created = self.objects.get(id=1)
assert created.content == 'foobar' assert created.content == 'foobar'
# Test for particularly ugly regression with m2m in browsable API # Test for particularly ugly regression with m2m in browsable API
@ -432,15 +409,14 @@ class ExampleView(generics.ListCreateAPIView):
queryset = ClassA.objects.all() queryset = ClassA.objects.all()
class TestM2MBrowsableAPI(TestCase): def test_m2m_in_browsable_api(self):
def test_m2m_in_browsable_api(self): """
"""
Test for particularly ugly regression with m2m in browsable API Test for particularly ugly regression with m2m in browsable API
""" """
request = factory.get('/', HTTP_ACCEPT='text/html') request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view() view = ExampleView().as_view()
response = view(request).render() response = view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
class InclusiveFilterBackend: class InclusiveFilterBackend:
@ -476,189 +452,179 @@ class DynamicSerializerView(generics.ListCreateAPIView):
return DynamicSerializer return DynamicSerializer
class TestFilterBackendAppliedToViews(TestCase): def setUp(self):
def setUp(self): """
"""
Create 3 BasicModel instances to filter on. Create 3 BasicModel instances to filter on.
""" """
items = ['foo', 'bar', 'baz'] items = ['foo', 'bar', 'baz']
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
def test_get_root_view_filters_by_name_with_filter_backend(self): def test_get_root_view_filters_by_name_with_filter_backend(self):
""" """
GET requests to ListCreateAPIView should return filtered list. GET requests to ListCreateAPIView should return filtered list.
""" """
root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/') request = factory.get('/')
response = root_view(request).render() response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1 assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}] assert response.data == [{'id': 1, 'text': 'foo'}]
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
""" """
GET requests to ListCreateAPIView should return empty list when all models are filtered out. GET requests to ListCreateAPIView should return empty list when all models are filtered out.
""" """
root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/') request = factory.get('/')
response = root_view(request).render() response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == [] assert response.data == []
def test_get_instance_view_filters_out_name_with_filter_backend(self): def test_get_instance_view_filters_out_name_with_filter_backend(self):
""" """
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
""" """
instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1') request = factory.get('/1')
response = instance_view(request, pk=1).render() response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'} assert response.data == {'detail': 'Not found.'}
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
""" """
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
""" """
instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1') request = factory.get('/1')
response = instance_view(request, pk=1).render() response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'} assert response.data == {'id': 1, 'text': 'foo'}
def test_dynamic_serializer_form_in_browsable_api(self): def test_dynamic_serializer_form_in_browsable_api(self):
""" """
GET requests to ListCreateAPIView should return filtered list. GET requests to ListCreateAPIView should return filtered list.
""" """
view = DynamicSerializerView.as_view() view = DynamicSerializerView.as_view()
request = factory.get('/') request = factory.get('/')
response = view(request).render() response = view(request).render()
content = response.content.decode() content = response.content.decode()
assert 'field_b' in content assert 'field_b' in content
assert 'field_a' not in content assert 'field_a' not in content
class TestGuardedQueryset(TestCase): def test_guarded_queryset(self):
def test_guarded_queryset(self): class QuerysetAccessError(generics.ListAPIView):
class QuerysetAccessError(generics.ListAPIView): queryset = BasicModel.objects.all()
queryset = BasicModel.objects.all() def get(self, request):
return Response(list(self.queryset))
def get(self, request):
return Response(list(self.queryset))
view = QuerysetAccessError.as_view() view = QuerysetAccessError.as_view()
request = factory.get('/') request=factory.get('/')
with pytest.raises(RuntimeError): withpytest.raises(RuntimeError):
view(request).render() view(request).render()
class ApiViewsTests(TestCase):
def test_create_api_view_post(self): def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView): class MockCreateApiView(generics.CreateAPIView):
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockCreateApiView() view = MockCreateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.post('test request', 'test arg', test_kwarg='test') view.post('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_destroy_api_view_delete(self): def test_destroy_api_view_delete(self):
class MockDestroyApiView(generics.DestroyAPIView): class MockDestroyApiView(generics.DestroyAPIView):
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockDestroyApiView() view = MockDestroyApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.delete('test request', 'test arg', test_kwarg='test') view.delete('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_update_api_view_partial_update(self): def test_update_api_view_partial_update(self):
class MockUpdateApiView(generics.UpdateAPIView): class MockUpdateApiView(generics.UpdateAPIView):
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockUpdateApiView() view = MockUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.patch('test request', 'test arg', test_kwarg='test') view.patch('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_update_api_view_get(self): def test_retrieve_update_api_view_get(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.get('test request', 'test arg', test_kwarg='test') view.get('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_update_api_view_put(self): def test_retrieve_update_api_view_put(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.put('test request', 'test arg', test_kwarg='test') view.put('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_update_api_view_patch(self): def test_retrieve_update_api_view_patch(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.patch('test request', 'test arg', test_kwarg='test') view.patch('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_destroy_api_view_get(self): def test_retrieve_destroy_api_view_get(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView() view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.get('test request', 'test arg', test_kwarg='test') view.get('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
def test_retrieve_destroy_api_view_delete(self): def test_retrieve_destroy_api_view_delete(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView() view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data=('test request',('test arg',),{'test_kwarg':'test'})
view.delete('test request', 'test arg', test_kwarg='test') view.delete('test request','test arg',test_kwarg='test')
assert view.called is True assertview.calledisTrue
assert view.call_args == data assertview.call_args==data
class GetObjectOr404Tests(TestCase): def setUp(self):
def setUp(self): super().setUp()
super().setUp() self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
def test_get_object_or_404_with_valid_uuid(self): def test_get_object_or_404_with_valid_uuid(self):
obj = generics.get_object_or_404( obj = generics.get_object_or_404( UUIDForeignKeyTarget, pk=self.uuid_object.pk )
UUIDForeignKeyTarget, pk=self.uuid_object.pk assert obj == self.uuid_object
)
assert obj == self.uuid_object
def test_get_object_or_404_with_invalid_string_for_uuid(self): def test_get_object_or_404_with_invalid_string_for_uuid(self):
with pytest.raises(Http404): with pytest.raises(Http404):
generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid') generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid')

View File

@ -42,122 +42,113 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer') @override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
class TemplateHTMLRendererTests(TestCase): def setUp(self):
def setUp(self): class MockResponse:
class MockResponse: template_name = None
template_name = None
self.mock_response = MockResponse() self.mock_response = MockResponse()
self._monkey_patch_get_template() self._monkey_patch_get_template()
def _monkey_patch_get_template(self): def _monkey_patch_get_template(self):
""" """
Monkeypatch get_template Monkeypatch get_template
""" """
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name, dirs=None):
def get_template(template_name, dirs=None): if template_name == 'example.html':
if template_name == 'example.html': return engines['django'].from_string("example: {{ object }}")
return engines['django'].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name) raise TemplateDoesNotExist(template_name)
def select_template(template_name_list, dirs=None, using=None): def select_template(template_name_list, dirs=None, using=None):
if template_name_list == ['example.html']: if template_name_list == ['example.html']:
return engines['django'].from_string("example: {{ object }}") return engines['django'].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name_list[0]) raise TemplateDoesNotExist(template_name_list[0])
django.template.loader.get_template = get_template django.template.loader.get_template = get_template
django.template.loader.select_template = select_template django.template.loader.select_template=select_template
def tearDown(self): def tearDown(self):
""" """
Revert monkeypatching Revert monkeypatching
""" """
django.template.loader.get_template = self.get_template django.template.loader.get_template = self.get_template
def test_simple_html_view(self): def test_simple_html_view(self):
response = self.client.get('/') response = self.client.get('/')
self.assertContains(response, "example: foobar") self.assertContains(response, "example: foobar")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_not_found_html_view(self): def test_not_found_html_view(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
self.assertEqual(response.content, b"404 Not Found") assert response.content == b"404 Not Found"
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_permission_denied_html_view(self): def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(response.content, b"403 Forbidden") assert response.content == b"403 Forbidden"
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
# 2 tests below are based on order of if statements in corresponding method # 2 tests below are based on order of if statements in corresponding method
# of TemplateHTMLRenderer # of TemplateHTMLRenderer
def test_get_template_names_returns_own_template_name(self): def test_get_template_names_returns_own_template_name(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
renderer.template_name = 'test_template' renderer.template_name = 'test_template'
template_name = renderer.get_template_names(self.mock_response, view={}) template_name = renderer.get_template_names(self.mock_response, view={})
assert template_name == ['test_template'] assert template_name == ['test_template']
def test_get_template_names_returns_view_template_name(self): def test_get_template_names_returns_view_template_name(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
class MockResponse:
class MockResponse: template_name = None
template_name = None
class MockView: class MockView:
def get_template_names(self): def get_template_names(self):
return ['template from get_template_names method'] return ['template from get_template_names method']
class MockView2: class MockView2:
template_name = 'template from template_name attribute' template_name = 'template from template_name attribute'
template_name = renderer.get_template_names(self.mock_response, template_name = renderer.get_template_names(self.mock_response,MockView())
MockView()) asserttemplate_name==['template from get_template_names method']
assert template_name == ['template from get_template_names method'] template_name=renderer.get_template_names(self.mock_response,MockView2())
asserttemplate_name==['template from template_name attribute']
template_name = renderer.get_template_names(self.mock_response,
MockView2())
assert template_name == ['template from template_name attribute']
def test_get_template_names_raises_error_if_no_template_found(self): def test_get_template_names_raises_error_if_no_template_found(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
with pytest.raises(ImproperlyConfigured): with pytest.raises(ImproperlyConfigured):
renderer.get_template_names(self.mock_response, view=object()) renderer.get_template_names(self.mock_response, view=object())
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer') @override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
class TemplateHTMLRendererExceptionTests(TestCase): def setUp(self):
def setUp(self): """
"""
Monkeypatch get_template Monkeypatch get_template
""" """
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name):
def get_template(template_name): if template_name == '404.html':
if template_name == '404.html': return engines['django'].from_string("404: {{ detail }}")
return engines['django'].from_string("404: {{ detail }}")
if template_name == '403.html': if template_name == '403.html':
return engines['django'].from_string("403: {{ detail }}") return engines['django'].from_string("403: {{ detail }}")
raise TemplateDoesNotExist(template_name) raise TemplateDoesNotExist(template_name)
django.template.loader.get_template = get_template django.template.loader.get_template = get_template
def tearDown(self): def tearDown(self):
""" """
Revert monkeypatching Revert monkeypatching
""" """
django.template.loader.get_template = self.get_template django.template.loader.get_template = self.get_template
def test_not_found_html_view_with_template(self): def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
self.assertTrue(response.content in ( assert response.content in ( b"404: Not found", b"404 Not Found")
b"404: Not found", b"404 Not Found")) assert response['Content-Type'] == 'text/html; charset=utf-8'
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self): def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertTrue(response.content in (b"403: Permission denied", b"403 Forbidden")) assert response.content in (b"403: Permission denied", b"403 Forbidden")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'

View File

@ -34,16 +34,15 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks') @override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks')
class TestLazyHyperlinkNames(TestCase): def setUp(self):
def setUp(self): self.example = Example.objects.create(text='foo')
self.example = Example.objects.create(text='foo')
def test_lazy_hyperlink_names(self): def test_lazy_hyperlink_names(self):
global str_called global str_called
context = {'request': None} context = {'request': None}
serializer = ExampleSerializer(self.example, context=context) serializer = ExampleSerializer(self.example, context=context)
JSONRenderer().render(serializer.data) JSONRenderer().render(serializer.data)
assert not str_called assert not str_called
hyperlink_string = format_value(serializer.data['url']) hyperlink_string = format_value(serializer.data['url'])
assert hyperlink_string == '<a href=/example/1/>An example</a>' assert hyperlink_string == '<a href=/example/1/>An example</a>'
assert str_called assert str_called

View File

@ -308,98 +308,49 @@ class TestMetadata:
assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer) assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer)
class TestSimpleMetadataFieldInfo(TestCase): def test_null_boolean_field_info_type(self):
def test_null_boolean_field_info_type(self): options = metadata.SimpleMetadata()
options = metadata.SimpleMetadata() field_info = options.get_field_info(serializers.NullBooleanField())
field_info = options.get_field_info(serializers.NullBooleanField()) assert field_info['type'] == 'boolean'
assert field_info['type'] == 'boolean'
def test_related_field_choices(self): def test_related_field_choices(self):
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
BasicModel.objects.create() BasicModel.objects.create()
with self.assertNumQueries(0): with self.assertNumQueries(0):
field_info = options.get_field_info( field_info = options.get_field_info( serializers.RelatedField(queryset=BasicModel.objects.all()) )
serializers.RelatedField(queryset=BasicModel.objects.all())
)
assert 'choices' not in field_info assert 'choices' not in field_info
class TestModelSerializerMetadata(TestCase): def test_read_only_primary_key_related_field(self):
def test_read_only_primary_key_related_field(self): """
"""
On generic views OPTIONS should return an 'actions' key with metadata On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests. It should on the fields that may be supplied to PUT and POST requests. It should
not fail when a read_only PrimaryKeyRelatedField is present not fail when a read_only PrimaryKeyRelatedField is present
""" """
class Parent(models.Model): class Parent(models.Model):
integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)]) integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)])
children = models.ManyToManyField('Child') children = models.ManyToManyField('Child')
name = models.CharField(max_length=100, blank=True, null=True) name = models.CharField(max_length=100, blank=True, null=True)
class Child(models.Model): class Child(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
children = serializers.PrimaryKeyRelatedField(read_only=True, many=True) children = serializers.PrimaryKeyRelatedField(read_only=True, many=True)
class Meta:
class Meta: model = Parent
model = Parent fields = '__all__'
fields = '__all__'
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
def get_serializer(self): def get_serializer(self):
return ExampleSerializer() return ExampleSerializer()
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response=view(request=request)
expected = { expected={'name':'Example','description':'Example view.','renders':['application/json','text/html'],'parses':['application/json','application/x-www-form-urlencoded','multipart/form-data'],'actions':{'POST':{'id':{'type':'integer','required':False,'read_only':True,'label':'ID'},'children':{'type':'field','required':False,'read_only':True,'label':'Children'},'integer_field':{'type':'integer','required':True,'read_only':False,'label':'Integer field','min_value':1,'max_value':1000},'name':{'type':'string','required':False,'read_only':False,'label':'Name','max_length':100}}}}
'name': 'Example', assertresponse.status_code==status.HTTP_200_OK
'description': 'Example view.', assertresponse.data==expected
'renders': [
'application/json',
'text/html'
],
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
],
'actions': {
'POST': {
'id': {
'type': 'integer',
'required': False,
'read_only': True,
'label': 'ID'
},
'children': {
'type': 'field',
'required': False,
'read_only': True,
'label': 'Children'
},
'integer_field': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'Integer field',
'min_value': 1,
'max_value': 1000
},
'name': {
'type': 'string',
'required': False,
'read_only': False,
'label': 'Name',
'max_length': 100
}
}
}
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected

File diff suppressed because it is too large Load Diff

View File

@ -33,35 +33,31 @@ class AssociatedModelSerializer(serializers.ModelSerializer):
# Tests # Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
Assert that the parent pointer field is not included in the fields Assert that the parent pointer field is not included in the fields
serialized fields serialized fields
""" """
child = ChildModel(name1='parent name', name2='child name') child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child) serializer = DerivedModelSerializer(child)
assert set(serializer.data) == {'name1', 'name2', 'id'} assert set(serializer.data) == {'name1', 'name2', 'id'}
def test_onetoone_primary_key_model_fields_as_expected(self): def test_onetoone_primary_key_model_fields_as_expected(self):
""" """
Assert that a model with a onetoone field that is the primary key is Assert that a model with a onetoone field that is the primary key is
not treated like a derived model not treated like a derived model
""" """
parent = ParentModel.objects.create(name1='parent name') parent = ParentModel.objects.create(name1='parent name')
associate = AssociatedModel.objects.create(name='hello', ref=parent) associate = AssociatedModel.objects.create(name='hello', ref=parent)
serializer = AssociatedModelSerializer(associate) serializer = AssociatedModelSerializer(associate)
assert set(serializer.data) == {'name', 'ref'} assert set(serializer.data) == {'name', 'ref'}
def test_data_is_valid_without_parent_ptr(self): def test_data_is_valid_without_parent_ptr(self):
""" """
Assert that the pointer to the parent table is not a required field Assert that the pointer to the parent table is not a required field
for input data for input data
""" """
data = { data = { 'name1': 'parent name', 'name2': 'child name', }
'name1': 'parent name', serializer = DerivedModelSerializer(data=data)
'name2': 'child name', assert serializer.is_valid() is True
}
serializer = DerivedModelSerializer(data=data)
assert serializer.is_valid() is True

View File

@ -30,70 +30,68 @@ class NoCharsetSpecifiedRenderer(BaseRenderer):
media_type = 'my/media' media_type = 'my/media'
class TestAcceptedMediaType(TestCase): def setUp(self):
def setUp(self): self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()]
self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()] self.negotiator = DefaultContentNegotiation()
self.negotiator = DefaultContentNegotiation()
def select_renderer(self, request): def select_renderer(self, request):
return self.negotiator.select_renderer(request, self.renderers) return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self): def test_client_without_accept_use_renderer(self):
request = Request(factory.get('/')) request = Request(factory.get('/'))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json' assert accepted_media_type == 'application/json'
def test_client_underspecifies_accept_use_renderer(self): def test_client_underspecifies_accept_use_renderer(self):
request = Request(factory.get('/', HTTP_ACCEPT='*/*')) request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json' assert accepted_media_type == 'application/json'
def test_client_overspecifies_accept_use_client(self): def test_client_overspecifies_accept_use_client(self):
request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json; indent=8' assert accepted_media_type == 'application/json; indent=8'
def test_client_specifies_parameter(self): def test_client_specifies_parameter(self):
request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0')) request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0'))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/openapi+json;version=2.0' assert accepted_media_type == 'application/openapi+json;version=2.0'
assert accepted_renderer.format == 'swagger' assert accepted_renderer.format == 'swagger'
def test_match_is_false_if_main_types_not_match(self): def test_match_is_false_if_main_types_not_match(self):
mediatype = _MediaType('test_1') mediatype = _MediaType('test_1')
anoter_mediatype = _MediaType('test_2') anoter_mediatype = _MediaType('test_2')
assert mediatype.match(anoter_mediatype) is False assert mediatype.match(anoter_mediatype) is False
def test_mediatype_match_is_false_if_keys_not_match(self): def test_mediatype_match_is_false_if_keys_not_match(self):
mediatype = _MediaType(';test_param=foo') mediatype = _MediaType(';test_param=foo')
another_mediatype = _MediaType(';test_param=bar') another_mediatype = _MediaType(';test_param=bar')
assert mediatype.match(another_mediatype) is False assert mediatype.match(another_mediatype) is False
def test_mediatype_precedence_with_wildcard_subtype(self): def test_mediatype_precedence_with_wildcard_subtype(self):
mediatype = _MediaType('test/*') mediatype = _MediaType('test/*')
assert mediatype.precedence == 1 assert mediatype.precedence == 1
def test_mediatype_string_representation(self): def test_mediatype_string_representation(self):
mediatype = _MediaType('test/*; foo=bar') mediatype = _MediaType('test/*; foo=bar')
assert str(mediatype) == 'test/*; foo=bar' assert str(mediatype) == 'test/*; foo=bar'
def test_raise_error_if_no_suitable_renderers_found(self): def test_raise_error_if_no_suitable_renderers_found(self):
class MockRenderer: class MockRenderer:
format = 'xml' format = 'xml'
renderers = [MockRenderer()] renderers = [MockRenderer()]
with pytest.raises(Http404): withpytest.raises(Http404):
self.negotiator.filter_renderers(renderers, format='json') self.negotiator.filter_renderers(renderers, format='json')
class BaseContentNegotiationTests(TestCase):
def setUp(self): def setUp(self):
self.negotiator = BaseContentNegotiation() self.negotiator = BaseContentNegotiation()
def test_raise_error_for_abstract_select_parser_method(self): def test_raise_error_for_abstract_select_parser_method(self):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
self.negotiator.select_parser(None, None) self.negotiator.select_parser(None, None)
def test_raise_error_for_abstract_select_renderer_method(self): def test_raise_error_for_abstract_select_renderer_method(self):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
self.negotiator.select_renderer(None, None) self.negotiator.select_renderer(None, None)

View File

@ -28,14 +28,12 @@ class ChildAssociatedModelSerializer(serializers.ModelSerializer):
# Tests # Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
Assert that the parent pointer field is not included in the fields Assert that the parent pointer field is not included in the fields
serialized fields serialized fields
""" """
child = ChildModel(name1='parent name', name2='child name') child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child) serializer = DerivedModelSerializer(child)
self.assertEqual(set(serializer.data), assert set(serializer.data) == {'name1', 'name2', 'id', 'childassociatedmodel'}
{'name1', 'name2', 'id', 'childassociatedmodel'})

View File

@ -22,157 +22,138 @@ class Form(forms.Form):
field2 = forms.CharField() field2 = forms.CharField()
class TestFormParser(TestCase): def setUp(self):
def setUp(self): self.string = "field1=abc&field2=defghijk"
self.string = "field1=abc&field2=defghijk"
def test_parse(self): def test_parse(self):
""" Make sure the `QueryDict` works OK """ """ Make sure the `QueryDict` works OK """
parser = FormParser() parser = FormParser()
stream = io.StringIO(self.string)
stream = io.StringIO(self.string) data = parser.parse(stream)
data = parser.parse(stream) assert Form(data).is_valid() is True
assert Form(data).is_valid() is True
class TestFileUploadParser(TestCase): def setUp(self):
def setUp(self): class MockRequest:
class MockRequest: pass
pass
self.stream = io.BytesIO(b"Test text file") self.stream = io.BytesIO(b"Test text file")
request = MockRequest() request=MockRequest()
request.upload_handlers = (MemoryFileUploadHandler(),) request.upload_handlers=(MemoryFileUploadHandler(),)
request.META = { request.META={'HTTP_CONTENT_DISPOSITION':'Content-Disposition: inline; filename=file.txt','HTTP_CONTENT_LENGTH':14,}
'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt', self.parser_context={'request':request,'kwargs':{}}
'HTTP_CONTENT_LENGTH': 14,
}
self.parser_context = {'request': request, 'kwargs': {}}
def test_parse(self): def test_parse(self):
""" """
Parse raw file upload. Parse raw file upload.
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
data_and_files = parser.parse(self.stream, None, self.parser_context) data_and_files = parser.parse(self.stream, None, self.parser_context)
file_obj = data_and_files.files['file'] file_obj = data_and_files.files['file']
assert file_obj.size == 14 assert file_obj.size == 14
def test_parse_missing_filename(self): def test_parse_missing_filename(self):
""" """
Parse raw file upload when filename is missing. Parse raw file upload when filename is missing.
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo: with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context) parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_parse_missing_filename_multiple_upload_handlers(self): def test_parse_missing_filename_multiple_upload_handlers(self):
""" """
Parse raw file upload with multiple handlers when filename is missing. Parse raw file upload with multiple handlers when filename is missing.
Regression test for #2109. Regression test for #2109.
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( self.parser_context['request'].upload_handlers = ( MemoryFileUploadHandler(), MemoryFileUploadHandler() )
MemoryFileUploadHandler(), self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
MemoryFileUploadHandler() with pytest.raises(ParseError) as excinfo:
) parser.parse(self.stream, None, self.parser_context)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_parse_missing_filename_large_file(self): def test_parse_missing_filename_large_file(self):
""" """
Parse raw file upload when filename is missing with TemporaryFileUploadHandler. Parse raw file upload when filename is missing with TemporaryFileUploadHandler.
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( self.parser_context['request'].upload_handlers = ( TemporaryFileUploadHandler(), )
TemporaryFileUploadHandler(), self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
) with pytest.raises(ParseError) as excinfo:
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' parser.parse(self.stream, None, self.parser_context)
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_get_filename(self): def test_get_filename(self):
parser = FileUploadParser() parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'file.txt' assert filename == 'file.txt'
def test_get_encoded_filename(self): def test_get_encoded_filename(self):
parser = FileUploadParser() parser = FileUploadParser()
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt') filename = parser.get_filename(self.stream, None, self.parser_context)
filename = parser.get_filename(self.stream, None, self.parser_context) assert filename == 'ÀĥƦ.txt'
assert filename == 'ÀĥƦ.txt' self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt') assert filename == 'ÀĥƦ.txt'
filename = parser.get_filename(self.stream, None, self.parser_context) self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
assert filename == 'ÀĥƦ.txt' filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
def __replace_content_disposition(self, disposition): def __replace_content_disposition(self, disposition):
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
class TestJSONParser(TestCase): def bytes(self, value):
def bytes(self, value): return io.BytesIO(value.encode())
return io.BytesIO(value.encode())
def test_float_strictness(self): def test_float_strictness(self):
parser = JSONParser() parser = JSONParser()
for value in ['Infinity', '-Infinity', 'NaN']:
# Default to strict with pytest.raises(ParseError):
for value in ['Infinity', '-Infinity', 'NaN']: parser.parse(self.bytes(value))
with pytest.raises(ParseError):
parser.parse(self.bytes(value))
parser.strict = False parser.strict = False
assert parser.parse(self.bytes('Infinity')) == float('inf') assertparser.parse(self.bytes('Infinity'))==float('inf')
assert parser.parse(self.bytes('-Infinity')) == float('-inf') assertparser.parse(self.bytes('-Infinity'))==float('-inf')
assert math.isnan(parser.parse(self.bytes('NaN'))) assertmath.isnan(parser.parse(self.bytes('NaN')))
class TestPOSTAccessed(TestCase): def setUp(self):
def setUp(self): self.factory = APIRequestFactory()
self.factory = APIRequestFactory()
def test_post_accessed_in_post_method(self): def test_post_accessed_in_post_method(self):
django_request = self.factory.post('/', {'foo': 'bar'}) django_request = self.factory.post('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST django_request.POST
assert request.POST == {'foo': ['bar']} assert request.POST == {'foo': ['bar']}
assert request.data == {'foo': ['bar']} assert request.data == {'foo': ['bar']}
def test_post_accessed_in_post_method_with_json_parser(self): def test_post_accessed_in_post_method_with_json_parser(self):
django_request = self.factory.post('/', {'foo': 'bar'}) django_request = self.factory.post('/', {'foo': 'bar'})
request = Request(django_request, parsers=[JSONParser()]) request = Request(django_request, parsers=[JSONParser()])
django_request.POST django_request.POST
assert request.POST == {} assert request.POST == {}
assert request.data == {} assert request.data == {}
def test_post_accessed_in_put_method(self): def test_post_accessed_in_put_method(self):
django_request = self.factory.put('/', {'foo': 'bar'}) django_request = self.factory.put('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST django_request.POST
assert request.POST == {'foo': ['bar']} assert request.POST == {'foo': ['bar']}
assert request.data == {'foo': ['bar']} assert request.data == {'foo': ['bar']}
def test_request_read_before_parsing(self): def test_request_read_before_parsing(self):
django_request = self.factory.put('/', {'foo': 'bar'}) django_request = self.factory.put('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.read() django_request.read()
with pytest.raises(RawPostDataException):
request.POST
with pytest.raises(RawPostDataException): with pytest.raises(RawPostDataException):
request.POST request.POST
with pytest.raises(RawPostDataException): request.data
request.POST
request.data

View File

@ -72,177 +72,136 @@ def basic_auth_header(username, password):
return 'Basic %s' % base64_credentials return 'Basic %s' % base64_credentials
class ModelPermissionsIntegrationTests(TestCase): def setUp(self):
def setUp(self): User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
User.objects.create_user('disallowed', 'disallowed@example.com', 'password') user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
user = User.objects.create_user('permitted', 'permitted@example.com', 'password') user.user_permissions.set([ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') ])
user.user_permissions.set([ user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
Permission.objects.get(codename='add_basicmodel'), user.user_permissions.set([ Permission.objects.get(codename='change_basicmodel'), ])
Permission.objects.get(codename='change_basicmodel'), self.permitted_credentials = basic_auth_header('permitted', 'password')
Permission.objects.get(codename='delete_basicmodel') self.disallowed_credentials = basic_auth_header('disallowed', 'password')
]) self.updateonly_credentials = basic_auth_header('updateonly', 'password')
BasicModel(text='foo').save()
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
user.user_permissions.set([
Permission.objects.get(codename='change_basicmodel'),
])
self.permitted_credentials = basic_auth_header('permitted', 'password')
self.disallowed_credentials = basic_auth_header('disallowed', 'password')
self.updateonly_credentials = basic_auth_header('updateonly', 'password')
BasicModel(text='foo').save()
def test_has_create_permissions(self): def test_has_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk=1)
response = root_view(request, pk=1) assert response.status_code == status.HTTP_201_CREATED
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_api_root_view_discard_default_django_model_permission(self): def test_api_root_view_discard_default_django_model_permission(self):
""" """
We check that DEFAULT_PERMISSION_CLASSES can We check that DEFAULT_PERMISSION_CLASSES can
apply to APIRoot view. More specifically we check expected behavior of apply to APIRoot view. More specifically we check expected behavior of
``_ignore_model_permissions`` attribute support. ``_ignore_model_permissions`` attribute support.
""" """
request = factory.get('/', format='json', request = factory.get('/', format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials) request.resolver_match = ResolverMatch('get', (), {})
request.resolver_match = ResolverMatch('get', (), {}) response = api_root_view(request)
response = api_root_view(request) assert response.status_code == status.HTTP_200_OK
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_get_queryset_has_create_permissions(self): def test_get_queryset_has_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials) response = get_queryset_list_view(request, pk=1)
response = get_queryset_list_view(request, pk=1) assert response.status_code == status.HTTP_201_CREATED
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_has_put_permissions(self): def test_has_put_permissions(self):
request = factory.put('/1', {'text': 'foobar'}, format='json', request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1')
response = instance_view(request, pk='1') assert response.status_code == status.HTTP_200_OK
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_has_delete_permissions(self): def test_has_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) assert response.status_code == status.HTTP_204_NO_CONTENT
def test_does_not_have_create_permissions(self): def test_does_not_have_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk=1)
response = root_view(request, pk=1) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_does_not_have_put_permissions(self): def test_does_not_have_put_permissions(self):
request = factory.put('/1', {'text': 'foobar'}, format='json', request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1')
response = instance_view(request, pk='1') assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_does_not_have_delete_permissions(self): def test_does_not_have_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
def test_options_permitted(self): def test_options_permitted(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.permitted_credentials )
'/', response = root_view(request, pk='1')
HTTP_AUTHORIZATION=self.permitted_credentials assert response.status_code == status.HTTP_200_OK
) assert 'actions' in response.data
response = root_view(request, pk='1') assert list(response.data['actions']) == ['POST']
self.assertEqual(response.status_code, status.HTTP_200_OK) request = factory.options( '/1', HTTP_AUTHORIZATION=self.permitted_credentials )
self.assertIn('actions', response.data) response = instance_view(request, pk='1')
self.assertEqual(list(response.data['actions']), ['POST']) assert response.status_code == status.HTTP_200_OK
assert 'actions' in response.data
request = factory.options( assert list(response.data['actions']) == ['PUT']
'/1',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions']), ['PUT'])
def test_options_disallowed(self): def test_options_disallowed(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.disallowed_credentials )
'/', response = root_view(request, pk='1')
HTTP_AUTHORIZATION=self.disallowed_credentials assert response.status_code == status.HTTP_200_OK
) assert 'actions' not in response.data
response = root_view(request, pk='1') request = factory.options( '/1', HTTP_AUTHORIZATION=self.disallowed_credentials )
self.assertEqual(response.status_code, status.HTTP_200_OK) response = instance_view(request, pk='1')
self.assertNotIn('actions', response.data) assert response.status_code == status.HTTP_200_OK
assert 'actions' not in response.data
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
def test_options_updateonly(self): def test_options_updateonly(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.updateonly_credentials )
'/', response = root_view(request, pk='1')
HTTP_AUTHORIZATION=self.updateonly_credentials assert response.status_code == status.HTTP_200_OK
) assert 'actions' not in response.data
response = root_view(request, pk='1') request = factory.options( '/1', HTTP_AUTHORIZATION=self.updateonly_credentials )
self.assertEqual(response.status_code, status.HTTP_200_OK) response = instance_view(request, pk='1')
self.assertNotIn('actions', response.data) assert response.status_code == status.HTTP_200_OK
assert 'actions' in response.data
request = factory.options( assert list(response.data['actions']) == ['PUT']
'/1',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions']), ['PUT'])
def test_empty_view_does_not_assert(self): def test_empty_view_does_not_assert(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = empty_list_view(request, pk=1) response = empty_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_calling_method_not_allowed(self): def test_calling_method_not_allowed(self):
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request) response = root_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1')
response = instance_view(request, pk='1') assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_check_auth_before_queryset_call(self): def test_check_auth_before_queryset_call(self):
class View(RootView): class View(RootView):
def get_queryset(_): def get_queryset(_):
self.fail('should not reach due to auth check') self.fail('should not reach due to auth check')
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION='')
request = factory.get('/', HTTP_AUTHORIZATION='') response=view(request)
response = view(request) assertresponse.status_code==status.HTTP_401_UNAUTHORIZED
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_queryset_assertions(self): def test_queryset_assertions(self):
class View(views.APIView): class View(views.APIView):
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions] permission_classes = [permissions.DjangoModelPermissions]
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials) msg='Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.'
msg = 'Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.' withself.assertRaisesMessage(AssertionError,msg):
with self.assertRaisesMessage(AssertionError, msg): view(request)
view(request)
# Faulty `get_queryset()` methods should trigger the above "view does not have a queryset" assertion. # Faulty `get_queryset()` methods should trigger the above "view does not have a queryset" assertion.
class View(RootView): class View(RootView):
def get_queryset(self): def get_queryset(self):
return None return None
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials) withself.assertRaisesMessage(AssertionError,'View.get_queryset() returned None'):
with self.assertRaisesMessage(AssertionError, 'View.get_queryset() returned None'): view(request)
view(request)
class BasicPermModel(models.Model): class BasicPermModel(models.Model):
@ -310,149 +269,117 @@ get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.a
@unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed') @unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed')
class ObjectPermissionsIntegrationTests(TestCase): """
"""
Integration tests for the object level permissions API. Integration tests for the object level permissions API.
""" """
def setUp(self): defsetUp(self):
from guardian.shortcuts import assign_perm from guardian.shortcuts import assign_perm
create = User.objects.create_user
# create users users = { 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), 'readonly': create('readonly', 'readonly@example.com', 'password'), 'writeonly': create('writeonly', 'writeonly@example.com', 'password'), 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), }
create = User.objects.create_user everyone = Group.objects.create(name='everyone')
users = { model_name = BasicPermModel._meta.model_name
'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), app_label = BasicPermModel._meta.app_label
'readonly': create('readonly', 'readonly@example.com', 'password'), f = '{}_{}'.format
'writeonly': create('writeonly', 'writeonly@example.com', 'password'), perms = { 'view': f('view', model_name), 'change': f('change', model_name), 'delete': f('delete', model_name) }
'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), for perm in perms.values():
} perm = '{}.{}'.format(app_label, perm)
assign_perm(perm, everyone)
# give everyone model level permissions, as we are not testing those
everyone = Group.objects.create(name='everyone')
model_name = BasicPermModel._meta.model_name
app_label = BasicPermModel._meta.app_label
f = '{}_{}'.format
perms = {
'view': f('view', model_name),
'change': f('change', model_name),
'delete': f('delete', model_name)
}
for perm in perms.values():
perm = '{}.{}'.format(app_label, perm)
assign_perm(perm, everyone)
everyone.user_set.add(*users.values()) everyone.user_set.add(*users.values())
readers=Group.objects.create(name='readers')
# appropriate object level permissions writers=Group.objects.create(name='writers')
readers = Group.objects.create(name='readers') deleters=Group.objects.create(name='deleters')
writers = Group.objects.create(name='writers') model=BasicPermModel.objects.create(text='foo')
deleters = Group.objects.create(name='deleters') assign_perm(perms['view'],readers,model)
assign_perm(perms['change'],writers,model)
model = BasicPermModel.objects.create(text='foo') assign_perm(perms['delete'],deleters,model)
readers.user_set.add(users['fullaccess'],users['readonly'])
assign_perm(perms['view'], readers, model) writers.user_set.add(users['fullaccess'],users['writeonly'])
assign_perm(perms['change'], writers, model) deleters.user_set.add(users['fullaccess'],users['deleteonly'])
assign_perm(perms['delete'], deleters, model) self.credentials={}
foruserinusers.values():
readers.user_set.add(users['fullaccess'], users['readonly']) self.credentials[user.username] = basic_auth_header(user.username, 'password')
writers.user_set.add(users['fullaccess'], users['writeonly'])
deleters.user_set.add(users['fullaccess'], users['deleteonly'])
self.credentials = {}
for user in users.values():
self.credentials[user.username] = basic_auth_header(user.username, 'password')
# Delete # Delete
def test_can_delete_permissions(self): def test_can_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) assert response.status_code == status.HTTP_204_NO_CONTENT
def test_cannot_delete_permissions(self): def test_cannot_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
# Update # Update
def test_can_update_permissions(self): def test_can_update_permissions(self):
request = factory.patch( request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['writeonly'] )
'/1', {'text': 'foobar'}, format='json', response = object_permissions_view(request, pk='1')
HTTP_AUTHORIZATION=self.credentials['writeonly'] assert response.status_code == status.HTTP_200_OK
) assert response.data.get('text') == 'foobar'
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data.get('text'), 'foobar')
def test_cannot_update_permissions(self): def test_cannot_update_permissions(self):
request = factory.patch( request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
'/1', {'text': 'foobar'}, format='json', response = object_permissions_view(request, pk='1')
HTTP_AUTHORIZATION=self.credentials['deleteonly'] assert response.status_code == status.HTTP_404_NOT_FOUND
)
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_cannot_update_permissions_non_existing(self): def test_cannot_update_permissions_non_existing(self):
request = factory.patch( request = factory.patch( '/999', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
'/999', {'text': 'foobar'}, format='json', response = object_permissions_view(request, pk='999')
HTTP_AUTHORIZATION=self.credentials['deleteonly'] assert response.status_code == status.HTTP_404_NOT_FOUND
)
response = object_permissions_view(request, pk='999')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Read # Read
def test_can_read_permissions(self): def test_can_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_cannot_read_permissions(self): def test_cannot_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
def test_can_read_get_queryset_permissions(self): def test_can_read_get_queryset_permissions(self):
""" """
same as ``test_can_read_permissions`` but with a view same as ``test_can_read_permissions`` but with a view
that rely on ``.get_queryset()`` instead of ``.queryset``. that rely on ``.get_queryset()`` instead of ``.queryset``.
""" """
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = get_queryset_object_permissions_view(request, pk='1') response = get_queryset_object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
# Read list # Read list
def test_django_object_permissions_filter_deprecated(self): def test_django_object_permissions_filter_deprecated(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
DjangoObjectPermissionsFilter() DjangoObjectPermissionsFilter()
message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved " message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved ""to the 3rd-party django-rest-framework-guardian package.")
"to the 3rd-party django-rest-framework-guardian package.") assertlen(w)==1
self.assertEqual(len(w), 1) assertw[-1].categoryisRemovedInDRF310Warning
self.assertIs(w[-1].category, RemovedInDRF310Warning) assertstr(w[-1].message)==message
self.assertEqual(str(w[-1].message), message)
def test_can_read_list_permissions(self): def test_can_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
# TODO: remove in version 3.10 with warnings.catch_warnings(record=True):
with warnings.catch_warnings(record=True): warnings.simplefilter("always")
warnings.simplefilter("always") response = object_permissions_list_view(request)
response = object_permissions_list_view(request) assert response.status_code== status.HTTP_200_OK
self.assertEqual(response.status_code, status.HTTP_200_OK) assertresponse.data[0].get('id')==1
self.assertEqual(response.data[0].get('id'), 1)
def test_cannot_read_list_permissions(self): def test_cannot_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
# TODO: remove in version 3.10 with warnings.catch_warnings(record=True):
with warnings.catch_warnings(record=True): warnings.simplefilter("always")
warnings.simplefilter("always") response = object_permissions_list_view(request)
response = object_permissions_list_view(request) assert response.status_code== status.HTTP_200_OK
self.assertEqual(response.status_code, status.HTTP_200_OK) assertresponse.data==[]
self.assertListEqual(response.data, [])
def test_cannot_method_not_allowed(self): def test_cannot_method_not_allowed(self):
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
class BasicPerm(permissions.BasePermission): class BasicPerm(permissions.BasePermission):
@ -507,203 +434,176 @@ denied_view_with_detail = DeniedViewWithDetail.as_view()
denied_object_view = DeniedObjectView.as_view() denied_object_view = DeniedObjectView.as_view()
denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view() denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()
def setUp(self):
BasicModel(text='foo').save()
class CustomPermissionsTests(TestCase): User.objects.create_user('username', 'username@example.com', 'password')
def setUp(self): credentials = basic_auth_header('username', 'password')
BasicModel(text='foo').save() self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials)
User.objects.create_user('username', 'username@example.com', 'password') self.custom_message = 'Custom: You cannot access this resource'
credentials = basic_auth_header('username', 'password')
self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials)
self.custom_message = 'Custom: You cannot access this resource'
def test_permission_denied(self): def test_permission_denied(self):
response = denied_view(self.request, pk=1) response = denied_view(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertNotEqual(detail, self.custom_message) assert detail != self.custom_message
def test_permission_denied_with_custom_detail(self): def test_permission_denied_with_custom_detail(self):
response = denied_view_with_detail(self.request, pk=1) response = denied_view_with_detail(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(detail, self.custom_message) assert detail == self.custom_message
def test_permission_denied_for_object(self): def test_permission_denied_for_object(self):
response = denied_object_view(self.request, pk=1) response = denied_object_view(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertNotEqual(detail, self.custom_message) assert detail != self.custom_message
def test_permission_denied_for_object_with_custom_detail(self): def test_permission_denied_for_object_with_custom_detail(self):
response = denied_object_view_with_detail(self.request, pk=1) response = denied_object_view_with_detail(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(detail, self.custom_message) assert detail == self.custom_message
class PermissionsCompositionTests(TestCase):
def setUp(self): def setUp(self):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user( self.user = User.objects.create_user( self.username, self.email, self.password )
self.username, self.client.login(username=self.username, password=self.password)
self.email,
self.password
)
self.client.login(username=self.username, password=self.password)
def test_and_false(self): def test_and_false(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
composed_perm = permissions.IsAuthenticated & permissions.AllowAny composed_perm = permissions.IsAuthenticated & permissions.AllowAny
assert composed_perm().has_permission(request, None) is False assert composed_perm().has_permission(request, None) is False
def test_and_true(self): def test_and_true(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = permissions.IsAuthenticated & permissions.AllowAny composed_perm = permissions.IsAuthenticated & permissions.AllowAny
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
def test_or_false(self): def test_or_false(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
composed_perm = permissions.IsAuthenticated | permissions.AllowAny composed_perm = permissions.IsAuthenticated | permissions.AllowAny
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
def test_or_true(self): def test_or_true(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = permissions.IsAuthenticated | permissions.AllowAny composed_perm = permissions.IsAuthenticated | permissions.AllowAny
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
def test_not_false(self): def test_not_false(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
composed_perm = ~permissions.IsAuthenticated composed_perm = ~permissions.IsAuthenticated
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
def test_not_true(self): def test_not_true(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ~permissions.AllowAny composed_perm = ~permissions.AllowAny
assert composed_perm().has_permission(request, None) is False assert composed_perm().has_permission(request, None) is False
def test_several_levels_without_negation(self): def test_several_levels_without_negation(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated )
permissions.IsAuthenticated & assert composed_perm().has_permission(request, None) is True
permissions.IsAuthenticated &
permissions.IsAuthenticated &
permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence_with_negation(self): def test_several_levels_and_precedence_with_negation(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & ~ permissions.IsAdminUser & permissions.IsAuthenticated & ~(permissions.IsAdminUser & permissions.IsAdminUser) )
permissions.IsAuthenticated & assert composed_perm().has_permission(request, None) is True
~ permissions.IsAdminUser &
permissions.IsAuthenticated &
~(permissions.IsAdminUser & permissions.IsAdminUser)
)
assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence(self): def test_several_levels_and_precedence(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated | permissions.IsAuthenticated & permissions.IsAuthenticated )
permissions.IsAuthenticated & assert composed_perm().has_permission(request, None) is True
permissions.IsAuthenticated |
permissions.IsAuthenticated &
permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_or_lazyness(self): deftest_or_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None)
assert hasperm is True
mock_allow.assert_called_once()
mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_allow.assert_called_once() mock_deny.assert_called_once()
mock_deny.assert_not_called() mock_allow.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True)
mock_deny.assert_called_once()
mock_allow.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_or_lazyness(self): deftest_object_or_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None)
assert hasperm is True
mock_allow.assert_called_once()
mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_allow.assert_called_once() mock_deny.assert_called_once()
mock_deny.assert_not_called() mock_allow.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True)
mock_deny.assert_called_once()
mock_allow.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_and_lazyness(self): deftest_and_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None)
assert hasperm is False
mock_allow.assert_called_once()
mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_called_once() mock_allow.assert_not_called()
mock_deny.assert_called_once() mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False)
mock_allow.assert_not_called()
mock_deny.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_and_lazyness(self): deftest_object_and_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None)
assert hasperm is False
mock_allow.assert_called_once()
mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_called_once() mock_allow.assert_not_called()
mock_deny.assert_called_once() mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False)
mock_allow.assert_not_called()
mock_deny.assert_called_once()

View File

@ -18,41 +18,30 @@ class UserUpdate(generics.UpdateAPIView):
serializer_class = UserSerializer serializer_class = UserSerializer
class TestPrefetchRelatedUpdates(TestCase): def setUp(self):
def setUp(self): self.user = User.objects.create(username='tom', email='tom@example.com')
self.user = User.objects.create(username='tom', email='tom@example.com') self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.user.groups.set(self.groups)
self.user.groups.set(self.groups)
def test_prefetch_related_updates(self): def test_prefetch_related_updates(self):
view = UserUpdate.as_view() view = UserUpdate.as_view()
pk = self.user.pk pk = self.user.pk
groups_pk = self.groups[0].pk groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk) response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1 assert User.objects.get(pk=pk).groups.count() == 1
expected = { expected = { 'id': pk, 'username': 'new', 'groups': [1], 'email': 'tom@example.com' }
'id': pk, assert response.data == expected
'username': 'new',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
def test_prefetch_related_excluding_instance_from_original_queryset(self): def test_prefetch_related_excluding_instance_from_original_queryset(self):
""" """
Regression test for https://github.com/encode/django-rest-framework/issues/4661 Regression test for https://github.com/encode/django-rest-framework/issues/4661
""" """
view = UserUpdate.as_view() view = UserUpdate.as_view()
pk = self.user.pk pk = self.user.pk
groups_pk = self.groups[0].pk groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk) response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1 assert User.objects.get(pk=pk).groups.count() == 1
expected = { expected = { 'id': pk, 'username': 'exclude', 'groups': [1], 'email': 'tom@example.com' }
'id': pk, assert response.data == expected
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected

View File

@ -70,380 +70,268 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedManyToManyTests(TestCase): def setUp(self):
def setUp(self): for idx in range(1, 4):
for idx in range(1, 4): target = ManyToManyTarget(name='target-%d' % idx)
target = ManyToManyTarget(name='target-%d' % idx) target.save()
target.save() source = ManyToManySource(name='source-%d' % idx)
source = ManyToManySource(name='source-%d' % idx) source.save()
source.save() for target in ManyToManyTarget.objects.all():
for target in ManyToManyTarget.objects.all(): source.targets.add(target)
source.targets.add(target)
def test_relative_hyperlinks(self): def test_relative_hyperlinks(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
expected = [ expected = [ {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} ]
{'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, with self.assertNumQueries(4):
{'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, assert serializer.data == expected
{'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve(self): def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, with self.assertNumQueries(4):
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, assert serializer.data == expected
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve_prefetch_related(self): def test_many_to_many_retrieve_prefetch_related(self):
queryset = ManyToManySource.objects.all().prefetch_related('targets') queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
with self.assertNumQueries(2): with self.assertNumQueries(2):
serializer.data serializer.data
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, with self.assertNumQueries(4):
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, assert serializer.data == expected
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_update(self): def test_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1) instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
queryset = ManyToManySource.objects.all()
# Ensure source 1 is updated, and everything else is as expected serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
queryset = ManyToManySource.objects.all() expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) assert serializer.data == expected
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
assert serializer.data == expected
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 1 is updated, and everything else is as expected queryset = ManyToManyTarget.objects.all()
queryset = ManyToManyTarget.objects.all() serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
expected = [ assert serializer.data == expected
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
assert serializer.data == expected
def test_many_to_many_create(self): def test_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
serializer = ManyToManySourceSerializer(data=data, context={'request': request}) serializer = ManyToManySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
queryset = ManyToManySource.objects.all()
# Ensure source 4 is added, and everything else is as expected serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
queryset = ManyToManySource.objects.all() expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ]
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) assert serializer.data == expected
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
]
assert serializer.data == expected
def test_reverse_many_to_many_create(self): def test_reverse_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-4' assert obj.name == 'target-4'
queryset = ManyToManyTarget.objects.all()
# Ensure target 4 is added, and everything else is as expected serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
queryset = ManyToManyTarget.objects.all() expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ]
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) assert serializer.data == expected
expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
]
assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedForeignKeyTests(TestCase): def setUp(self):
def setUp(self): target = ForeignKeyTarget(name='target-1')
target = ForeignKeyTarget(name='target-1') target.save()
target.save() new_target = ForeignKeyTarget(name='target-2')
new_target = ForeignKeyTarget(name='target-2') new_target.save()
new_target.save() for idx in range(1, 4):
for idx in range(1, 4): source = ForeignKeySource(name='source-%d' % idx, target=target)
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
with self.assertNumQueries(1):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
with self.assertNumQueries(3):
assert serializer.data == expected
def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert not serializer.is_valid()
assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']}
def test_reverse_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected
serializer.save()
assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
assert serializer.data == expected
def test_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-3'
# Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']}
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedNullableForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedNullableOneToOneTests(TestCase):
def setUp(self):
target = OneToOneTarget(name='target-1')
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
source = NullableOneToOneSource(name='source-1', target=target)
source.save() source.save()
def test_reverse_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve(self):
queryset = OneToOneTarget.objects.all() queryset = ForeignKeySource.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
{'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, with self.assertNumQueries(1):
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
with self.assertNumQueries(3):
assert serializer.data == expected
def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert not serializer.is_valid()
assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']}
def test_reverse_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
assert new_serializer.data == expected
serializer.save()
assert serializer.data == data
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
assert serializer.data == expected
def test_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, ]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-3'
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']}
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
def setUp(self):
target = OneToOneTarget(name='target-1')
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
source = NullableOneToOneSource(name='source-1', target=target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ]
assert serializer.data == expected

View File

@ -77,496 +77,373 @@ class OneToOnePKSourceSerializer(serializers.ModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
class PKManyToManyTests(TestCase): def setUp(self):
def setUp(self): for idx in range(1, 4):
for idx in range(1, 4): target = ManyToManyTarget(name='target-%d' % idx)
target = ManyToManyTarget(name='target-%d' % idx) target.save()
target.save() source = ManyToManySource(name='source-%d' % idx)
source = ManyToManySource(name='source-%d' % idx) source.save()
source.save() for target in ManyToManyTarget.objects.all():
for target in ManyToManyTarget.objects.all(): source.targets.add(target)
source.targets.add(target)
def test_many_to_many_retrieve(self): def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
{'id': 1, 'name': 'source-1', 'targets': [1]}, with self.assertNumQueries(4):
{'id': 2, 'name': 'source-2', 'targets': [1, 2]}, assert serializer.data == expected
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve_prefetch_related(self): def test_many_to_many_retrieve_prefetch_related(self):
queryset = ManyToManySource.objects.all().prefetch_related('targets') queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
with self.assertNumQueries(2): with self.assertNumQueries(2):
serializer.data serializer.data
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, with self.assertNumQueries(4):
{'id': 2, 'name': 'target-2', 'sources': [2, 3]}, assert serializer.data == expected
{'id': 3, 'name': 'target-3', 'sources': [3]}
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_update(self): def test_many_to_many_update(self):
data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1) instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data) serializer = ManyToManySourceSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
queryset = ManyToManySource.objects.all()
# Ensure source 1 is updated, and everything else is as expected serializer = ManyToManySourceSerializer(queryset, many=True)
queryset = ManyToManySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
serializer = ManyToManySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
assert serializer.data == expected
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'id': 1, 'name': 'target-1', 'sources': [1]} data = {'id': 1, 'name': 'target-1', 'sources': [1]}
instance = ManyToManyTarget.objects.get(pk=1) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data) serializer = ManyToManyTargetSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
queryset = ManyToManyTarget.objects.all()
# Ensure target 1 is updated, and everything else is as expected serializer = ManyToManyTargetSerializer(queryset, many=True)
queryset = ManyToManyTarget.objects.all() expected = [ {'id': 1, 'name': 'target-1', 'sources': [1]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
serializer = ManyToManyTargetSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'target-1', 'sources': [1]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]}
]
assert serializer.data == expected
def test_many_to_many_create(self): def test_many_to_many_create(self):
data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data) serializer = ManyToManySourceSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
queryset = ManyToManySource.objects.all()
# Ensure source 4 is added, and everything else is as expected serializer = ManyToManySourceSerializer(queryset, many=True)
queryset = ManyToManySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, ]
serializer = ManyToManySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
{'id': 4, 'name': 'source-4', 'targets': [1, 3]},
]
assert serializer.data == expected
def test_many_to_many_unsaved(self): def test_many_to_many_unsaved(self):
source = ManyToManySource(name='source-unsaved') source = ManyToManySource(name='source-unsaved')
serializer = ManyToManySourceSerializer(source)
serializer = ManyToManySourceSerializer(source) expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
with self.assertNumQueries(0):
expected = {'id': None, 'name': 'source-unsaved', 'targets': []} assert serializer.data == expected
# no query if source hasn't been created yet
with self.assertNumQueries(0):
assert serializer.data == expected
def test_reverse_many_to_many_create(self): def test_reverse_many_to_many_create(self):
data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data) serializer = ManyToManyTargetSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-4' assert obj.name == 'target-4'
queryset = ManyToManyTarget.objects.all()
# Ensure target 4 is added, and everything else is as expected serializer = ManyToManyTargetSerializer(queryset, many=True)
queryset = ManyToManyTarget.objects.all() expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]}, {'id': 4, 'name': 'target-4', 'sources': [1, 3]} ]
serializer = ManyToManyTargetSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]},
{'id': 4, 'name': 'target-4', 'sources': [1, 3]}
]
assert serializer.data == expected
class PKForeignKeyTests(TestCase): def setUp(self):
def setUp(self): target = ForeignKeyTarget(name='target-1')
target = ForeignKeyTarget(name='target-1') target.save()
target.save() new_target = ForeignKeyTarget(name='target-2')
new_target = ForeignKeyTarget(name='target-2') new_target.save()
new_target.save() for idx in range(1, 4):
for idx in range(1, 4): source = ForeignKeySource(name='source-%d' % idx, target=target)
source = ForeignKeySource(name='source-%d' % idx, target=target) source.save()
source.save()
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
{'id': 1, 'name': 'source-1', 'target': 1}, with self.assertNumQueries(1):
{'id': 2, 'name': 'source-2', 'target': 1}, assert serializer.data == expected
{'id': 3, 'name': 'source-3', 'target': 1}
]
with self.assertNumQueries(1):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, with self.assertNumQueries(3):
{'id': 2, 'name': 'target-2', 'sources': []}, assert serializer.data == expected
]
with self.assertNumQueries(3):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self): def test_reverse_foreign_key_retrieve_prefetch_related(self):
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2): with self.assertNumQueries(2):
serializer.data serializer.data
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 2} data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
queryset = ForeignKeySource.objects.all()
# Ensure source 1 is updated, and everything else is as expected serializer = ForeignKeySourceSerializer(queryset, many=True)
queryset = ForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': 2}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
serializer = ForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': 2},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1}
]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 'foo'} data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid() assert not serializer.is_valid()
assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received str.']} assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received str.']}
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save queryset = ForeignKeyTarget.objects.all()
# hasn't been called. new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
queryset = ForeignKeyTarget.objects.all() expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) assert new_serializer.data == expected
expected = [ serializer.save()
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, assert serializer.data == data
{'id': 2, 'name': 'target-2', 'sources': []}, queryset = ForeignKeyTarget.objects.all()
] serializer = ForeignKeyTargetSerializer(queryset, many=True)
assert new_serializer.data == expected expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, ]
assert serializer.data == expected
serializer.save()
assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': [1, 3]},
]
assert serializer.data == expected
def test_foreign_key_create(self): def test_foreign_key_create(self):
data = {'id': 4, 'name': 'source-4', 'target': 2} data = {'id': 4, 'name': 'source-4', 'target': 2}
serializer = ForeignKeySourceSerializer(data=data) serializer = ForeignKeySourceSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
queryset = ForeignKeySource.objects.all()
# Ensure source 4 is added, and everything else is as expected serializer = ForeignKeySourceSerializer(queryset, many=True)
queryset = ForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1}, {'id': 4, 'name': 'source-4', 'target': 2}, ]
serializer = ForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1},
{'id': 4, 'name': 'source-4', 'target': 2},
]
assert serializer.data == expected
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
serializer = ForeignKeyTargetSerializer(data=data) serializer = ForeignKeyTargetSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-3' assert obj.name == 'target-3'
queryset = ForeignKeyTarget.objects.all()
# Ensure target 3 is added, and everything else is as expected serializer = ForeignKeyTargetSerializer(queryset, many=True)
queryset = ForeignKeyTarget.objects.all() expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, ]
serializer = ForeignKeyTargetSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': [1, 3]},
]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid() assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']} assert serializer.errors == {'target': ['This field may not be null.']}
def test_foreign_key_with_unsaved(self): def test_foreign_key_with_unsaved(self):
source = ForeignKeySource(name='source-unsaved') source = ForeignKeySource(name='source-unsaved')
expected = {'id': None, 'name': 'source-unsaved', 'target': None} expected = {'id': None, 'name': 'source-unsaved', 'target': None}
serializer = ForeignKeySourceSerializer(source)
serializer = ForeignKeySourceSerializer(source) with self.assertNumQueries(0):
assert serializer.data == expected
# no query if source hasn't been created yet
with self.assertNumQueries(0):
assert serializer.data == expected
def test_foreign_key_with_empty(self): def test_foreign_key_with_empty(self):
""" """
Regression test for #1072 Regression test for #1072
https://github.com/encode/django-rest-framework/issues/1072 https://github.com/encode/django-rest-framework/issues/1072
""" """
serializer = NullableForeignKeySourceSerializer() serializer = NullableForeignKeySourceSerializer()
assert serializer.data['target'] is None assert serializer.data['target'] is None
def test_foreign_key_not_required(self): def test_foreign_key_not_required(self):
""" """
Let's say we wanted to fill the non-nullable model field inside Let's say we wanted to fill the non-nullable model field inside
Model.save(), we would make it empty and not required. Model.save(), we would make it empty and not required.
""" """
class ModelSerializer(ForeignKeySourceSerializer): class ModelSerializer(ForeignKeySourceSerializer):
class Meta(ForeignKeySourceSerializer.Meta): class Meta(ForeignKeySourceSerializer.Meta):
extra_kwargs = {'target': {'required': False}} extra_kwargs = {'target': {'required': False}}
serializer = ModelSerializer(data={'name': 'test'}) serializer = ModelSerializer(data={'name': 'test'})
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
assert 'target' not in serializer.validated_data assert'target'notinserializer.validated_data
def test_queryset_size_without_limited_choices(self): def test_queryset_size_without_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target") limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save() limited_target.save()
queryset = ForeignKeySourceSerializer().fields["target"].get_queryset() queryset = ForeignKeySourceSerializer().fields["target"].get_queryset()
assert len(queryset) == 3 assert len(queryset) == 3
def test_queryset_size_with_limited_choices(self): def test_queryset_size_with_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target") limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save() limited_target.save()
queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset() queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset()
assert len(queryset) == 1 assert len(queryset) == 1
def test_queryset_size_with_Q_limited_choices(self): def test_queryset_size_with_Q_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target") limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save() limited_target.save()
class QLimitedChoicesSerializer(serializers.ModelSerializer):
class QLimitedChoicesSerializer(serializers.ModelSerializer): class Meta:
class Meta: model = ForeignKeySourceWithQLimitedChoices
model = ForeignKeySourceWithQLimitedChoices fields = ("id", "target")
fields = ("id", "target")
queryset = QLimitedChoicesSerializer().fields["target"].get_queryset() queryset = QLimitedChoicesSerializer().fields["target"].get_queryset()
assert len(queryset) == 1 assertlen(queryset)==1
class PKNullableForeignKeyTests(TestCase): def setUp(self):
def setUp(self): target = ForeignKeyTarget(name='target-1')
target = ForeignKeyTarget(name='target-1') target.save()
target.save() for idx in range(1, 4):
for idx in range(1, 4): if idx == 3:
if idx == 3: target = None
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, ]
{'id': 1, 'name': 'source-1', 'target': 1}, assert serializer.data == expected
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': 'source-4', 'target': None} data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
# Ensure source 4 is created, and everything else is as expected serializer = NullableForeignKeySourceSerializer(queryset, many=True)
queryset = NullableForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
serializer = NullableForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self): def test_foreign_key_create_with_valid_emptystring(self):
""" """
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'id': 4, 'name': 'source-4', 'target': ''} data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None} expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
assert obj.name == 'source-4' assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
# Ensure source 4 is created, and everything else is as expected serializer = NullableForeignKeySourceSerializer(queryset, many=True)
queryset = NullableForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
serializer = NullableForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
queryset = NullableForeignKeySource.objects.all()
# Ensure source 1 is updated, and everything else is as expected serializer = NullableForeignKeySourceSerializer(queryset, many=True)
queryset = NullableForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
serializer = NullableForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self): def test_foreign_key_update_with_valid_emptystring(self):
""" """
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'id': 1, 'name': 'source-1', 'target': ''} data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None} expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
queryset = NullableForeignKeySource.objects.all()
# Ensure source 1 is updated, and everything else is as expected serializer = NullableForeignKeySourceSerializer(queryset, many=True)
queryset = NullableForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
serializer = NullableForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected
def test_null_uuid_foreign_key_serializes_as_none(self): def test_null_uuid_foreign_key_serializes_as_none(self):
source = NullableUUIDForeignKeySource(name='Source') source = NullableUUIDForeignKeySource(name='Source')
serializer = NullableUUIDForeignKeySourceSerializer(source) serializer = NullableUUIDForeignKeySourceSerializer(source)
data = serializer.data data = serializer.data
assert data["target"] is None assert data["target"] is None
def test_nullable_uuid_foreign_key_is_valid_when_none(self): def test_nullable_uuid_foreign_key_is_valid_when_none(self):
data = {"name": "Source", "target": None} data = {"name": "Source", "target": None}
serializer = NullableUUIDForeignKeySourceSerializer(data=data) serializer = NullableUUIDForeignKeySourceSerializer(data=data)
assert serializer.is_valid(), serializer.errors assert serializer.is_valid(), serializer.errors
class PKNullableOneToOneTests(TestCase): def setUp(self):
def setUp(self): target = OneToOneTarget(name='target-1')
target = OneToOneTarget(name='target-1') target.save()
target.save() new_target = OneToOneTarget(name='target-2')
new_target = OneToOneTarget(name='target-2') new_target.save()
new_target.save() source = NullableOneToOneSource(name='source-1', target=new_target)
source = NullableOneToOneSource(name='source-1', target=new_target) source.save()
source.save()
def test_reverse_foreign_key_retrieve_with_null(self): def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True) serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ]
{'id': 1, 'name': 'target-1', 'nullable_source': None}, assert serializer.data == expected
{'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
assert serializer.data == expected
class OneToOnePrimaryKeyTests(TestCase):
def setUp(self): def setUp(self):
# Given: Some target models already exist # Given: Some target models already exist
self.target = target = OneToOneTarget(name='target-1') self.target = target = OneToOneTarget(name='target-1')
target.save() target.save()
self.alt_target = alt_target = OneToOneTarget(name='target-2') self.alt_target = alt_target = OneToOneTarget(name='target-2')
alt_target.save() alt_target.save()
def test_one_to_one_when_primary_key(self): def test_one_to_one_when_primary_key(self):
# When: Creating a Source pointing at the id of the second Target # When: Creating a Source pointing at the id of the second Target
target_pk = self.alt_target.id target_pk = self.alt_target.id
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
# Then: The source is valid with the serializer if not source.is_valid():
if not source.is_valid(): self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
# Then: Saving the serializer creates a new object # Then: Saving the serializer creates a new object
new_source = source.save() new_source = source.save()
# Then: The new object has the same pk as the target object assertnew_source.pk==target_pk
self.assertEqual(new_source.pk, target_pk)
def test_one_to_one_when_primary_key_no_duplicates(self): def test_one_to_one_when_primary_key_no_duplicates(self):
# When: Creating a Source pointing at the id of the second Target # When: Creating a Source pointing at the id of the second Target
target_pk = self.target.id target_pk = self.target.id
data = {'name': 'source-1', 'target': target_pk} data = {'name': 'source-1', 'target': target_pk}
source = OneToOnePKSourceSerializer(data=data) source = OneToOnePKSourceSerializer(data=data)
# Then: The source is valid with the serializer assert source.is_valid()
self.assertTrue(source.is_valid()) new_source = source.save()
# Then: Saving the serializer creates a new object assert new_source.pk == target_pk
new_source = source.save() second_source = OneToOnePKSourceSerializer(data=data)
# Then: The new object has the same pk as the target object assert not second_source.is_valid()
self.assertEqual(new_source.pk, target_pk) expected = {'target': ['one to one pk source with this target already exists.']}
# When: Trying to create a second object assert second_source.errors == expected
second_source = OneToOnePKSourceSerializer(data=data)
self.assertFalse(second_source.is_valid())
expected = {'target': ['one to one pk source with this target already exists.']}
self.assertDictEqual(second_source.errors, expected)
def test_one_to_one_when_primary_key_does_not_exist(self): def test_one_to_one_when_primary_key_does_not_exist(self):
# Given: a target PK that does not exist # Given: a target PK that does not exist
target_pk = self.target.pk + self.alt_target.pk target_pk = self.target.pk + self.alt_target.pk
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
# Then: The source is not valid with the serializer assert not source.is_valid()
self.assertFalse(source.is_valid()) assert "Invalid pk" in source.errors['target'][0]
self.assertIn("Invalid pk", source.errors['target'][0]) assert "object does not exist" in source.errors['target'][0]
self.assertIn("object does not exist", source.errors['target'][0])

View File

@ -42,246 +42,177 @@ class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
# TODO: M2M Tests, FKTests (Non-nullable), One2One # TODO: M2M Tests, FKTests (Non-nullable), One2One
class SlugForeignKeyTests(TestCase): def setUp(self):
def setUp(self): target = ForeignKeyTarget(name='target-1')
target = ForeignKeyTarget(name='target-1') target.save()
target.save() new_target = ForeignKeyTarget(name='target-2')
new_target = ForeignKeyTarget(name='target-2') new_target.save()
new_target.save() for idx in range(1, 4):
for idx in range(1, 4): source = ForeignKeySource(name='source-%d' % idx, target=target)
source = ForeignKeySource(name='source-%d' % idx, target=target) source.save()
source.save()
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, with self.assertNumQueries(4):
{'id': 2, 'name': 'source-2', 'target': 'target-1'}, assert serializer.data == expected
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_foreign_key_retrieve_select_related(self): def test_foreign_key_retrieve_select_related(self):
queryset = ForeignKeySource.objects.all().select_related('target') queryset = ForeignKeySource.objects.all().select_related('target')
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
with self.assertNumQueries(1): with self.assertNumQueries(1):
serializer.data serializer.data
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, assert serializer.data == expected
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self): def test_reverse_foreign_key_retrieve_prefetch_related(self):
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2): with self.assertNumQueries(2):
serializer.data serializer.data
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
queryset = ForeignKeySource.objects.all()
# Ensure source 1 is updated, and everything else is as expected serializer = ForeignKeySourceSerializer(queryset, many=True)
queryset = ForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-2'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
serializer = ForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-2'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 123} data = {'id': 1, 'name': 'source-1', 'target': 123}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid() assert not serializer.is_valid()
assert serializer.errors == {'target': ['Object with name=123 does not exist.']} assert serializer.errors == {'target': ['Object with name=123 does not exist.']}
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save queryset = ForeignKeyTarget.objects.all()
# hasn't been called. new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
queryset = ForeignKeyTarget.objects.all() expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) assert new_serializer.data == expected
expected = [ serializer.save()
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, assert serializer.data == data
{'id': 2, 'name': 'target-2', 'sources': []}, queryset = ForeignKeyTarget.objects.all()
] serializer = ForeignKeyTargetSerializer(queryset, many=True)
assert new_serializer.data == expected expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, ]
assert serializer.data == expected
serializer.save()
assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
]
assert serializer.data == expected
def test_foreign_key_create(self): def test_foreign_key_create(self):
data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
serializer = ForeignKeySourceSerializer(data=data) serializer = ForeignKeySourceSerializer(data=data)
serializer.is_valid() serializer.is_valid()
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
queryset = ForeignKeySource.objects.all()
# Ensure source 4 is added, and everything else is as expected serializer = ForeignKeySourceSerializer(queryset, many=True)
queryset = ForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'}, {'id': 4, 'name': 'source-4', 'target': 'target-2'}, ]
serializer = ForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'},
{'id': 4, 'name': 'source-4', 'target': 'target-2'},
]
assert serializer.data == expected
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data) serializer = ForeignKeyTargetSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-3' assert obj.name == 'target-3'
queryset = ForeignKeyTarget.objects.all()
# Ensure target 3 is added, and everything else is as expected serializer = ForeignKeyTargetSerializer(queryset, many=True)
queryset = ForeignKeyTarget.objects.all() expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, ]
serializer = ForeignKeyTargetSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid() assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']} assert serializer.errors == {'target': ['This field may not be null.']}
class SlugNullableForeignKeyTests(TestCase): def setUp(self):
def setUp(self): target = ForeignKeyTarget(name='target-1')
target = ForeignKeyTarget(name='target-1') target.save()
target.save() for idx in range(1, 4):
for idx in range(1, 4): if idx == 3:
if idx == 3: target = None
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'}, assert serializer.data == expected
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': 'source-4', 'target': None} data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
# Ensure source 4 is created, and everything else is as expected serializer = NullableForeignKeySourceSerializer(queryset, many=True)
queryset = NullableForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
serializer = NullableForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self): def test_foreign_key_create_with_valid_emptystring(self):
""" """
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'id': 4, 'name': 'source-4', 'target': ''} data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None} expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid() assert serializer.is_valid()
obj = serializer.save() obj = serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
assert obj.name == 'source-4' assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
# Ensure source 4 is created, and everything else is as expected serializer = NullableForeignKeySourceSerializer(queryset, many=True)
queryset = NullableForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
serializer = NullableForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None} data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
queryset = NullableForeignKeySource.objects.all()
# Ensure source 1 is updated, and everything else is as expected serializer = NullableForeignKeySourceSerializer(queryset, many=True)
queryset = NullableForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
serializer = NullableForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self): def test_foreign_key_update_with_valid_emptystring(self):
""" """
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'id': 1, 'name': 'source-1', 'target': ''} data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None} expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
queryset = NullableForeignKeySource.objects.all()
# Ensure source 1 is updated, and everything else is as expected serializer = NullableForeignKeySourceSerializer(queryset, many=True)
queryset = NullableForeignKeySource.objects.all() expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
serializer = NullableForeignKeySourceSerializer(queryset, many=True) assert serializer.data == expected
expected = [
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected

View File

@ -49,11 +49,10 @@ class DummyTestModel(models.Model):
name = models.CharField(max_length=42, default='') name = models.CharField(max_length=42, default='')
class BasicRendererTests(TestCase): def test_expected_results(self):
def test_expected_results(self): for value, renderer_cls, expected in expected_results:
for value, renderer_cls, expected in expected_results: output = renderer_cls().render(value)
output = renderer_cls().render(value) assert output == expected
self.assertEqual(output, expected)
class RendererA(BaseRenderer): class RendererA(BaseRenderer):
@ -144,123 +143,114 @@ class POSTDeniedView(APIView):
return Response() return Response()
class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self):
def test_only_permitted_forms_are_displayed(self): view = POSTDeniedView.as_view()
view = POSTDeniedView.as_view() request = APIRequestFactory().get('/')
request = APIRequestFactory().get('/') response = view(request).render()
response = view(request).render() self.assertNotContains(response, '>POST<')
self.assertNotContains(response, '>POST<') self.assertContains(response, '>PUT<')
self.assertContains(response, '>PUT<') self.assertContains(response, '>PATCH<')
self.assertContains(response, '>PATCH<')
@override_settings(ROOT_URLCONF='tests.test_renderers') @override_settings(ROOT_URLCONF='tests.test_renderers')
class RendererEndToEndTests(TestCase): """
"""
End-to-end testing of renderers using an RendererMixin on a generic view. End-to-end testing of renderers using an RendererMixin on a generic view.
""" """
def test_default_renderer_serializes_content(self): deftest_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/') resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self): def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests.""" """No response must be included in HEAD requests."""
resp = self.client.head('/') resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, b'') assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*') resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self): def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)""" (In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self): def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)""" (In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_unsatisfiable_accept_header_on_request_returns_406_status(self): def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar') resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) assert resp.status_code == status.HTTP_406_NOT_ACCEPTABLE
def test_specified_renderer_serializes_content_on_format_query(self): def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
param = '?%s=%s' % ( param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
api_settings.URL_FORMAT_OVERRIDE, resp = self.client.get('/' + param)
RendererB.format assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
resp = self.client.get('/' + param) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self): def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching """If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/something.formatb') resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the response."""
param = '?%s=%s' % ( param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
api_settings.URL_FORMAT_OVERRIDE, resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type)
RendererB.format assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
resp = self.client.get('/' + param, assert resp.status_code == DUMMYSTATUS
HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_parse_error_renderers_browsable_api(self): def test_parse_error_renderers_browsable_api(self):
"""Invalid data should still render the browsable API correctly.""" """Invalid data should still render the browsable API correctly."""
resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) assert resp.status_code == status.HTTP_400_BAD_REQUEST
def test_204_no_content_responses_have_no_content_type_set(self): def test_204_no_content_responses_have_no_content_type_set(self):
""" """
Regression test for #1196 Regression test for #1196
https://github.com/encode/django-rest-framework/issues/1196 https://github.com/encode/django-rest-framework/issues/1196
""" """
resp = self.client.get('/empty') resp = self.client.get('/empty')
self.assertEqual(resp.get('Content-Type', None), None) assert resp.get('Content-Type', None) == None
self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) assert resp.status_code == status.HTTP_204_NO_CONTENT
def test_contains_headers_of_api_response(self): def test_contains_headers_of_api_response(self):
""" """
Issue #1437 Issue #1437
Test we display the headers of the API response and not those from the Test we display the headers of the API response and not those from the
HTML response HTML response
""" """
resp = self.client.get('/html1') resp = self.client.get('/html1')
self.assertContains(resp, '>GET, HEAD, OPTIONS<') self.assertContains(resp, '>GET, HEAD, OPTIONS<')
self.assertContains(resp, '>application/json<') self.assertContains(resp, '>application/json<')
self.assertNotContains(resp, '>text/html; charset=utf-8<') self.assertNotContains(resp, '>text/html; charset=utf-8<')
_flat_repr = '{"foo":["bar","baz"]}' _flat_repr = '{"foo":["bar","baz"]}'
@ -275,183 +265,174 @@ def strip_trailing_whitespace(content):
return re.sub(' +\n', '\n', content) return re.sub(' +\n', '\n', content)
class BaseRendererTests(TestCase): """
"""
Tests BaseRenderer Tests BaseRenderer
""" """
def test_render_raise_error(self): deftest_render_raise_error(self):
""" """
BaseRenderer.render should raise NotImplementedError BaseRenderer.render should raise NotImplementedError
""" """
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
BaseRenderer().render('test') BaseRenderer().render('test')
class JSONRendererTests(TestCase): """
"""
Tests specific to the JSON Renderer Tests specific to the JSON Renderer
""" """
deftest_render_lazy_strings(self):
def test_render_lazy_strings(self): """
"""
JSONRenderer should deal with lazy translated strings. JSONRenderer should deal with lazy translated strings.
""" """
ret = JSONRenderer().render(_('test')) ret = JSONRenderer().render(_('test'))
self.assertEqual(ret, b'"test"') assert ret == b'"test"'
def test_render_queryset_values(self): def test_render_queryset_values(self):
o = DummyTestModel.objects.create(name='dummy') o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values('id', 'name') qs = DummyTestModel.objects.values('id', 'name')
ret = JSONRenderer().render(qs) ret = JSONRenderer().render(qs)
data = json.loads(ret.decode()) data = json.loads(ret.decode())
self.assertEqual(data, [{'id': o.id, 'name': o.name}]) assert data == [{'id': o.id, 'name': o.name}]
def test_render_queryset_values_list(self): def test_render_queryset_values_list(self):
o = DummyTestModel.objects.create(name='dummy') o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values_list('id', 'name') qs = DummyTestModel.objects.values_list('id', 'name')
ret = JSONRenderer().render(qs) ret = JSONRenderer().render(qs)
data = json.loads(ret.decode()) data = json.loads(ret.decode())
self.assertEqual(data, [[o.id, o.name]]) assert data == [[o.id, o.name]]
def test_render_dict_abc_obj(self): def test_render_dict_abc_obj(self):
class Dict(MutableMapping): class Dict(MutableMapping):
def __init__(self): def __init__(self):
self._dict = {} self._dict = {}
def __getitem__(self, key): def __getitem__(self, key):
return self._dict.__getitem__(key) return self._dict.__getitem__(key)
def __setitem__(self, key, value): def __setitem__(self, key, value):
return self._dict.__setitem__(key, value) return self._dict.__setitem__(key, value)
def __delitem__(self, key): def __delitem__(self, key):
return self._dict.__delitem__(key) return self._dict.__delitem__(key)
def __iter__(self): def __iter__(self):
return self._dict.__iter__() return self._dict.__iter__()
def __len__(self): def __len__(self):
return self._dict.__len__() return self._dict.__len__()
def keys(self): def keys(self):
return self._dict.keys() return self._dict.keys()
x = Dict() x = Dict()
x['key'] = 'string value' x['key']='string value'
x[2] = 3 x[2]=3
ret = JSONRenderer().render(x) ret=JSONRenderer().render(x)
data = json.loads(ret.decode()) data=json.loads(ret.decode())
self.assertEqual(data, {'key': 'string value', '2': 3}) assertdata=={'key':'string value','2':3}
def test_render_obj_with_getitem(self): def test_render_obj_with_getitem(self):
class DictLike: class DictLike:
def __init__(self): def __init__(self):
self._dict = {} self._dict = {}
def set(self, value): def set(self, value):
self._dict = dict(value) self._dict = dict(value)
def __getitem__(self, key): def __getitem__(self, key):
return self._dict[key] return self._dict[key]
x = DictLike() x = DictLike()
x.set({'a': 1, 'b': 'string'}) x.set({'a':1,'b':'string'})
with self.assertRaises(TypeError): withpytest.raises(TypeError):
JSONRenderer().render(x) JSONRenderer().render(x)
def test_float_strictness(self): def test_float_strictness(self):
renderer = JSONRenderer() renderer = JSONRenderer()
for value in [float('inf'), float('-inf'), float('nan')]:
# Default to strict with pytest.raises(ValueError):
for value in [float('inf'), float('-inf'), float('nan')]: renderer.render(value)
with pytest.raises(ValueError):
renderer.render(value)
renderer.strict = False renderer.strict = False
assert renderer.render(float('inf')) == b'Infinity' assertrenderer.render(float('inf'))==b'Infinity'
assert renderer.render(float('-inf')) == b'-Infinity' assertrenderer.render(float('-inf'))==b'-Infinity'
assert renderer.render(float('nan')) == b'NaN' assertrenderer.render(float('nan'))==b'NaN'
def test_without_content_type_args(self): def test_without_content_type_args(self):
""" """
Test basic JSON rendering. Test basic JSON rendering.
""" """
obj = {'foo': ['bar', 'baz']} obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library. assert content.decode() == _flat_repr
self.assertEqual(content.decode(), _flat_repr)
def test_with_content_type_args(self): def test_with_content_type_args(self):
""" """
Test JSON rendering with additional content type arguments supplied. Test JSON rendering with additional content type arguments supplied.
""" """
obj = {'foo': ['bar', 'baz']} obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2') content = renderer.render(obj, 'application/json; indent=2')
self.assertEqual(strip_trailing_whitespace(content.decode()), _indented_repr) assert strip_trailing_whitespace(content.decode()) == _indented_repr
class UnicodeJSONRendererTests(TestCase): """
"""
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
def test_proper_encoding(self): deftest_proper_encoding(self):
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode()) assert content == '{"countries":["United Kingdom","France","España"]}'.encode()
def test_u2028_u2029(self): def test_u2028_u2029(self):
# The \u2028 and \u2029 characters should be escaped, # The \u2028 and \u2029 characters should be escaped,
# even when the non-escaping unicode representation is used. # even when the non-escaping unicode representation is used.
# Regression test for #2169 # Regression test for #2169
obj = {'should_escape': '\u2028\u2029'} obj = {'should_escape': '\u2028\u2029'}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode()) assert content == '{"should_escape":"\\u2028\\u2029"}'.encode()
class AsciiJSONRendererTests(TestCase): """
"""
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
def test_proper_encoding(self): deftest_proper_encoding(self):
class AsciiJSONRenderer(JSONRenderer): class AsciiJSONRenderer(JSONRenderer):
ensure_ascii = True ensure_ascii = True
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = AsciiJSONRenderer() renderer=AsciiJSONRenderer()
content = renderer.render(obj, 'application/json') content=renderer.render(obj,'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()) assertcontent=='{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()
# Tests for caching issue, #346 # Tests for caching issue, #346
@override_settings(ROOT_URLCONF='tests.test_renderers') @override_settings(ROOT_URLCONF='tests.test_renderers')
class CacheRenderTest(TestCase): """
"""
Tests specific to caching responses Tests specific to caching responses
""" """
def test_head_caching(self): deftest_head_caching(self):
""" """
Test caching of HEAD requests Test caching of HEAD requests
""" """
response = self.client.head('/cache') response = self.client.head('/cache')
cache.set('key', response) cache.set('key', response)
cached_response = cache.get('key') cached_response = cache.get('key')
assert isinstance(cached_response, Response) assert isinstance(cached_response, Response)
assert cached_response.content == response.content assert cached_response.content == response.content
assert cached_response.status_code == response.status_code assert cached_response.status_code == response.status_code
def test_get_caching(self): def test_get_caching(self):
""" """
Test caching of GET requests Test caching of GET requests
""" """
response = self.client.get('/cache') response = self.client.get('/cache')
cache.set('key', response) cache.set('key', response)
cached_response = cache.get('key') cached_response = cache.get('key')
assert isinstance(cached_response, Response) assert isinstance(cached_response, Response)
assert cached_response.content == response.content assert cached_response.content == response.content
assert cached_response.status_code == response.status_code assert cached_response.status_code == response.status_code
class TestJSONIndentationStyles: class TestJSONIndentationStyles:
@ -476,150 +457,116 @@ class TestJSONIndentationStyles:
assert renderer.render(data) == b'{"a": 1, "b": 2}' assert renderer.render(data) == b'{"a": 1, "b": 2}'
class TestHiddenFieldHTMLFormRenderer(TestCase): def test_hidden_field_rendering(self):
def test_hidden_field_rendering(self): class TestSerializer(serializers.Serializer):
class TestSerializer(serializers.Serializer): published = serializers.HiddenField(default=True)
published = serializers.HiddenField(default=True)
serializer = TestSerializer(data={}) serializer = TestSerializer(data={})
serializer.is_valid() serializer.is_valid()
renderer = HTMLFormRenderer() renderer=HTMLFormRenderer()
field = serializer['published'] field=serializer['published']
rendered = renderer.render_field(field, {}) rendered=renderer.render_field(field,{})
assert rendered == '' assertrendered==''
class TestHTMLFormRenderer(TestCase): def setUp(self):
def setUp(self): class TestSerializer(serializers.Serializer):
class TestSerializer(serializers.Serializer): test_field = serializers.CharField()
test_field = serializers.CharField()
self.renderer = HTMLFormRenderer() self.renderer = HTMLFormRenderer()
self.serializer = TestSerializer(data={}) self.serializer=TestSerializer(data={})
def test_render_with_default_args(self): def test_render_with_default_args(self):
self.serializer.is_valid() self.serializer.is_valid()
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data)
result = renderer.render(self.serializer.data) assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText)
def test_render_with_provided_args(self): def test_render_with_provided_args(self):
self.serializer.is_valid() self.serializer.is_valid()
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data, None, {})
result = renderer.render(self.serializer.data, None, {}) assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText)
class TestChoiceFieldHTMLFormRenderer(TestCase): """
"""
Test rendering ChoiceField with HTMLFormRenderer. Test rendering ChoiceField with HTMLFormRenderer.
""" """
defsetUp(self):
def setUp(self): choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) class TestSerializer(serializers.Serializer):
test_field = serializers.ChoiceField(choices=choices, initial=2)
class TestSerializer(serializers.Serializer):
test_field = serializers.ChoiceField(choices=choices,
initial=2)
self.TestSerializer = TestSerializer self.TestSerializer = TestSerializer
self.renderer = HTMLFormRenderer() self.renderer=HTMLFormRenderer()
def test_render_initial_option(self): def test_render_initial_option(self):
serializer = self.TestSerializer() serializer = self.TestSerializer()
result = self.renderer.render(serializer.data) result = self.renderer.render(serializer.data)
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="2" selected>Option2</option>', result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2" selected>Option2</option>', self.assertInHTML('<option value="12">Option12</option>', result)
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="12">Option12</option>', result)
def test_render_selected_option(self): def test_render_selected_option(self):
serializer = self.TestSerializer(data={'test_field': '12'}) serializer = self.TestSerializer(data={'test_field': '12'})
serializer.is_valid()
serializer.is_valid() result = self.renderer.render(serializer.data)
result = self.renderer.render(serializer.data) assert isinstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>', result)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
class TestMultipleChoiceFieldHTMLFormRenderer(TestCase): """
"""
Test rendering MultipleChoiceField with HTMLFormRenderer. Test rendering MultipleChoiceField with HTMLFormRenderer.
""" """
defsetUp(self):
def setUp(self): self.renderer = HTMLFormRenderer()
self.renderer = HTMLFormRenderer()
def test_render_selected_option_with_string_option_ids(self): def test_render_selected_option_with_string_option_ids(self):
choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), ('}', 'OptionBrace'))
('}', 'OptionBrace')) class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']}) serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid() serializer.is_valid()
result=self.renderer.render(serializer.data)
result = self.renderer.render(serializer.data) assertisinstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',result)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="1">Option1</option>',result)
self.assertInHTML('<option value="2">Option2</option>',result)
self.assertInHTML('<option value="12" selected>Option12</option>', self.assertInHTML('<option value="}">OptionBrace</option>',result)
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
self.assertInHTML('<option value="}">OptionBrace</option>', result)
def test_render_selected_option_with_integer_option_ids(self): def test_render_selected_option_with_integer_option_ids(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
class TestSerializer(serializers.Serializer): test_field = serializers.MultipleChoiceField(choices=choices)
test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']}) serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid() serializer.is_valid()
result=self.renderer.render(serializer.data)
result = self.renderer.render(serializer.data) assertisinstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',result)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="1">Option1</option>',result)
self.assertInHTML('<option value="2">Option2</option>',result)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
class StaticHTMLRendererTests(TestCase): """
"""
Tests specific for Static HTML Renderer Tests specific for Static HTML Renderer
""" """
def setUp(self): defsetUp(self):
self.renderer = StaticHTMLRenderer() self.renderer = StaticHTMLRenderer()
def test_static_renderer(self): def test_static_renderer(self):
data = '<html><body>text</body></html>' data = '<html><body>text</body></html>'
result = self.renderer.render(data) result = self.renderer.render(data)
assert result == data assert result == data
def test_static_renderer_with_exception(self): def test_static_renderer_with_exception(self):
context = { context = { 'response': Response(status=500, exception=True), 'request': Request(HttpRequest()) }
'response': Response(status=500, exception=True), result = self.renderer.render({}, renderer_context=context)
'request': Request(HttpRequest()) assert result == '500 Internal Server Error'
}
result = self.renderer.render({}, renderer_context=context)
assert result == '500 Internal Server Error'
class BrowsableAPIRendererTests(URLPatternsTestCase): class BrowsableAPIRendererTests(URLPatternsTestCase):
@ -658,196 +605,142 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
assert '>Extra list action<' in resp.content.decode() assert '>Extra list action<' in resp.content.decode()
class AdminRendererTests(TestCase):
def setUp(self): def setUp(self):
self.renderer = AdminRenderer() self.renderer = AdminRenderer()
def test_render_when_resource_created(self): def test_render_when_resource_created(self):
class DummyView(APIView): class DummyView(APIView):
renderer_classes = (AdminRenderer, ) renderer_classes = (AdminRenderer, )
request = Request(HttpRequest()) request = Request(HttpRequest())
request.build_absolute_uri = lambda: 'http://example.com' request.build_absolute_uri=lambda:'http://example.com'
response = Response(status=201, headers={'Location': '/test'}) response=Response(status=201,headers={'Location':'/test'})
context = { context={'view':DummyView(),'request':request,'response':response}
'view': DummyView(), result=self.renderer.render(data={'test':'test'},renderer_context=context)
'request': request, assertresult==''
'response': response assertresponse.status_code==status.HTTP_303_SEE_OTHER
} assertresponse['Location']=='http://example.com'
result = self.renderer.render(data={'test': 'test'},
renderer_context=context)
assert result == ''
assert response.status_code == status.HTTP_303_SEE_OTHER
assert response['Location'] == 'http://example.com'
def test_render_dict(self): def test_render_dict(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView):
class DummyView(APIView): renderer_classes = (AdminRenderer, )
renderer_classes = (AdminRenderer, ) def get(self, request):
return Response({'foo': 'a string'})
def get(self, request):
return Response({'foo': 'a string'})
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
response.render() response.render()
self.assertContains(response, '<tr><th>Foo</th><td>a string</td></tr>', html=True) self.assertContains(response,'<tr><th>Foo</th><td>a string</td></tr>',html=True)
def test_render_dict_with_items_key(self): def test_render_dict_with_items_key(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView):
class DummyView(APIView): renderer_classes = (AdminRenderer, )
renderer_classes = (AdminRenderer, ) def get(self, request):
return Response({'items': 'a string'})
def get(self, request):
return Response({'items': 'a string'})
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
response.render() response.render()
self.assertContains(response, '<tr><th>Items</th><td>a string</td></tr>', html=True) self.assertContains(response,'<tr><th>Items</th><td>a string</td></tr>',html=True)
def test_render_dict_with_iteritems_key(self): def test_render_dict_with_iteritems_key(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView):
class DummyView(APIView): renderer_classes = (AdminRenderer, )
renderer_classes = (AdminRenderer, ) def get(self, request):
return Response({'iteritems': 'a string'})
def get(self, request):
return Response({'iteritems': 'a string'})
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
response.render() response.render()
self.assertContains(response, '<tr><th>Iteritems</th><td>a string</td></tr>', html=True) self.assertContains(response,'<tr><th>Iteritems</th><td>a string</td></tr>',html=True)
def test_get_result_url(self): def test_get_result_url(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyGenericViewsetLike(APIView):
class DummyGenericViewsetLike(APIView): lookup_field = 'test'
lookup_field = 'test' def reverse_action(view, *args, **kwargs):
assert kwargs['kwargs']['test'] == 1
def reverse_action(view, *args, **kwargs): return '/example/'
self.assertEqual(kwargs['kwargs']['test'], 1)
return '/example/'
# get the view instance instead of the view function # get the view instance instead of the view function
view = DummyGenericViewsetLike.as_view() view = DummyGenericViewsetLike.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
view = response.renderer_context['view'] view=response.renderer_context['view']
assertself.renderer.get_result_url({'test':1},view)=='/example/'
self.assertEqual(self.renderer.get_result_url({'test': 1}, view), '/example/') assertself.renderer.get_result_url({},view)is None
self.assertIsNone(self.renderer.get_result_url({}, view))
def test_get_result_url_no_result(self): def test_get_result_url_no_result(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView):
class DummyView(APIView): lookup_field = 'test'
lookup_field = 'test'
# get the view instance instead of the view function # get the view instance instead of the view function
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
view = response.renderer_context['view'] view=response.renderer_context['view']
assertself.renderer.get_result_url({'test':1},view)is None
self.assertIsNone(self.renderer.get_result_url({'test': 1}, view)) assertself.renderer.get_result_url({},view)is None
self.assertIsNone(self.renderer.get_result_url({}, view))
def test_get_context_result_urls(self): def test_get_context_result_urls(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView):
class DummyView(APIView): lookup_field = 'test'
lookup_field = 'test' def reverse_action(view, url_name, args=None, kwargs=None):
return '/%s/%d' % (url_name, kwargs['test'])
def reverse_action(view, url_name, args=None, kwargs=None):
return '/%s/%d' % (url_name, kwargs['test'])
# get the view instance instead of the view function # get the view instance instead of the view function
view = DummyView.as_view() view = DummyView.as_view()
request = factory.get('/') request=factory.get('/')
response = view(request) response=view(request)
data=[{'test':1},{'url':'/example','test':2},{'url':None,'test':3},{},]
data = [ context={'view':DummyView(),'request':Request(request),'response':response}
{'test': 1}, context=self.renderer.get_context(data,None,context)
{'url': '/example', 'test': 2}, results=context['results']
{'url': None, 'test': 3}, assertlen(results)==4
{}, assertresults[0]['url']=='/detail/1'
] assertresults[1]['url']=='/example'
context = { assertresults[2]['url']==None
'view': DummyView(), assert'url'not inresults[3]
'request': Request(request),
'response': response
}
context = self.renderer.get_context(data, None, context)
results = context['results']
self.assertEqual(len(results), 4)
self.assertEqual(results[0]['url'], '/detail/1')
self.assertEqual(results[1]['url'], '/example')
self.assertEqual(results[2]['url'], None)
self.assertNotIn('url', results[3])
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestDocumentationRenderer(TestCase):
def test_document_with_link_named_data(self): def test_document_with_link_named_data(self):
""" """
Ref #5395: Doc's `document.data` would fail with a Link named "data". Ref #5395: Doc's `document.data` would fail with a Link named "data".
As per #4972, use templatetag instead. As per #4972, use templatetag instead.
""" """
document = coreapi.Document( document = coreapi.Document( title='Data Endpoint API', url='https://api.example.org/', content={ 'data': coreapi.Link( url='/data/', action='get', fields=[], description='Return data.' ) } )
title='Data Endpoint API', factory = APIRequestFactory()
url='https://api.example.org/', request = factory.get('/')
content={ renderer = DocumentationRenderer()
'data': coreapi.Link( html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
url='/data/', assert '<h1>Data Endpoint API</h1>' in html
action='get',
fields=[],
description='Return data.'
)
}
)
factory = APIRequestFactory()
request = factory.get('/')
renderer = DocumentationRenderer()
html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
assert '<h1>Data Endpoint API</h1>' in html
def test_shell_code_example_rendering(self): def test_shell_code_example_rendering(self):
template = loader.get_template('rest_framework/docs/langs/shell.html') template = loader.get_template('rest_framework/docs/langs/shell.html')
context = { context = { 'document': coreapi.Document(url='https://api.example.org/'), 'link_key': 'testcases > list', 'link': coreapi.Link(url='/data/', action='get', fields=[]), }
'document': coreapi.Document(url='https://api.example.org/'), html = template.render(context)
'link_key': 'testcases > list', assert 'testcases list' in html
'link': coreapi.Link(url='/data/', action='get', fields=[]),
}
html = template.render(context)
assert 'testcases list' in html
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestSchemaJSRenderer(TestCase):
def test_schemajs_output(self): def test_schemajs_output(self):
""" """
Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates, Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates,
and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix. and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix.
""" """
factory = APIRequestFactory() factory = APIRequestFactory()
request = factory.get('/') request = factory.get('/')
renderer = SchemaJSRenderer()
renderer = SchemaJSRenderer() output = renderer.render('data', renderer_context={"request": request})
assert "'ImRhdGEi'" in output
output = renderer.render('data', renderer_context={"request": request}) assert "'b'ImRhdGEi''" not in output
assert "'ImRhdGEi'" in output
assert "'b'ImRhdGEi''" not in output

View File

@ -23,18 +23,11 @@ from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
factory = APIRequestFactory() factory = APIRequestFactory()
def test_request_type(self):
request = Request(factory.get('/'))
class TestInitializer(TestCase): message = ( 'The `request` argument must be an instance of ' '`django.http.HttpRequest`, not `rest_framework.request.Request`.' )
def test_request_type(self): with self.assertRaisesMessage(AssertionError, message):
request = Request(factory.get('/')) Request(request)
message = (
'The `request` argument must be an instance of '
'`django.http.HttpRequest`, not `rest_framework.request.Request`.'
)
with self.assertRaisesMessage(AssertionError, message):
Request(request)
class PlainTextParser(BaseParser): class PlainTextParser(BaseParser):
@ -50,79 +43,78 @@ class PlainTextParser(BaseParser):
return stream.read() return stream.read()
class TestContentParsing(TestCase): def test_standard_behaviour_determines_no_content_GET(self):
def test_standard_behaviour_determines_no_content_GET(self): """
"""
Ensure request.data returns empty QueryDict for GET request. Ensure request.data returns empty QueryDict for GET request.
""" """
request = Request(factory.get('/')) request = Request(factory.get('/'))
assert request.data == {} assert request.data == {}
def test_standard_behaviour_determines_no_content_HEAD(self): def test_standard_behaviour_determines_no_content_HEAD(self):
""" """
Ensure request.data returns empty QueryDict for HEAD request. Ensure request.data returns empty QueryDict for HEAD request.
""" """
request = Request(factory.head('/')) request = Request(factory.head('/'))
assert request.data == {} assert request.data == {}
def test_request_DATA_with_form_content(self): def test_request_DATA_with_form_content(self):
""" """
Ensure request.data returns content for POST request with form content. Ensure request.data returns content for POST request with form content.
""" """
data = {'qwerty': 'uiop'} data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data)) request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser()) request.parsers = (FormParser(), MultiPartParser())
assert list(request.data.items()) == list(data.items()) assert list(request.data.items()) == list(data.items())
def test_request_DATA_with_text_content(self): def test_request_DATA_with_text_content(self):
""" """
Ensure request.data returns content for POST request with Ensure request.data returns content for POST request with
non-form content. non-form content.
""" """
content = b'qwerty' content = b'qwerty'
content_type = 'text/plain' content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type)) request = Request(factory.post('/', content, content_type=content_type))
request.parsers = (PlainTextParser(),) request.parsers = (PlainTextParser(),)
assert request.data == content assert request.data == content
def test_request_POST_with_form_content(self): def test_request_POST_with_form_content(self):
""" """
Ensure request.POST returns content for POST request with form content. Ensure request.POST returns content for POST request with form content.
""" """
data = {'qwerty': 'uiop'} data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data)) request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser()) request.parsers = (FormParser(), MultiPartParser())
assert list(request.POST.items()) == list(data.items()) assert list(request.POST.items()) == list(data.items())
def test_request_POST_with_files(self): def test_request_POST_with_files(self):
""" """
Ensure request.POST returns no content for POST request with file content. Ensure request.POST returns no content for POST request with file content.
""" """
upload = SimpleUploadedFile("file.txt", b"file_content") upload = SimpleUploadedFile("file.txt", b"file_content")
request = Request(factory.post('/', {'upload': upload})) request = Request(factory.post('/', {'upload': upload}))
request.parsers = (FormParser(), MultiPartParser()) request.parsers = (FormParser(), MultiPartParser())
assert list(request.POST) == [] assert list(request.POST) == []
assert list(request.FILES) == ['upload'] assert list(request.FILES) == ['upload']
def test_standard_behaviour_determines_form_content_PUT(self): def test_standard_behaviour_determines_form_content_PUT(self):
""" """
Ensure request.data returns content for PUT request with form content. Ensure request.data returns content for PUT request with form content.
""" """
data = {'qwerty': 'uiop'} data = {'qwerty': 'uiop'}
request = Request(factory.put('/', data)) request = Request(factory.put('/', data))
request.parsers = (FormParser(), MultiPartParser()) request.parsers = (FormParser(), MultiPartParser())
assert list(request.data.items()) == list(data.items()) assert list(request.data.items()) == list(data.items())
def test_standard_behaviour_determines_non_form_content_PUT(self): def test_standard_behaviour_determines_non_form_content_PUT(self):
""" """
Ensure request.data returns content for PUT request with Ensure request.data returns content for PUT request with
non-form content. non-form content.
""" """
content = b'qwerty' content = b'qwerty'
content_type = 'text/plain' content_type = 'text/plain'
request = Request(factory.put('/', content, content_type=content_type)) request = Request(factory.put('/', content, content_type=content_type))
request.parsers = (PlainTextParser(), ) request.parsers = (PlainTextParser(), )
assert request.data == content assert request.data == content
class MockView(APIView): class MockView(APIView):
@ -160,175 +152,148 @@ urlpatterns = [
@override_settings( @override_settings(
ROOT_URLCONF='tests.test_request', ROOT_URLCONF='tests.test_request',
FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler']) FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler'])
class FileUploadTests(TestCase):
def test_fileuploads_closed_at_request_end(self): def test_fileuploads_closed_at_request_end(self):
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
response = self.client.post('/upload/', {'file': f}) response = self.client.post('/upload/', {'file': f})
# sanity check that file was processed # sanity check that file was processed
assert len(response.data) == 1 assert len(response.data) == 1
forfileinresponse.data:
for file in response.data: assert not os.path.exists(file)
assert not os.path.exists(file)
@override_settings(ROOT_URLCONF='tests.test_request') @override_settings(ROOT_URLCONF='tests.test_request')
class TestContentParsingWithAuthentication(TestCase): def setUp(self):
def setUp(self): self.csrf_client = APIClient(enforce_csrf_checks=True)
self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john'
self.username = 'john' self.email = 'lennon@thebeatles.com'
self.email = 'lennon@thebeatles.com' self.password = 'password'
self.password = 'password' self.user = User.objects.create_user(self.username, self.email, self.password)
self.user = User.objects.create_user(self.username, self.email, self.password)
def test_user_logged_in_authentication_has_POST_when_not_logged_in(self): def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
""" """
Ensures request.POST exists after SessionAuthentication when user Ensures request.POST exists after SessionAuthentication when user
doesn't log in. doesn't log in.
""" """
content = {'example': 'example'} content = {'example': 'example'}
response = self.client.post('/', content)
response = self.client.post('/', content) assert status.HTTP_200_OK == response.status_code
assert status.HTTP_200_OK == response.status_code response = self.csrf_client.post('/', content)
assert status.HTTP_200_OK == response.status_code
response = self.csrf_client.post('/', content)
assert status.HTTP_200_OK == response.status_code
class TestUserSetter(TestCase):
def setUp(self): def setUp(self):
# Pass request object through session middleware so session is # Pass request object through session middleware so session is
# available to login and logout functions # available to login and logout functions
self.wrapped_request = factory.get('/') self.wrapped_request = factory.get('/')
self.request = Request(self.wrapped_request) self.request = Request(self.wrapped_request)
SessionMiddleware().process_request(self.wrapped_request) SessionMiddleware().process_request(self.wrapped_request)
AuthenticationMiddleware().process_request(self.wrapped_request) AuthenticationMiddleware().process_request(self.wrapped_request)
User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') self.user = authenticate(username='ringo', password='yellow')
self.user = authenticate(username='ringo', password='yellow')
def test_user_can_be_set(self): def test_user_can_be_set(self):
self.request.user = self.user self.request.user = self.user
assert self.request.user == self.user assert self.request.user == self.user
def test_user_can_login(self): def test_user_can_login(self):
login(self.request, self.user) login(self.request, self.user)
assert self.request.user == self.user assert self.request.user == self.user
def test_user_can_logout(self): def test_user_can_logout(self):
self.request.user = self.user self.request.user = self.user
assert not self.request.user.is_anonymous assert not self.request.user.is_anonymous
logout(self.request) logout(self.request)
assert self.request.user.is_anonymous assert self.request.user.is_anonymous
def test_logged_in_user_is_set_on_wrapped_request(self): def test_logged_in_user_is_set_on_wrapped_request(self):
login(self.request, self.user) login(self.request, self.user)
assert self.wrapped_request.user == self.user assert self.wrapped_request.user == self.user
def test_calling_user_fails_when_attribute_error_is_raised(self): def test_calling_user_fails_when_attribute_error_is_raised(self):
""" """
This proves that when an AttributeError is raised inside of the request.user This proves that when an AttributeError is raised inside of the request.user
property, that we can handle this and report the true, underlying error. property, that we can handle this and report the true, underlying error.
""" """
class AuthRaisesAttributeError: class AuthRaisesAttributeError:
def authenticate(self, request): def authenticate(self, request):
self.MISSPELLED_NAME_THAT_DOESNT_EXIST self.MISSPELLED_NAME_THAT_DOESNT_EXIST
request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),)) request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),))
assertself.wrapped_request.user.is_anonymous
# The middleware processes the underlying Django request, sets anonymous user expected=r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
assert self.wrapped_request.user.is_anonymous withpytest.raises(WrappedAttributeError,match=expected):
request.user
# The DRF request object does not have a user and should run authenticators
expected = r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
with pytest.raises(WrappedAttributeError, match=expected):
request.user
with pytest.raises(WrappedAttributeError, match=expected): with pytest.raises(WrappedAttributeError, match=expected):
hasattr(request, 'user') hasattr(request, 'user')
with pytest.raises(WrappedAttributeError, match=expected): with pytest.raises(WrappedAttributeError, match=expected):
login(request, self.user) login(request, self.user)
class TestAuthSetter(TestCase): def test_auth_can_be_set(self):
def test_auth_can_be_set(self): request = Request(factory.get('/'))
request = Request(factory.get('/')) request.auth = 'DUMMY'
request.auth = 'DUMMY' assert request.auth == 'DUMMY'
assert request.auth == 'DUMMY'
class TestSecure(TestCase):
def test_default_secure_false(self): def test_default_secure_false(self):
request = Request(factory.get('/', secure=False)) request = Request(factory.get('/', secure=False))
assert request.scheme == 'http' assert request.scheme == 'http'
def test_default_secure_true(self): def test_default_secure_true(self):
request = Request(factory.get('/', secure=True)) request = Request(factory.get('/', secure=True))
assert request.scheme == 'https' assert request.scheme == 'https'
class TestHttpRequest(TestCase): def test_attribute_access_proxy(self):
def test_attribute_access_proxy(self): http_request = factory.get('/')
http_request = factory.get('/') request = Request(http_request)
request = Request(http_request) inner_sentinel = object()
http_request.inner_property = inner_sentinel
inner_sentinel = object() assert request.inner_property is inner_sentinel
http_request.inner_property = inner_sentinel outer_sentinel = object()
assert request.inner_property is inner_sentinel request.inner_property = outer_sentinel
assert request.inner_property is outer_sentinel
outer_sentinel = object()
request.inner_property = outer_sentinel
assert request.inner_property is outer_sentinel
def test_exception_proxy(self): def test_exception_proxy(self):
# ensure the exception message is not for the underlying WSGIRequest # ensure the exception message is not for the underlying WSGIRequest
http_request = factory.get('/') http_request = factory.get('/')
request = Request(http_request) request = Request(http_request)
message = "'Request' object has no attribute 'inner_property'"
message = "'Request' object has no attribute 'inner_property'" with self.assertRaisesMessage(AttributeError, message):
with self.assertRaisesMessage(AttributeError, message): request.inner_property
request.inner_property
@override_settings(ROOT_URLCONF='tests.test_request') @override_settings(ROOT_URLCONF='tests.test_request')
def test_duplicate_request_stream_parsing_exception(self): deftest_duplicate_request_stream_parsing_exception(self):
""" """
Check assumption that duplicate stream parsing will result in a Check assumption that duplicate stream parsing will result in a
`RawPostDataException` being raised. `RawPostDataException` being raised.
""" """
response = APIClient().post('/echo/', data={'a': 'b'}, format='json') response = APIClient().post('/echo/', data={'a': 'b'}, format='json')
request = response.renderer_context['request'] request = response.renderer_context['request']
assert request.content_type.startswith('application/json')
# ensure that request stream was consumed by json parser assert response.data == {'a': 'b'}
assert request.content_type.startswith('application/json') with pytest.raises(RawPostDataException):
assert response.data == {'a': 'b'} EchoView.as_view()(request._request)
# pass same HttpRequest to view, stream already consumed
with pytest.raises(RawPostDataException):
EchoView.as_view()(request._request)
@override_settings(ROOT_URLCONF='tests.test_request') @override_settings(ROOT_URLCONF='tests.test_request')
def test_duplicate_request_form_data_access(self): deftest_duplicate_request_form_data_access(self):
""" """
Form data is copied to the underlying django request for middleware Form data is copied to the underlying django request for middleware
and file closing reasons. Duplicate processing of a request with form and file closing reasons. Duplicate processing of a request with form
data is 'safe' in so far as accessing `request.POST` does not trigger data is 'safe' in so far as accessing `request.POST` does not trigger
the duplicate stream parse exception. the duplicate stream parse exception.
""" """
response = APIClient().post('/echo/', data={'a': 'b'}) response = APIClient().post('/echo/', data={'a': 'b'})
request = response.renderer_context['request'] request = response.renderer_context['request']
assert request.content_type.startswith('multipart/form-data')
# ensure that request stream was consumed by form parser assert response.data == {'a': ['b']}
assert request.content_type.startswith('multipart/form-data') response = EchoView.as_view()(request._request)
assert response.data == {'a': ['b']} request = response.renderer_context['request']
assert request.content_type.startswith('multipart/form-data')
# pass same HttpRequest to view, form data set on underlying request assert response.data == {'a': ['b']}
response = EchoView.as_view()(request._request)
request = response.renderer_context['request']
# ensure that request stream was consumed by form parser
assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']}

View File

@ -131,155 +131,146 @@ urlpatterns = [
# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... # TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class RendererIntegrationTests(TestCase): """
"""
End-to-end testing of renderers using an ResponseMixin on a generic view. End-to-end testing of renderers using an ResponseMixin on a generic view.
""" """
def test_default_renderer_serializes_content(self): deftest_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/') resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self): def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests.""" """No response must be included in HEAD requests."""
resp = self.client.head('/') resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, b'') assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*') resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self): def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)""" (In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self): def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)""" (In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_query(self): def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format) resp = self.client.get('/?format=%s' % RendererB.format)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self): def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching """If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/something.formatb') resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format, resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type)
HTTP_ACCEPT=RendererB.media_type) assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp.status_code, DUMMYSTATUS)
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class UnsupportedMediaTypeTests(TestCase): def test_should_allow_posting_json(self):
def test_should_allow_posting_json(self): response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
response = self.client.post('/json', data='{"test": 123}', content_type='application/json') assert response.status_code == 200
self.assertEqual(response.status_code, 200)
def test_should_not_allow_posting_xml(self): def test_should_not_allow_posting_xml(self):
response = self.client.post('/json', data='<test>123</test>', content_type='application/xml') response = self.client.post('/json', data='<test>123</test>', content_type='application/xml')
assert response.status_code == 415
self.assertEqual(response.status_code, 415)
def test_should_not_allow_posting_a_form(self): def test_should_not_allow_posting_a_form(self):
response = self.client.post('/json', data={'test': 123}) response = self.client.post('/json', data={'test': 123})
assert response.status_code == 415
self.assertEqual(response.status_code, 415)
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue122Tests(TestCase): """
"""
Tests that covers #122. Tests that covers #122.
""" """
def test_only_html_renderer(self): deftest_only_html_renderer(self):
""" """
Test if no infinite recursion occurs. Test if no infinite recursion occurs.
""" """
self.client.get('/html') self.client.get('/html')
def test_html_renderer_is_first(self): def test_html_renderer_is_first(self):
""" """
Test if no infinite recursion occurs. Test if no infinite recursion occurs.
""" """
self.client.get('/html1') self.client.get('/html1')
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue467Tests(TestCase): """
"""
Tests for #467 Tests for #467
""" """
def test_form_has_label_and_help_text(self): deftest_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue807Tests(TestCase): """
"""
Covers #807 Covers #807
""" """
def test_does_not_append_charset_by_default(self): deftest_does_not_append_charset_by_default(self):
""" """
Renderers don't include a charset unless set explicitly. Renderers don't include a charset unless set explicitly.
""" """
headers = {"HTTP_ACCEPT": RendererA.media_type} headers = {"HTTP_ACCEPT": RendererA.media_type}
resp = self.client.get('/', **headers) resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererA.media_type, 'utf-8') expected = "{}; charset={}".format(RendererA.media_type, 'utf-8')
self.assertEqual(expected, resp['Content-Type']) assert expected == resp['Content-Type']
def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
""" """
If renderer class has charset attribute declared, it gets appended If renderer class has charset attribute declared, it gets appended
to Response's Content-Type to Response's Content-Type
""" """
headers = {"HTTP_ACCEPT": RendererC.media_type} headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/', **headers) resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset) expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset)
self.assertEqual(expected, resp['Content-Type']) assert expected == resp['Content-Type']
def test_content_type_set_explicitly_on_response(self): def test_content_type_set_explicitly_on_response(self):
""" """
The content type may be set explicitly on the response. The content type may be set explicitly on the response.
""" """
headers = {"HTTP_ACCEPT": RendererC.media_type} headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/setbyview', **headers) resp = self.client.get('/setbyview', **headers)
self.assertEqual('setbyview', resp['Content-Type']) assert 'setbyview' == resp['Content-Type']
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')

View File

@ -30,25 +30,22 @@ class MockVersioningScheme:
@override_settings(ROOT_URLCONF='tests.test_reverse') @override_settings(ROOT_URLCONF='tests.test_reverse')
class ReverseTests(TestCase): """
"""
Tests for fully qualified URLs when using `reverse`. Tests for fully qualified URLs when using `reverse`.
""" """
def test_reversed_urls_are_fully_qualified(self): deftest_reversed_urls_are_fully_qualified(self):
request = factory.get('/view') request = factory.get('/view')
url = reverse('view', request=request) url = reverse('view', request=request)
assert url == 'http://testserver/view' assert url == 'http://testserver/view'
def test_reverse_with_versioning_scheme(self): def test_reverse_with_versioning_scheme(self):
request = factory.get('/view') request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme() request.versioning_scheme = MockVersioningScheme()
url = reverse('view', request=request)
url = reverse('view', request=request) assert url == 'http://scheme-reversed/view'
assert url == 'http://scheme-reversed/view'
def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self): def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self):
request = factory.get('/view') request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme(raise_error=True) request.versioning_scheme = MockVersioningScheme(raise_error=True)
url = reverse('view', request=request)
url = reverse('view', request=request) assert url == 'http://testserver/view'
assert url == 'http://testserver/view'

View File

@ -214,25 +214,24 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"} assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"}
class TestLookupValueRegex(TestCase): """
"""
Ensure the router honors lookup_value_regex when applied Ensure the router honors lookup_value_regex when applied
to the viewset. to the viewset.
""" """
def setUp(self): defsetUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all() queryset = RouterTestModel.objects.all()
lookup_field = 'uuid' lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f]{32}' lookup_value_regex = '[0-9a-f]{32}'
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes',NoteViewSet)
self.urls = self.router.urls self.urls=self.router.urls
def test_urls_limited_by_lookup_value_regex(self): def test_urls_limited_by_lookup_value_regex(self):
expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$'] expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$']
for idx in range(len(expected)): for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == get_regex_pattern(self.urls[idx])
@override_settings(ROOT_URLCONF='tests.test_routers') @override_settings(ROOT_URLCONF='tests.test_routers')
@ -264,97 +263,84 @@ class TestLookupUrlKwargs(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"} assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"}
class TestTrailingSlashIncluded(TestCase): def setUp(self):
def setUp(self): class NoteViewSet(viewsets.ModelViewSet):
class NoteViewSet(viewsets.ModelViewSet): queryset = RouterTestModel.objects.all()
queryset = RouterTestModel.objects.all()
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes',NoteViewSet)
self.urls = self.router.urls self.urls=self.router.urls
def test_urls_have_trailing_slash_by_default(self): def test_urls_have_trailing_slash_by_default(self):
expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$'] expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$']
for idx in range(len(expected)): for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == get_regex_pattern(self.urls[idx])
class TestTrailingSlashRemoved(TestCase): def setUp(self):
def setUp(self): class NoteViewSet(viewsets.ModelViewSet):
class NoteViewSet(viewsets.ModelViewSet): queryset = RouterTestModel.objects.all()
queryset = RouterTestModel.objects.all()
self.router = SimpleRouter(trailing_slash=False) self.router = SimpleRouter(trailing_slash=False)
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes',NoteViewSet)
self.urls = self.router.urls self.urls=self.router.urls
def test_urls_can_have_trailing_slash_removed(self): def test_urls_can_have_trailing_slash_removed(self):
expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$'] expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)): for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == get_regex_pattern(self.urls[idx])
class TestNameableRoot(TestCase): def setUp(self):
def setUp(self): class NoteViewSet(viewsets.ModelViewSet):
class NoteViewSet(viewsets.ModelViewSet): queryset = RouterTestModel.objects.all()
queryset = RouterTestModel.objects.all()
self.router = DefaultRouter() self.router = DefaultRouter()
self.router.root_view_name = 'nameable-root' self.router.root_view_name='nameable-root'
self.router.register(r'notes', NoteViewSet) self.router.register(r'notes',NoteViewSet)
self.urls = self.router.urls self.urls=self.router.urls
def test_router_has_custom_name(self): def test_router_has_custom_name(self):
expected = 'nameable-root' expected = 'nameable-root'
assert expected == self.urls[-1].name assert expected == self.urls[-1].name
class TestActionKeywordArgs(TestCase): """
"""
Ensure keyword arguments passed in the `@action` decorator Ensure keyword arguments passed in the `@action` decorator
are properly handled. Refs #940. are properly handled. Refs #940.
""" """
defsetUp(self):
def setUp(self): class TestViewSet(viewsets.ModelViewSet):
class TestViewSet(viewsets.ModelViewSet): permission_classes = []
permission_classes = [] @action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny])
def custom(self, request, *args, **kwargs):
@action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny]) return Response({ 'permission_classes': self.permission_classes })
def custom(self, request, *args, **kwargs):
return Response({
'permission_classes': self.permission_classes
})
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'test', TestViewSet, basename='test') self.router.register(r'test',TestViewSet,basename='test')
self.view = self.router.urls[-1].callback self.view=self.router.urls[-1].callback
def test_action_kwargs(self): def test_action_kwargs(self):
request = factory.post('/test/0/custom/') request = factory.post('/test/0/custom/')
response = self.view(request) response = self.view(request)
assert response.data == {'permission_classes': [permissions.AllowAny]} assert response.data == {'permission_classes': [permissions.AllowAny]}
class TestActionAppliedToExistingRoute(TestCase): """
"""
Ensure `@action` decorator raises an except when applied Ensure `@action` decorator raises an except when applied
to an existing route to an existing route
""" """
deftest_exception_raised_when_action_applied_to_existing_route(self):
class TestViewSet(viewsets.ModelViewSet):
def test_exception_raised_when_action_applied_to_existing_route(self): @action(methods=['post'], detail=True)
class TestViewSet(viewsets.ModelViewSet): def retrieve(self, request, *args, **kwargs):
return Response({ 'hello': 'world' })
@action(methods=['post'], detail=True)
def retrieve(self, request, *args, **kwargs):
return Response({
'hello': 'world'
})
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'test', TestViewSet, basename='test') self.router.register(r'test',TestViewSet,basename='test')
withpytest.raises(ImproperlyConfigured):
with pytest.raises(ImproperlyConfigured): self.router.urls
self.router.urls
class DynamicListAndDetailViewSet(viewsets.ViewSet): class DynamicListAndDetailViewSet(viewsets.ViewSet):
@ -390,44 +376,33 @@ class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet):
pass pass
class TestDynamicListAndDetailRouter(TestCase): def setUp(self):
def setUp(self): self.router = SimpleRouter()
self.router = SimpleRouter()
def _test_list_and_detail_route_decorators(self, viewset): def _test_list_and_detail_route_decorators(self, viewset):
routes = self.router.get_routes(viewset) routes = self.router.get_routes(viewset)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path') for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), MethodNamesMap('list_route_get', 'list_route_get'), MethodNamesMap('list_route_post', 'list_route_post'), MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), MethodNamesMap('detail_route_get', 'detail_route_get'), MethodNamesMap('detail_route_post', 'detail_route_post') ]):
# Make sure all these endpoints exist and none have been clobbered route = decorator_routes[i]
for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), method_name = endpoint.method_name
MethodNamesMap('list_route_get', 'list_route_get'), url_path = endpoint.url_path
MethodNamesMap('list_route_post', 'list_route_post'), if method_name.startswith('list_'):
MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
MethodNamesMap('detail_route_get', 'detail_route_get'),
MethodNamesMap('detail_route_post', 'detail_route_post')
]):
route = decorator_routes[i]
# check url listing
method_name = endpoint.method_name
url_path = endpoint.url_path
if method_name.startswith('list_'):
assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
else: else:
assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path) assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)
# check method to function mapping # check method to function mapping
if method_name.endswith('_post'): if method_name.endswith('_post'):
method_map = 'post' method_map = 'post'
else: else:
method_map = 'get' method_map = 'get'
assert route.mapping[method_map] == method_name assert route.mapping[method_map] == method_name
def test_list_and_detail_route_decorators(self): def test_list_and_detail_route_decorators(self):
self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet) self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet)
def test_inherited_list_and_detail_route_decorators(self): def test_inherited_list_and_detail_route_decorators(self):
self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet) self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
class TestEmptyPrefix(URLPatternsTestCase, TestCase): class TestEmptyPrefix(URLPatternsTestCase, TestCase):
@ -490,69 +465,52 @@ class TestViewInitkwargs(URLPatternsTestCase, TestCase):
assert initkwargs['basename'] == 'routertestmodel' assert initkwargs['basename'] == 'routertestmodel'
class TestBaseNameRename(TestCase):
def test_base_name_and_basename_assertion(self): def test_base_name_and_basename_assertion(self):
router = SimpleRouter() router = SimpleRouter()
msg = "Do not provide both the `basename` and `base_name` arguments."
msg = "Do not provide both the `basename` and `base_name` arguments." with warnings.catch_warnings(record=True) as w, self.assertRaisesMessage(AssertionError, msg):
with warnings.catch_warnings(record=True) as w, \ warnings.simplefilter('always')
self.assertRaisesMessage(AssertionError, msg): router.register('mock', MockViewSet, 'mock', base_name='mock')
warnings.simplefilter('always')
router.register('mock', MockViewSet, 'mock', base_name='mock')
msg = "The `base_name` argument is pending deprecation in favor of `basename`." msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assert len(w) == 1 assertlen(w)==1
assert str(w[0].message) == msg assertstr(w[0].message)==msg
def test_base_name_argument_deprecation(self): def test_base_name_argument_deprecation(self):
router = SimpleRouter() router = SimpleRouter()
with pytest.warns(RemovedInDRF311Warning) as w:
with pytest.warns(RemovedInDRF311Warning) as w: warnings.simplefilter('always')
warnings.simplefilter('always') router.register('mock', MockViewSet, base_name='mock')
router.register('mock', MockViewSet, base_name='mock')
msg = "The `base_name` argument is pending deprecation in favor of `basename`." msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assert len(w) == 1 assertlen(w)==1
assert str(w[0].message) == msg assertstr(w[0].message)==msg
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'mock'),]
('mock', MockViewSet, 'mock'),
]
def test_basename_argument_no_warnings(self): def test_basename_argument_no_warnings(self):
router = SimpleRouter() router = SimpleRouter()
with warnings.catch_warnings(record=True) as w:
with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always')
warnings.simplefilter('always') router.register('mock', MockViewSet, basename='mock')
router.register('mock', MockViewSet, basename='mock')
assert len(w) == 0 assert len(w) == 0
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'mock'),]
('mock', MockViewSet, 'mock'),
]
def test_get_default_base_name_deprecation(self): def test_get_default_base_name_deprecation(self):
msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`." msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`."
with pytest.warns(RemovedInDRF311Warning) as w:
# Class definition should raise a warning warnings.simplefilter('always')
with pytest.warns(RemovedInDRF311Warning) as w: class CustomRouter(SimpleRouter):
warnings.simplefilter('always') def get_default_base_name(self, viewset):
return 'foo'
class CustomRouter(SimpleRouter):
def get_default_base_name(self, viewset):
return 'foo'
assert len(w) == 1 assert len(w) == 1
assert str(w[0].message) == msg assertstr(w[0].message)==msg
withwarnings.catch_warnings(record=True)asw:
# Deprecated method implementation should still be called warnings.simplefilter('always')
with warnings.catch_warnings(record=True) as w: router = CustomRouter()
warnings.simplefilter('always') router.register('mock', MockViewSet)
router = CustomRouter()
router.register('mock', MockViewSet)
assert len(w) == 0 assert len(w) == 0
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'foo'),]
('mock', MockViewSet, 'foo'),
]

File diff suppressed because one or more lines are too long

View File

@ -4,120 +4,66 @@ Tests to cover bulk create and update using serializers.
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
"""
class BulkCreateSerializerTests(TestCase):
"""
Creating multiple instances using serializers. Creating multiple instances using serializers.
""" """
defsetUp(self):
def setUp(self): class BookSerializer(serializers.Serializer):
class BookSerializer(serializers.Serializer): id = serializers.IntegerField()
id = serializers.IntegerField() title = serializers.CharField(max_length=100)
title = serializers.CharField(max_length=100) author = serializers.CharField(max_length=100)
author = serializers.CharField(max_length=100)
self.BookSerializer = BookSerializer self.BookSerializer = BookSerializer
def test_bulk_create_success(self): def test_bulk_create_success(self):
""" """
Correct bulk update serialization should return the input data. Correct bulk update serialization should return the input data.
""" """
data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 2, 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
data = [ serializer = self.BookSerializer(data=data, many=True)
{ assert serializer.is_valid() is True
'id': 0, assert serializer.validated_data == data
'title': 'The electric kool-aid acid test', assert serializer.errors == []
'author': 'Tom Wolfe'
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 2,
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is True
assert serializer.validated_data == data
assert serializer.errors == []
def test_bulk_create_errors(self): def test_bulk_create_errors(self):
""" """
Incorrect bulk create serialization should return errors. Incorrect bulk create serialization should return errors.
""" """
data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 'foo', 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
data = [ expected_errors = [ {}, {}, {'id': ['A valid integer is required.']} ]
{ serializer = self.BookSerializer(data=data, many=True)
'id': 0, assert serializer.is_valid() is False
'title': 'The electric kool-aid acid test', assert serializer.errors == expected_errors
'author': 'Tom Wolfe' assert serializer.validated_data == []
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 'foo',
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
expected_errors = [
{},
{},
{'id': ['A valid integer is required.']}
]
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
assert serializer.errors == expected_errors
assert serializer.validated_data == []
def test_invalid_list_datatype(self): def test_invalid_list_datatype(self):
""" """
Data containing list of incorrect data type should return errors. Data containing list of incorrect data type should return errors.
""" """
data = ['foo', 'bar', 'baz'] data = ['foo', 'bar', 'baz']
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
message = 'Invalid data. Expected a dictionary, but got str.'
message = 'Invalid data. Expected a dictionary, but got str.' expected_errors = [ {'non_field_errors': [message]}, {'non_field_errors': [message]}, {'non_field_errors': [message]} ]
expected_errors = [ assert serializer.errors == expected_errors
{'non_field_errors': [message]},
{'non_field_errors': [message]},
{'non_field_errors': [message]}
]
assert serializer.errors == expected_errors
def test_invalid_single_datatype(self): def test_invalid_single_datatype(self):
""" """
Data containing a single incorrect data type should return errors. Data containing a single incorrect data type should return errors.
""" """
data = 123 data = 123
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} assert serializer.errors == expected_errors
assert serializer.errors == expected_errors
def test_invalid_single_object(self): def test_invalid_single_object(self):
""" """
Data containing only a single object, instead of a list of objects Data containing only a single object, instead of a list of objects
should return errors. should return errors.
""" """
data = { data = { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }
'id': 0, serializer = self.BookSerializer(data=data, many=True)
'title': 'The electric kool-aid acid test', assert serializer.is_valid() is False
'author': 'Tom Wolfe' expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
} assert serializer.errors == expected_errors
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
assert serializer.errors == expected_errors