diff --git a/.gitignore b/.gitignore index 52d63d225..82e885ede 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ *.db *~ .* -.py.bak +*.py.bak /site/ diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index de04d2c06..4c88b3586 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -11,6 +11,7 @@ from rest_framework.response import Response from rest_framework.test import APIRequestFactory from rest_framework.views import APIView from tests.models import BasicModel +import pytest factory = APIRequestFactory() @@ -52,51 +53,49 @@ urlpatterns = ( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." ) -class DBTransactionTests(TestCase): - def setUp(self): - self.view = BasicView.as_view() - connections.databases['default']['ATOMIC_REQUESTS'] = True +def setUp(self): + self.view = BasicView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + connections.databases['default']['ATOMIC_REQUESTS'] = False def test_no_exception_commit_transaction(self): - request = factory.post('/') - - with self.assertNumQueries(1): - response = self.view(request) + request = factory.post('/') + with self.assertNumQueries(1): + response = self.view(request) assert not transaction.get_rollback() - assert response.status_code == status.HTTP_200_OK - assert BasicModel.objects.count() == 1 +assertresponse.status_code==status.HTTP_200_OK +assertBasicModel.objects.count()==1 @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." ) -class DBTransactionErrorTests(TestCase): - def setUp(self): - self.view = ErrorView.as_view() - connections.databases['default']['ATOMIC_REQUESTS'] = True +def setUp(self): + self.view = ErrorView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + connections.databases['default']['ATOMIC_REQUESTS'] = False def test_generic_exception_delegate_transaction_management(self): - """ + """ Transaction is eventually managed by outer-most transaction atomic block. DRF do not try to interfere here. We let django deal with the transaction when it will catch the Exception. """ - request = factory.post('/') - with self.assertNumQueries(3): + request = factory.post('/') + with self.assertNumQueries(3): # 1 - begin savepoint # 2 - insert # 3 - release savepoint - with transaction.atomic(): - self.assertRaises(Exception, self.view, request) - assert not transaction.get_rollback() + with transaction.atomic(): + with pytest.raises(Exception): + self.view(request) + assert not transaction.get_rollback() assert BasicModel.objects.count() == 1 @@ -104,30 +103,29 @@ class DBTransactionErrorTests(TestCase): connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." ) -class DBTransactionAPIExceptionTests(TestCase): - def setUp(self): - self.view = APIExceptionView.as_view() - connections.databases['default']['ATOMIC_REQUESTS'] = True +def setUp(self): + self.view = APIExceptionView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + connections.databases['default']['ATOMIC_REQUESTS'] = False def test_api_exception_rollback_transaction(self): - """ + """ Transaction is rollbacked by our transaction atomic block. """ - request = factory.post('/') - num_queries = 4 if connection.features.can_release_savepoints else 3 - with self.assertNumQueries(num_queries): + request = factory.post('/') + num_queries = 4 if connection.features.can_release_savepoints else 3 + with self.assertNumQueries(num_queries): # 1 - begin savepoint # 2 - insert # 3 - rollback savepoint # 4 - release savepoint - with transaction.atomic(): - response = self.view(request) - assert transaction.get_rollback() + with transaction.atomic(): + response = self.view(request) + assert transaction.get_rollback() assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert BasicModel.objects.count() == 0 +assertBasicModel.objects.count()==0 @unittest.skipUnless( diff --git a/tests/test_authtoken.py b/tests/test_authtoken.py index 036e317ef..f22551b26 100644 --- a/tests/test_authtoken.py +++ b/tests/test_authtoken.py @@ -13,74 +13,68 @@ from rest_framework.authtoken.models import Token from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.exceptions import ValidationError - -class AuthTokenTests(TestCase): - - def setUp(self): - self.site = site - self.user = User.objects.create_user(username='test_user') - self.token = Token.objects.create(key='test token', user=self.user) +def setUp(self): + self.site = site + self.user = User.objects.create_user(username='test_user') + self.token = Token.objects.create(key='test token', user=self.user) def test_model_admin_displayed_fields(self): - mock_request = object() - token_admin = TokenAdmin(self.token, self.site) - assert token_admin.get_fields(mock_request) == ('user',) + mock_request = object() + token_admin = TokenAdmin(self.token, self.site) + assert token_admin.get_fields(mock_request) == ('user',) def test_token_string_representation(self): - assert str(self.token) == 'test token' + assert str(self.token) == 'test token' def test_validate_raise_error_if_no_credentials_provided(self): - with pytest.raises(ValidationError): - AuthTokenSerializer().validate({}) + with pytest.raises(ValidationError): + AuthTokenSerializer().validate({}) def test_whitespace_in_password(self): - data = {'username': self.user.username, 'password': 'test pass '} - self.user.set_password(data['password']) - self.user.save() - assert AuthTokenSerializer(data=data).is_valid() + data = {'username': self.user.username, 'password': 'test pass '} + self.user.set_password(data['password']) + self.user.save() + assert AuthTokenSerializer(data=data).is_valid() -class AuthTokenCommandTests(TestCase): - def setUp(self): - self.site = site - self.user = User.objects.create_user(username='test_user') +def setUp(self): + self.site = site + self.user = User.objects.create_user(username='test_user') def test_command_create_user_token(self): - token = AuthTokenCommand().create_user_token(self.user.username, False) - assert token is not None - token_saved = Token.objects.first() - assert token.key == token_saved.key + token = AuthTokenCommand().create_user_token(self.user.username, False) + assert token is not None + token_saved = Token.objects.first() + assert token.key == token_saved.key def test_command_create_user_token_invalid_user(self): - with pytest.raises(User.DoesNotExist): - AuthTokenCommand().create_user_token('not_existing_user', False) + with pytest.raises(User.DoesNotExist): + AuthTokenCommand().create_user_token('not_existing_user', False) def test_command_reset_user_token(self): - AuthTokenCommand().create_user_token(self.user.username, False) - first_token_key = Token.objects.first().key - AuthTokenCommand().create_user_token(self.user.username, True) - second_token_key = Token.objects.first().key - - assert first_token_key != second_token_key + AuthTokenCommand().create_user_token(self.user.username, False) + first_token_key = Token.objects.first().key + AuthTokenCommand().create_user_token(self.user.username, True) + second_token_key = Token.objects.first().key + assert first_token_key != second_token_key def test_command_do_not_reset_user_token(self): - AuthTokenCommand().create_user_token(self.user.username, False) - first_token_key = Token.objects.first().key - AuthTokenCommand().create_user_token(self.user.username, False) - second_token_key = Token.objects.first().key - - assert first_token_key == second_token_key + AuthTokenCommand().create_user_token(self.user.username, False) + first_token_key = Token.objects.first().key + AuthTokenCommand().create_user_token(self.user.username, False) + second_token_key = Token.objects.first().key + assert first_token_key == second_token_key def test_command_raising_error_for_invalid_user(self): - out = StringIO() - with pytest.raises(CommandError): - call_command('drf_create_token', 'not_existing_user', stdout=out) + out = StringIO() + with pytest.raises(CommandError): + call_command('drf_create_token', 'not_existing_user', stdout=out) def test_command_output(self): - out = StringIO() - call_command('drf_create_token', self.user.username, stdout=out) - token_saved = Token.objects.first() - self.assertIn('Generated token', out.getvalue()) - self.assertIn(self.user.username, out.getvalue()) - self.assertIn(token_saved.key, out.getvalue()) + out = StringIO() + call_command('drf_create_token', self.user.username, stdout=out) + token_saved = Token.objects.first() + assert 'Generated token' in out.getvalue() + assert self.user.username in out.getvalue() + assert token_saved.key in out.getvalue() diff --git a/tests/test_decorators.py b/tests/test_decorators.py index bd30449e5..8b81ad461 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -17,307 +17,270 @@ from rest_framework.test import APIRequestFactory from rest_framework.throttling import UserRateThrottle from rest_framework.views import APIView - -class DecoratorTestCase(TestCase): - - def setUp(self): - self.factory = APIRequestFactory() +def setUp(self): + self.factory = APIRequestFactory() def _finalize_response(self, request, response, *args, **kwargs): - response.request = request - return APIView.finalize_response(self, request, response, *args, **kwargs) + response.request = request + return APIView.finalize_response(self, request, response, *args, **kwargs) def test_api_view_incorrect(self): - """ + """ If @api_view is not applied correct, we should raise an assertion. """ + @api_view + def view(request): + return Response() - @api_view + request = self.factory.get('/') +withpytest.raises(AssertionError): +view(request) + + def test_api_view_incorrect_arguments(self): + """ + If @api_view is missing arguments, we should raise an assertion. + """ + with pytest.raises(AssertionError): + @api_view('GET') def view(request): return Response() - request = self.factory.get('/') - self.assertRaises(AssertionError, view, request) - - def test_api_view_incorrect_arguments(self): - """ - If @api_view is missing arguments, we should raise an assertion. - """ - - with self.assertRaises(AssertionError): - @api_view('GET') - def view(request): - return Response() - def test_calling_method(self): - @api_view(['GET']) - def view(request): - return Response({}) + @api_view(['GET']) + def view(request): + return Response({}) request = self.factory.get('/') - response = view(request) - assert response.status_code == status.HTTP_200_OK - - request = self.factory.post('/') - response = view(request) - assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED +response=view(request) +assertresponse.status_code==status.HTTP_200_OK +request=self.factory.post('/') +response=view(request) +assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED def test_calling_put_method(self): - @api_view(['GET', 'PUT']) - def view(request): - return Response({}) + @api_view(['GET', 'PUT']) + def view(request): + return Response({}) request = self.factory.put('/') - response = view(request) - assert response.status_code == status.HTTP_200_OK - - request = self.factory.post('/') - response = view(request) - assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED +response=view(request) +assertresponse.status_code==status.HTTP_200_OK +request=self.factory.post('/') +response=view(request) +assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED def test_calling_patch_method(self): - @api_view(['GET', 'PATCH']) - def view(request): - return Response({}) + @api_view(['GET', 'PATCH']) + def view(request): + return Response({}) request = self.factory.patch('/') - response = view(request) - assert response.status_code == status.HTTP_200_OK - - request = self.factory.post('/') - response = view(request) - assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED +response=view(request) +assertresponse.status_code==status.HTTP_200_OK +request=self.factory.post('/') +response=view(request) +assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED def test_renderer_classes(self): - @api_view(['GET']) - @renderer_classes([JSONRenderer]) - def view(request): - return Response({}) + @api_view(['GET']) + @renderer_classes([JSONRenderer]) + def view(request): + return Response({}) request = self.factory.get('/') - response = view(request) - assert isinstance(response.accepted_renderer, JSONRenderer) +response=view(request) +assertisinstance(response.accepted_renderer,JSONRenderer) def test_parser_classes(self): - @api_view(['GET']) - @parser_classes([JSONParser]) - def view(request): - assert len(request.parsers) == 1 - assert isinstance(request.parsers[0], JSONParser) - return Response({}) + @api_view(['GET']) + @parser_classes([JSONParser]) + def view(request): + assert len(request.parsers) == 1 + assert isinstance(request.parsers[0], JSONParser) + return Response({}) request = self.factory.get('/') - view(request) +view(request) def test_authentication_classes(self): - @api_view(['GET']) - @authentication_classes([BasicAuthentication]) - def view(request): - assert len(request.authenticators) == 1 - assert isinstance(request.authenticators[0], BasicAuthentication) - return Response({}) + @api_view(['GET']) + @authentication_classes([BasicAuthentication]) + def view(request): + assert len(request.authenticators) == 1 + assert isinstance(request.authenticators[0], BasicAuthentication) + return Response({}) request = self.factory.get('/') - view(request) +view(request) def test_permission_classes(self): - @api_view(['GET']) - @permission_classes([IsAuthenticated]) - def view(request): - return Response({}) + @api_view(['GET']) + @permission_classes([IsAuthenticated]) + def view(request): + return Response({}) request = self.factory.get('/') - response = view(request) - assert response.status_code == status.HTTP_403_FORBIDDEN +response=view(request) +assertresponse.status_code==status.HTTP_403_FORBIDDEN def test_throttle_classes(self): - class OncePerDayUserThrottle(UserRateThrottle): - rate = '1/day' + class OncePerDayUserThrottle(UserRateThrottle): + rate = '1/day' @api_view(['GET']) - @throttle_classes([OncePerDayUserThrottle]) - def view(request): - return Response({}) +@throttle_classes([OncePerDayUserThrottle]) +defview(request): + return Response({}) request = self.factory.get('/') - response = view(request) - assert response.status_code == status.HTTP_200_OK - - response = view(request) - assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS +response=view(request) +assertresponse.status_code==status.HTTP_200_OK +response=view(request) +assertresponse.status_code==status.HTTP_429_TOO_MANY_REQUESTS def test_schema(self): - """ + """ Checks CustomSchema class is set on view """ - class CustomSchema(AutoSchema): - pass + class CustomSchema(AutoSchema): + pass @api_view(['GET']) - @schema(CustomSchema()) - def view(request): - return Response({}) +@schema(CustomSchema()) +defview(request): + return Response({}) assert isinstance(view.cls.schema, CustomSchema) -class ActionDecoratorTestCase(TestCase): - def test_defaults(self): - @action(detail=True) - def test_action(request): - """Description""" +def test_defaults(self): + @action(detail=True) + def test_action(request): + """Description""" assert test_action.mapping == {'get': 'test_action'} - assert test_action.detail is True - assert test_action.url_path == 'test_action' - assert test_action.url_name == 'test-action' - assert test_action.kwargs == { - 'name': 'Test action', - 'description': 'Description', - } +asserttest_action.detailisTrue +asserttest_action.url_path=='test_action' +asserttest_action.url_name=='test-action' +asserttest_action.kwargs=={'name':'Test action','description':'Description',} def test_detail_required(self): - with pytest.raises(AssertionError) as excinfo: - @action() - def test_action(request): - raise NotImplementedError + with pytest.raises(AssertionError) as excinfo: + @action() + def test_action(request): + raise NotImplementedError assert str(excinfo.value) == "@action() missing required argument: 'detail'" def test_method_mapping_http_methods(self): # All HTTP methods should be mappable - @action(detail=False, methods=[]) - def test_action(): - raise NotImplementedError + @action(detail=False, methods=[]) + def test_action(): + raise NotImplementedError for name in APIView.http_method_names: - def method(): - raise NotImplementedError + def method(): + raise NotImplementedError method.__name__ = name - getattr(test_action.mapping, name)(method) +getattr(test_action.mapping,name)(method) # ensure the mapping returns the correct method name for name in APIView.http_method_names: - assert test_action.mapping[name] == name + assert test_action.mapping[name] == name def test_view_name_kwargs(self): - """ + """ 'name' and 'suffix' are mutually exclusive kwargs used for generating a view's display name. """ - # by default, generate name from method - @action(detail=True) - def test_action(request): - raise NotImplementedError + @action(detail=True) + def test_action(request): + raise NotImplementedError - assert test_action.kwargs == { - 'description': None, - 'name': 'Test action', - } + assert test_action.kwargs == {'description':None,'name':'Test action',} +@action(detail=True,name='test name') +deftest_action(request): + raise NotImplementedError - # name kwarg supersedes name generation - @action(detail=True, name='test name') - def test_action(request): - raise NotImplementedError + assert test_action.kwargs == {'description':None,'name':'test name',} +@action(detail=True,suffix='Suffix') +deftest_action(request): + raise NotImplementedError - assert test_action.kwargs == { - 'description': None, - 'name': 'test name', - } - - # suffix kwarg supersedes name generation - @action(detail=True, suffix='Suffix') - def test_action(request): - raise NotImplementedError - - assert test_action.kwargs == { - 'description': None, - 'suffix': 'Suffix', - } - - # name + suffix is a conflict. - with pytest.raises(TypeError) as excinfo: - action(detail=True, name='test name', suffix='Suffix') + assert test_action.kwargs == {'description':None,'suffix':'Suffix',} +withpytest.raises(TypeError)asexcinfo: + action(detail=True, name='test name', suffix='Suffix') assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments." def test_method_mapping(self): - @action(detail=False) - def test_action(request): - raise NotImplementedError + @action(detail=False) + def test_action(request): + raise NotImplementedError @test_action.mapping.post - def test_action_post(request): - raise NotImplementedError +deftest_action_post(request): + raise NotImplementedError # The secondary handler methods should not have the action attributes for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']: - assert hasattr(test_action, name) and not hasattr(test_action_post, name) + assert hasattr(test_action, name) and not hasattr(test_action_post, name) def test_method_mapping_already_mapped(self): - @action(detail=True) - def test_action(request): - raise NotImplementedError + @action(detail=True) + def test_action(request): + raise NotImplementedError msg = "Method 'get' has already been mapped to '.test_action'." - with self.assertRaisesMessage(AssertionError, msg): - @test_action.mapping.get - def test_action_get(request): - raise NotImplementedError +withself.assertRaisesMessage(AssertionError,msg): + @test_action.mapping.get + def test_action_get(request): + raise NotImplementedError def test_method_mapping_overwrite(self): - @action(detail=True) + @action(detail=True) + def test_action(): + raise NotImplementedError + + msg = ("Method mapping does not behave like the property decorator. You ""cannot use the same method name for each mapping declaration.") +withself.assertRaisesMessage(AssertionError,msg): + @test_action.mapping.post def test_action(): raise NotImplementedError - msg = ("Method mapping does not behave like the property decorator. You " - "cannot use the same method name for each mapping declaration.") - with self.assertRaisesMessage(AssertionError, msg): - @test_action.mapping.post - def test_action(): - raise NotImplementedError - def test_detail_route_deprecation(self): - with pytest.warns(RemovedInDRF310Warning) as record: - @detail_route() - def view(request): - raise NotImplementedError + with pytest.warns(RemovedInDRF310Warning) as record: + @detail_route() + def view(request): + raise NotImplementedError assert len(record) == 1 - assert str(record[0].message) == ( - "`detail_route` is deprecated and will be removed in " - "3.10 in favor of `action`, which accepts a `detail` bool. Use " - "`@action(detail=True)` instead." - ) +assertstr(record[0].message)==("`detail_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=True)` instead.") def test_list_route_deprecation(self): - with pytest.warns(RemovedInDRF310Warning) as record: - @list_route() - def view(request): - raise NotImplementedError + with pytest.warns(RemovedInDRF310Warning) as record: + @list_route() + def view(request): + raise NotImplementedError assert len(record) == 1 - assert str(record[0].message) == ( - "`list_route` is deprecated and will be removed in " - "3.10 in favor of `action`, which accepts a `detail` bool. Use " - "`@action(detail=False)` instead." - ) +assertstr(record[0].message)==("`list_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=False)` instead.") def test_route_url_name_from_path(self): # pre-3.8 behavior was to base the `url_name` off of the `url_path` - with pytest.warns(RemovedInDRF310Warning): - @list_route(url_path='foo_bar') - def view(request): - raise NotImplementedError + with pytest.warns(RemovedInDRF310Warning): + @list_route(url_path='foo_bar') + def view(request): + raise NotImplementedError assert view.url_path == 'foo_bar' - assert view.url_name == 'foo-bar' +assertview.url_name=='foo-bar' diff --git a/tests/test_description.py b/tests/test_description.py index ae00fe4a9..a585eabb5 100644 --- a/tests/test_description.py +++ b/tests/test_description.py @@ -69,37 +69,34 @@ MARKED_DOWN_gte_21 = """

an example docstring

indented

hash style header

%s""" - - -class TestViewNamesAndDescriptions(TestCase): - def test_view_name_uses_class_name(self): - """ +def test_view_name_uses_class_name(self): + """ Ensure view names are based on the class name. """ - class MockView(APIView): - pass + class MockView(APIView): + pass assert MockView().get_view_name() == 'Mock' def test_view_name_uses_name_attribute(self): - class MockView(APIView): - name = 'Foo' + class MockView(APIView): + name = 'Foo' assert MockView().get_view_name() == 'Foo' def test_view_name_uses_suffix_attribute(self): - class MockView(APIView): - suffix = 'List' + class MockView(APIView): + suffix = 'List' assert MockView().get_view_name() == 'Mock List' def test_view_name_preferences_name_over_suffix(self): - class MockView(APIView): - name = 'Foo' - suffix = 'List' + class MockView(APIView): + name = 'Foo' + suffix = 'List' assert MockView().get_view_name() == 'Foo' def test_view_description_uses_docstring(self): - """Ensure view descriptions are based on the docstring.""" - class MockView(APIView): - """an example docstring + """Ensure view descriptions are based on the docstring.""" + class MockView(APIView): + """an example docstring ==================== * list @@ -124,64 +121,53 @@ class TestViewNamesAndDescriptions(TestCase): assert MockView().get_view_description() == DESCRIPTION def test_view_description_uses_description_attribute(self): - class MockView(APIView): - description = 'Foo' + class MockView(APIView): + description = 'Foo' assert MockView().get_view_description() == 'Foo' def test_view_description_allows_empty_description(self): - class MockView(APIView): - """Description.""" - description = '' + class MockView(APIView): + """Description.""" + description = '' assert MockView().get_view_description() == '' def test_view_description_can_be_empty(self): - """ + """ Ensure that if a view has no docstring, then it's description is the empty string. """ - class MockView(APIView): - pass + class MockView(APIView): + pass assert MockView().get_view_description() == '' def test_view_description_can_be_promise(self): - """ + """ Ensure a view may have a docstring that is actually a lazily evaluated class that can be converted to a string. See: https://github.com/encode/django-rest-framework/issues/1708 """ - # use a mock object instead of gettext_lazy to ensure that we can't end - # up with a test case string in our l10n catalog - - class MockLazyStr: - def __init__(self, string): - self.s = string + class MockLazyStr: + def __init__(self, string): + self.s = string def __str__(self): - return self.s + return self.s class MockView(APIView): - __doc__ = MockLazyStr("a gettext string") + __doc__ = MockLazyStr("a gettext string") assert MockView().get_view_description() == 'a gettext string' def test_markdown(self): - """ + """ Ensure markdown to HTML works as expected. """ - if apply_markdown: - md_applied = apply_markdown(DESCRIPTION) - gte_21_match = ( - md_applied == ( - MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or - md_applied == ( - MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE)) - lt_21_match = ( - md_applied == ( - MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or - md_applied == ( - MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE)) - assert gte_21_match or lt_21_match + if apply_markdown: + md_applied = apply_markdown(DESCRIPTION) + gte_21_match = ( md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE)) + lt_21_match = ( md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE)) + assert gte_21_match or lt_21_match def test_dedent_tabs(): diff --git a/tests/test_encoders.py b/tests/test_encoders.py index c66954b80..232725000 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -15,81 +15,79 @@ class MockList: return [1, 2, 3] -class JSONEncoderTests(TestCase): - """ +""" Tests the JSONEncoder method """ - - def setUp(self): - self.encoder = JSONEncoder() +defsetUp(self): + self.encoder = JSONEncoder() def test_encode_decimal(self): - """ + """ Tests encoding a decimal """ - d = Decimal(3.14) - assert self.encoder.default(d) == float(d) + d = Decimal(3.14) + assert self.encoder.default(d) == float(d) def test_encode_datetime(self): - """ + """ Tests encoding a datetime object """ - current_time = datetime.now() - assert self.encoder.default(current_time) == current_time.isoformat() - current_time_utc = current_time.replace(tzinfo=utc) - assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z' + current_time = datetime.now() + assert self.encoder.default(current_time) == current_time.isoformat() + current_time_utc = current_time.replace(tzinfo=utc) + assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z' def test_encode_time(self): - """ + """ Tests encoding a timezone """ - current_time = datetime.now().time() - assert self.encoder.default(current_time) == current_time.isoformat() + current_time = datetime.now().time() + assert self.encoder.default(current_time) == current_time.isoformat() def test_encode_time_tz(self): - """ + """ Tests encoding a timezone aware timestamp """ - current_time = datetime.now().time() - current_time = current_time.replace(tzinfo=utc) - with pytest.raises(ValueError): - self.encoder.default(current_time) + current_time = datetime.now().time() + current_time = current_time.replace(tzinfo=utc) + with pytest.raises(ValueError): + self.encoder.default(current_time) def test_encode_date(self): - """ + """ Tests encoding a date object """ - current_date = date.today() - assert self.encoder.default(current_date) == current_date.isoformat() + current_date = date.today() + assert self.encoder.default(current_date) == current_date.isoformat() def test_encode_timedelta(self): - """ + """ Tests encoding a timedelta object """ - delta = timedelta(hours=1) - assert self.encoder.default(delta) == str(delta.total_seconds()) + delta = timedelta(hours=1) + assert self.encoder.default(delta) == str(delta.total_seconds()) def test_encode_uuid(self): - """ + """ Tests encoding a UUID object """ - unique_id = uuid4() - assert self.encoder.default(unique_id) == str(unique_id) + unique_id = uuid4() + assert self.encoder.default(unique_id) == str(unique_id) @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_encode_coreapi_raises_error(self): - """ +deftest_encode_coreapi_raises_error(self): + """ Tests encoding a coreapi objects raises proper error """ - with pytest.raises(RuntimeError): - self.encoder.default(coreapi.Document()) + with pytest.raises(RuntimeError): + self.encoder.default(coreapi.Document()) with pytest.raises(RuntimeError): - self.encoder.default(coreapi.Error()) + self.encoder.default(coreapi.Error()) def test_encode_object_with_tolist(self): - """ + """ Tests encoding a object with tolist method """ - foo = MockList() - assert self.encoder.default(foo) == [1, 2, 3] + foo = MockList() + assert self.encoder.default(foo) == [1, 2, 3] diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 9516bfec9..70ad77f1f 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -7,89 +7,58 @@ from rest_framework.exceptions import ( server_error ) +def test_get_error_details(self): -class ExceptionTestCase(TestCase): - - def test_get_error_details(self): - - example = "string" - lazy_example = _(example) - - assert _get_error_details(lazy_example) == example - - assert isinstance( - _get_error_details(lazy_example), - ErrorDetail - ) - - assert _get_error_details({'nested': lazy_example})['nested'] == example - - assert isinstance( - _get_error_details({'nested': lazy_example})['nested'], - ErrorDetail - ) - - assert _get_error_details([[lazy_example]])[0][0] == example - - assert isinstance( - _get_error_details([[lazy_example]])[0][0], - ErrorDetail - ) + example = "string" + lazy_example = _(example) + assert _get_error_details(lazy_example) == example + assert isinstance( _get_error_details(lazy_example), ErrorDetail ) + assert _get_error_details({'nested': lazy_example})['nested'] == example + assert isinstance( _get_error_details({'nested': lazy_example})['nested'], ErrorDetail ) + assert _get_error_details([[lazy_example]])[0][0] == example + assert isinstance( _get_error_details([[lazy_example]])[0][0], ErrorDetail ) def test_get_full_details_with_throttling(self): - exception = Throttled() - assert exception.get_full_details() == { - 'message': 'Request was throttled.', 'code': 'throttled'} - - exception = Throttled(wait=2) - assert exception.get_full_details() == { - 'message': 'Request was throttled. Expected available in {} seconds.'.format(2), - 'code': 'throttled'} - - exception = Throttled(wait=2, detail='Slow down!') - assert exception.get_full_details() == { - 'message': 'Slow down! Expected available in {} seconds.'.format(2), - 'code': 'throttled'} + exception = Throttled() + assert exception.get_full_details() == { 'message': 'Request was throttled.', 'code': 'throttled'} + exception = Throttled(wait=2) + assert exception.get_full_details() == { 'message': 'Request was throttled. Expected available in {} seconds.'.format(2), 'code': 'throttled'} + exception = Throttled(wait=2, detail='Slow down!') + assert exception.get_full_details() == { 'message': 'Slow down! Expected available in {} seconds.'.format(2), 'code': 'throttled'} -class ErrorDetailTests(TestCase): - def test_eq(self): - assert ErrorDetail('msg') == ErrorDetail('msg') - assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code') - - assert ErrorDetail('msg') == 'msg' - assert ErrorDetail('msg', 'code') == 'msg' +def test_eq(self): + assert ErrorDetail('msg') == ErrorDetail('msg') + assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code') + assert ErrorDetail('msg') == 'msg' + assert ErrorDetail('msg', 'code') == 'msg' def test_ne(self): - assert ErrorDetail('msg1') != ErrorDetail('msg2') - assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid') - - assert ErrorDetail('msg1') != 'msg2' - assert ErrorDetail('msg1', 'code') != 'msg2' + assert ErrorDetail('msg1') != ErrorDetail('msg2') + assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid') + assert ErrorDetail('msg1') != 'msg2' + assert ErrorDetail('msg1', 'code') != 'msg2' def test_repr(self): - assert repr(ErrorDetail('msg1')) == \ - 'ErrorDetail(string={!r}, code=None)'.format('msg1') - assert repr(ErrorDetail('msg1', 'code')) == \ - 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code') + assert repr(ErrorDetail('msg1')) == 'ErrorDetail(string={!r}, code=None)'.format('msg1') + assert repr(ErrorDetail('msg1', 'code')) == 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code') def test_str(self): - assert str(ErrorDetail('msg1')) == 'msg1' - assert str(ErrorDetail('msg1', 'code')) == 'msg1' + assert str(ErrorDetail('msg1')) == 'msg1' + assert str(ErrorDetail('msg1', 'code')) == 'msg1' def test_hash(self): - assert hash(ErrorDetail('msg')) == hash('msg') - assert hash(ErrorDetail('msg', 'code')) == hash('msg') + assert hash(ErrorDetail('msg')) == hash('msg') + assert hash(ErrorDetail('msg', 'code')) == hash('msg') -class TranslationTests(TestCase): - @translation.override('fr') - def test_message(self): +@translation.override('fr') +deftest_message(self): # this test largely acts as a sanity test to ensure the translation files are present. - self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.') - self.assertEqual(str(APIException()), 'Une erreur du serveur est survenue.') + assert _('A server error occurred.') == 'Une erreur du serveur est survenue.' + assert str(APIException()) == 'Une erreur du serveur est survenue.' def test_server_error(): diff --git a/tests/test_fields.py b/tests/test_fields.py index e0833564b..909895f51 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1134,40 +1134,38 @@ class TestNoStringCoercionDecimalField(FieldValues): ) -class TestLocalizedDecimalField(TestCase): - @override_settings(USE_L10N=True, LANGUAGE_CODE='pl') - def test_to_internal_value(self): - field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) - assert field.to_internal_value('1,1') == Decimal('1.1') +@override_settings(USE_L10N=True, LANGUAGE_CODE='pl') +deftest_to_internal_value(self): + field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) + assert field.to_internal_value('1,1') == Decimal('1.1') @override_settings(USE_L10N=True, LANGUAGE_CODE='pl') - def test_to_representation(self): - field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) - assert field.to_representation(Decimal('1.1')) == '1,1' +deftest_to_representation(self): + field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) + assert field.to_representation(Decimal('1.1')) == '1,1' def test_localize_forces_coerce_to_string(self): - field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True) - assert isinstance(field.to_representation(Decimal('1.1')), str) + field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True) + assert isinstance(field.to_representation(Decimal('1.1')), str) -class TestQuantizedValueForDecimal(TestCase): - def test_int_quantized_value_for_decimal(self): - field = serializers.DecimalField(max_digits=4, decimal_places=2) - value = field.to_internal_value(12).as_tuple() - expected_digit_tuple = (0, (1, 2, 0, 0), -2) - assert value == expected_digit_tuple +def test_int_quantized_value_for_decimal(self): + field = serializers.DecimalField(max_digits=4, decimal_places=2) + value = field.to_internal_value(12).as_tuple() + expected_digit_tuple = (0, (1, 2, 0, 0), -2) + assert value == expected_digit_tuple def test_string_quantized_value_for_decimal(self): - field = serializers.DecimalField(max_digits=4, decimal_places=2) - value = field.to_internal_value('12').as_tuple() - expected_digit_tuple = (0, (1, 2, 0, 0), -2) - assert value == expected_digit_tuple + field = serializers.DecimalField(max_digits=4, decimal_places=2) + value = field.to_internal_value('12').as_tuple() + expected_digit_tuple = (0, (1, 2, 0, 0), -2) + assert value == expected_digit_tuple def test_part_precision_string_quantized_value_for_decimal(self): - field = serializers.DecimalField(max_digits=4, decimal_places=2) - value = field.to_internal_value('12.0').as_tuple() - expected_digit_tuple = (0, (1, 2, 0, 0), -2) - assert value == expected_digit_tuple + field = serializers.DecimalField(max_digits=4, decimal_places=2) + value = field.to_internal_value('12.0').as_tuple() + expected_digit_tuple = (0, (1, 2, 0, 0), -2) + assert value == expected_digit_tuple class TestNoDecimalPlaces(FieldValues): @@ -1185,17 +1183,15 @@ class TestNoDecimalPlaces(FieldValues): field = serializers.DecimalField(max_digits=6, decimal_places=None) -class TestRoundingDecimalField(TestCase): - def test_valid_rounding(self): - field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP) - assert field.to_representation(Decimal('1.234')) == '1.24' - - field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN) - assert field.to_representation(Decimal('1.234')) == '1.23' +def test_valid_rounding(self): + field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP) + assert field.to_representation(Decimal('1.234')) == '1.24' + field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN) + assert field.to_representation(Decimal('1.234')) == '1.23' def test_invalid_rounding(self): - with pytest.raises(AssertionError) as excinfo: - serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN') + with pytest.raises(AssertionError) as excinfo: + serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN') assert 'Invalid rounding option' in str(excinfo.value) @@ -1369,46 +1365,41 @@ class TestTZWithDateTimeField(FieldValues): @override_settings(TIME_ZONE='UTC', USE_TZ=True) -class TestDefaultTZDateTimeField(TestCase): - """ +""" Test the current/default timezone handling in `DateTimeField`. """ - - @classmethod - def setup_class(cls): - cls.field = serializers.DateTimeField() - cls.kolkata = pytz.timezone('Asia/Kolkata') +@classmethod +defsetup_class(cls): + cls.field = serializers.DateTimeField() + cls.kolkata = pytz.timezone('Asia/Kolkata') def test_default_timezone(self): - assert self.field.default_timezone() == utc + assert self.field.default_timezone() == utc def test_current_timezone(self): - assert self.field.default_timezone() == utc - activate(self.kolkata) - assert self.field.default_timezone() == self.kolkata - deactivate() - assert self.field.default_timezone() == utc + assert self.field.default_timezone() == utc + activate(self.kolkata) + assert self.field.default_timezone() == self.kolkata + deactivate() + assert self.field.default_timezone() == utc @pytest.mark.skipif(pytz is None, reason='pytz not installed') @override_settings(TIME_ZONE='UTC', USE_TZ=True) -class TestCustomTimezoneForDateTimeField(TestCase): - @classmethod - def setup_class(cls): - cls.kolkata = pytz.timezone('Asia/Kolkata') - cls.date_format = '%d/%m/%Y %H:%M' +@classmethod +defsetup_class(cls): + cls.kolkata = pytz.timezone('Asia/Kolkata') + cls.date_format = '%d/%m/%Y %H:%M' def test_should_render_date_time_in_default_timezone(self): - field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format) - dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc) - - with override(self.kolkata): - rendered_date = field.to_representation(dt) + field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format) + dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc) + with override(self.kolkata): + rendered_date = field.to_representation(dt) rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format) - - assert rendered_date == rendered_date_in_timezone +assertrendered_date==rendered_date_in_timezone class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues): diff --git a/tests/test_filters.py b/tests/test_filters.py index a52f40103..ec678eb09 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -13,28 +13,25 @@ from rest_framework.compat import coreschema from rest_framework.test import APIRequestFactory factory = APIRequestFactory() - - -class BaseFilterTests(TestCase): - def setUp(self): - self.original_coreapi = filters.coreapi - filters.coreapi = True # mock it, because not None value needed - self.filter_backend = filters.BaseFilterBackend() +def setUp(self): + self.original_coreapi = filters.coreapi + filters.coreapi = True + self.filter_backend = filters.BaseFilterBackend() def tearDown(self): - filters.coreapi = self.original_coreapi + filters.coreapi = self.original_coreapi def test_filter_queryset_raises_error(self): - with pytest.raises(NotImplementedError): - self.filter_backend.filter_queryset(None, None, None) + with pytest.raises(NotImplementedError): + self.filter_backend.filter_queryset(None, None, None) @pytest.mark.skipif(not coreschema, reason='coreschema is not installed') - def test_get_schema_fields_checks_for_coreapi(self): - filters.coreapi = None - with pytest.raises(AssertionError): - self.filter_backend.get_schema_fields({}) +deftest_get_schema_fields_checks_for_coreapi(self): + filters.coreapi = None + with pytest.raises(AssertionError): + self.filter_backend.get_schema_fields({}) filters.coreapi = True - assert self.filter_backend.get_schema_fields({}) == [] +assertself.filter_backend.get_schema_fields({})==[] class SearchFilterModel(models.Model): @@ -48,137 +45,115 @@ class SearchFilterSerializer(serializers.ModelSerializer): fields = '__all__' -class SearchFilterTests(TestCase): - def setUp(self): +def setUp(self): # Sequence of title/text is: # # z abc # zz bcd # zzz cde # ... - for idx in range(10): - title = 'z' * (idx + 1) - text = ( - chr(idx + ord('a')) + - chr(idx + ord('b')) + - chr(idx + ord('c')) - ) - SearchFilterModel(title=title, text=text).save() + for idx in range(10): + title = 'z' * (idx + 1) + text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) ) + SearchFilterModel(title=title, text=text).save() def test_search(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() +request=factory.get('/',{'search':'b'}) +response=view(request) +assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}] + + def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + + view = SearchListView.as_view() +request=factory.get('/') +response=view(request) +expected=SearchFilterSerializer(SearchFilterModel.objects.all(),many=True).data +assertresponse.data==expected + + def test_exact_search(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text') + + view = SearchListView.as_view() +request=factory.get('/',{'search':'zzz'}) +response=view(request) +assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}] + + def test_startswith_search(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('title', '^text') + + view = SearchListView.as_view() +request=factory.get('/',{'search':'b'}) +response=view(request) +assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}] + + def test_regexp_search(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('$title', '$text') + + view = SearchListView.as_view() +request=factory.get('/',{'search':'z{2} ^b'}) +response=view(request) +assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}] + + def test_search_with_nonstandard_search_param(self): + with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): + reload_module(filters) class SearchListView(generics.ListAPIView): queryset = SearchFilterModel.objects.all() serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('title', 'text') - view = SearchListView.as_view() - request = factory.get('/', {'search': 'b'}) - response = view(request) - assert response.data == [ - {'id': 1, 'title': 'z', 'text': 'abc'}, - {'id': 2, 'title': 'zz', 'text': 'bcd'} - ] - - def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self): - class SearchListView(generics.ListAPIView): - queryset = SearchFilterModel.objects.all() - serializer_class = SearchFilterSerializer - filter_backends = (filters.SearchFilter,) - - view = SearchListView.as_view() - request = factory.get('/') - response = view(request) - expected = SearchFilterSerializer(SearchFilterModel.objects.all(), - many=True).data - assert response.data == expected - - def test_exact_search(self): - class SearchListView(generics.ListAPIView): - queryset = SearchFilterModel.objects.all() - serializer_class = SearchFilterSerializer - filter_backends = (filters.SearchFilter,) - search_fields = ('=title', 'text') - - view = SearchListView.as_view() - request = factory.get('/', {'search': 'zzz'}) - response = view(request) - assert response.data == [ - {'id': 3, 'title': 'zzz', 'text': 'cde'} - ] - - def test_startswith_search(self): - class SearchListView(generics.ListAPIView): - queryset = SearchFilterModel.objects.all() - serializer_class = SearchFilterSerializer - filter_backends = (filters.SearchFilter,) - search_fields = ('title', '^text') - - view = SearchListView.as_view() - request = factory.get('/', {'search': 'b'}) - response = view(request) - assert response.data == [ - {'id': 2, 'title': 'zz', 'text': 'bcd'} - ] - - def test_regexp_search(self): - class SearchListView(generics.ListAPIView): - queryset = SearchFilterModel.objects.all() - serializer_class = SearchFilterSerializer - filter_backends = (filters.SearchFilter,) - search_fields = ('$title', '$text') - - view = SearchListView.as_view() - request = factory.get('/', {'search': 'z{2} ^b'}) - response = view(request) - assert response.data == [ - {'id': 2, 'title': 'zz', 'text': 'bcd'} - ] - - def test_search_with_nonstandard_search_param(self): - with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): - reload_module(filters) - - class SearchListView(generics.ListAPIView): - queryset = SearchFilterModel.objects.all() - serializer_class = SearchFilterSerializer - filter_backends = (filters.SearchFilter,) - search_fields = ('title', 'text') - view = SearchListView.as_view() - request = factory.get('/', {'query': 'b'}) - response = view(request) - assert response.data == [ - {'id': 1, 'title': 'z', 'text': 'abc'}, - {'id': 2, 'title': 'zz', 'text': 'bcd'} - ] +request=factory.get('/',{'query':'b'}) +response=view(request) +assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}] reload_module(filters) def test_search_with_filter_subclass(self): - class CustomSearchFilter(filters.SearchFilter): + class CustomSearchFilter(filters.SearchFilter): # Filter that dynamically changes search fields - def get_search_fields(self, view, request): - if request.query_params.get('title_only'): - return ('$title',) + def get_search_fields(self, view, request): + if request.query_params.get('title_only'): + return ('$title',) return super().get_search_fields(view, request) class SearchListView(generics.ListAPIView): - queryset = SearchFilterModel.objects.all() - serializer_class = SearchFilterSerializer - filter_backends = (CustomSearchFilter,) - search_fields = ('$title', '$text') + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (CustomSearchFilter,) + search_fields = ('$title', '$text') view = SearchListView.as_view() - request = factory.get('/', {'search': r'^\w{3}$'}) - response = view(request) - assert len(response.data) == 10 - - request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'}) - response = view(request) - assert response.data == [ - {'id': 3, 'title': 'zzz', 'text': 'cde'} - ] +request=factory.get('/',{'search':r'^\w{3}$'}) +response=view(request) +assertlen(response.data)==10 +request=factory.get('/',{'search':r'^\w{3}$','title_only':'true'}) +response=view(request) +assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}] class AttributeModel(models.Model): @@ -196,31 +171,21 @@ class SearchFilterFkSerializer(serializers.ModelSerializer): fields = '__all__' -class SearchFilterFkTests(TestCase): - def test_must_call_distinct(self): - filter_ = filters.SearchFilter() - prefixes = [''] + list(filter_.lookup_prefixes) - for prefix in prefixes: - assert not filter_.must_call_distinct( - SearchFilterModelFk._meta, - ["%stitle" % prefix] - ) - assert not filter_.must_call_distinct( - SearchFilterModelFk._meta, - ["%stitle" % prefix, "%sattribute__label" % prefix] - ) +def test_must_call_distinct(self): + filter_ = filters.SearchFilter() + prefixes = [''] + list(filter_.lookup_prefixes) + for prefix in prefixes: + assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix] ) + assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix, "%sattribute__label" % prefix] ) def test_must_call_distinct_restores_meta_for_each_field(self): # In this test case the attribute of the fk model comes first in the # list of search fields. - filter_ = filters.SearchFilter() - prefixes = [''] + list(filter_.lookup_prefixes) - for prefix in prefixes: - assert not filter_.must_call_distinct( - SearchFilterModelFk._meta, - ["%sattribute__label" % prefix, "%stitle" % prefix] - ) + filter_ = filters.SearchFilter() + prefixes = [''] + list(filter_.lookup_prefixes) + for prefix in prefixes: + assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%sattribute__label" % prefix, "%stitle" % prefix] ) class SearchFilterModelM2M(models.Model): @@ -235,53 +200,41 @@ class SearchFilterM2MSerializer(serializers.ModelSerializer): fields = '__all__' -class SearchFilterM2MTests(TestCase): - def setUp(self): +def setUp(self): # Sequence of title/text/attributes is: # # z abc [1, 2, 3] # zz bcd [1, 2, 3] # zzz cde [1, 2, 3] # ... - for idx in range(3): - label = 'w' * (idx + 1) - AttributeModel.objects.create(label=label) + for idx in range(3): + label = 'w' * (idx + 1) + AttributeModel.objects.create(label=label) for idx in range(10): - title = 'z' * (idx + 1) - text = ( - chr(idx + ord('a')) + - chr(idx + ord('b')) + - chr(idx + ord('c')) - ) - SearchFilterModelM2M(title=title, text=text).save() + title = 'z' * (idx + 1) + text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) ) + SearchFilterModelM2M(title=title, text=text).save() SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3) def test_m2m_search(self): - class SearchListView(generics.ListAPIView): - queryset = SearchFilterModelM2M.objects.all() - serializer_class = SearchFilterM2MSerializer - filter_backends = (filters.SearchFilter,) - search_fields = ('=title', 'text', 'attributes__label') + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModelM2M.objects.all() + serializer_class = SearchFilterM2MSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text', 'attributes__label') view = SearchListView.as_view() - request = factory.get('/', {'search': 'zz'}) - response = view(request) - assert len(response.data) == 1 +request=factory.get('/',{'search':'zz'}) +response=view(request) +assertlen(response.data)==1 def test_must_call_distinct(self): - filter_ = filters.SearchFilter() - prefixes = [''] + list(filter_.lookup_prefixes) - for prefix in prefixes: - assert not filter_.must_call_distinct( - SearchFilterModelM2M._meta, - ["%stitle" % prefix] - ) - - assert filter_.must_call_distinct( - SearchFilterModelM2M._meta, - ["%stitle" % prefix, "%sattributes__label" % prefix] - ) + filter_ = filters.SearchFilter() + prefixes = [''] + list(filter_.lookup_prefixes) + for prefix in prefixes: + assert not filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix] ) + assert filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix] ) class Blog(models.Model): @@ -300,32 +253,27 @@ class BlogSerializer(serializers.ModelSerializer): fields = '__all__' -class SearchFilterToManyTests(TestCase): - @classmethod - def setUpTestData(cls): - b1 = Blog.objects.create(name='Blog 1') - b2 = Blog.objects.create(name='Blog 2') - - # Multiple entries on Lennon published in 1979 - distinct should deduplicate - Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1)) - Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1)) - - # Entry on Lennon *and* a separate entry in 1979 - should not match - Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1)) - Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1)) +@classmethod +defsetUpTestData(cls): + b1 = Blog.objects.create(name='Blog 1') + b2 = Blog.objects.create(name='Blog 2') + Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1)) + Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1)) + Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1)) + Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1)) def test_multiple_filter_conditions(self): - class SearchListView(generics.ListAPIView): - queryset = Blog.objects.all() - serializer_class = BlogSerializer - filter_backends = (filters.SearchFilter,) - search_fields = ('=name', 'entry__headline', '=entry__pub_date__year') + class SearchListView(generics.ListAPIView): + queryset = Blog.objects.all() + serializer_class = BlogSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('=name', 'entry__headline', '=entry__pub_date__year') view = SearchListView.as_view() - request = factory.get('/', {'search': 'Lennon,1979'}) - response = view(request) - assert len(response.data) == 1 +request=factory.get('/',{'search':'Lennon,1979'}) +response=view(request) +assertlen(response.data)==1 class SearchFilterAnnotatedSerializer(serializers.ModelSerializer): @@ -336,28 +284,23 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer): fields = ('title', 'text', 'title_text') -class SearchFilterAnnotatedFieldTests(TestCase): - @classmethod - def setUpTestData(cls): - SearchFilterModel.objects.create(title='abc', text='def') - SearchFilterModel.objects.create(title='ghi', text='jkl') +@classmethod +defsetUpTestData(cls): + SearchFilterModel.objects.create(title='abc', text='def') + SearchFilterModel.objects.create(title='ghi', text='jkl') def test_search_in_annotated_field(self): - class SearchListView(generics.ListAPIView): - queryset = SearchFilterModel.objects.annotate( - title_text=Upper( - Concat(models.F('title'), models.F('text')) - ) - ).all() - serializer_class = SearchFilterAnnotatedSerializer - filter_backends = (filters.SearchFilter,) - search_fields = ('title_text',) + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.annotate( title_text=Upper( Concat(models.F('title'), models.F('text')) ) ).all() + serializer_class = SearchFilterAnnotatedSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('title_text',) view = SearchListView.as_view() - request = factory.get('/', {'search': 'ABCDEF'}) - response = view(request) - assert len(response.data) == 1 - assert response.data[0]['title_text'] == 'ABCDEF' +request=factory.get('/',{'search':'ABCDEF'}) +response=view(request) +assertlen(response.data)==1 +assertresponse.data[0]['title_text']=='ABCDEF' class OrderingFilterModel(models.Model): @@ -403,253 +346,187 @@ class DjangoFilterOrderingSerializer(serializers.ModelSerializer): fields = '__all__' -class OrderingFilterTests(TestCase): - def setUp(self): +def setUp(self): # Sequence of title/text is: # # zyx abc # yxw bcd # xwv cde - for idx in range(3): - title = ( - chr(ord('z') - idx) + - chr(ord('y') - idx) + - chr(ord('x') - idx) - ) - text = ( - chr(idx + ord('a')) + - chr(idx + ord('b')) + - chr(idx + ord('c')) - ) - OrderingFilterModel(title=title, text=text).save() + for idx in range(3): + title = ( chr(ord('z') - idx) + chr(ord('y') - idx) + chr(ord('x') - idx) ) + text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) ) + OrderingFilterModel(title=title, text=text).save() def test_ordering(self): - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - serializer_class = OrderingFilterSerializer - filter_backends = (filters.OrderingFilter,) - ordering = ('title',) - ordering_fields = ('text',) + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('/', {'ordering': 'text'}) - response = view(request) - assert response.data == [ - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - ] +request=factory.get('/',{'ordering':'text'}) +response=view(request) +assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},] def test_reverse_ordering(self): - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - serializer_class = OrderingFilterSerializer - filter_backends = (filters.OrderingFilter,) - ordering = ('title',) - ordering_fields = ('text',) + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('/', {'ordering': '-text'}) - response = view(request) - assert response.data == [ - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - ] +request=factory.get('/',{'ordering':'-text'}) +response=view(request) +assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},] def test_incorrecturl_extrahyphens_ordering(self): - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - serializer_class = OrderingFilterSerializer - filter_backends = (filters.OrderingFilter,) - ordering = ('title',) - ordering_fields = ('text',) + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('/', {'ordering': '--text'}) - response = view(request) - assert response.data == [ - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - ] +request=factory.get('/',{'ordering':'--text'}) +response=view(request) +assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},] def test_incorrectfield_ordering(self): - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - serializer_class = OrderingFilterSerializer - filter_backends = (filters.OrderingFilter,) - ordering = ('title',) - ordering_fields = ('text',) + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('/', {'ordering': 'foobar'}) - response = view(request) - assert response.data == [ - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - ] +request=factory.get('/',{'ordering':'foobar'}) +response=view(request) +assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},] def test_default_ordering(self): - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - serializer_class = OrderingFilterSerializer - filter_backends = (filters.OrderingFilter,) - ordering = ('title',) - ordering_fields = ('text',) + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('') - response = view(request) - assert response.data == [ - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - ] +request=factory.get('') +response=view(request) +assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},] def test_default_ordering_using_string(self): - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - serializer_class = OrderingFilterSerializer - filter_backends = (filters.OrderingFilter,) - ordering = 'title' - ordering_fields = ('text',) + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter,) + ordering = 'title' + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('') - response = view(request) - assert response.data == [ - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - ] +request=factory.get('') +response=view(request) +assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},] def test_ordering_by_aggregate_field(self): # create some related models to aggregate order by - num_objs = [2, 5, 3] - for obj, num_relateds in zip(OrderingFilterModel.objects.all(), - num_objs): - for _ in range(num_relateds): - new_related = OrderingFilterRelatedModel( - related_object=obj - ) - new_related.save() + num_objs = [2, 5, 3] + for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs): + for _ in range(num_relateds): + new_related = OrderingFilterRelatedModel( related_object=obj ) + new_related.save() class OrderingListView(generics.ListAPIView): - serializer_class = OrderingFilterSerializer - filter_backends = (filters.OrderingFilter,) - ordering = 'title' - ordering_fields = '__all__' - queryset = OrderingFilterModel.objects.all().annotate( - models.Count("relateds")) + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter,) + ordering = 'title' + ordering_fields = '__all__' + queryset = OrderingFilterModel.objects.all().annotate( models.Count("relateds")) view = OrderingListView.as_view() - request = factory.get('/', {'ordering': 'relateds__count'}) - response = view(request) - assert response.data == [ - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - ] +request=factory.get('/',{'ordering':'relateds__count'}) +response=view(request) +assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},] def test_ordering_by_dotted_source(self): - for index, obj in enumerate(OrderingFilterModel.objects.all()): - OrderingFilterRelatedModel.objects.create( - related_object=obj, - index=index - ) + for index, obj in enumerate(OrderingFilterModel.objects.all()): + OrderingFilterRelatedModel.objects.create( related_object=obj, index=index ) class OrderingListView(generics.ListAPIView): - serializer_class = OrderingDottedRelatedSerializer - filter_backends = (filters.OrderingFilter,) - queryset = OrderingFilterRelatedModel.objects.all() + serializer_class = OrderingDottedRelatedSerializer + filter_backends = (filters.OrderingFilter,) + queryset = OrderingFilterRelatedModel.objects.all() view = OrderingListView.as_view() - request = factory.get('/', {'ordering': 'related_object__text'}) - response = view(request) - assert response.data == [ - {'related_title': 'zyx', 'related_text': 'abc', 'index': 0}, - {'related_title': 'yxw', 'related_text': 'bcd', 'index': 1}, - {'related_title': 'xwv', 'related_text': 'cde', 'index': 2}, - ] - - request = factory.get('/', {'ordering': '-index'}) - response = view(request) - assert response.data == [ - {'related_title': 'xwv', 'related_text': 'cde', 'index': 2}, - {'related_title': 'yxw', 'related_text': 'bcd', 'index': 1}, - {'related_title': 'zyx', 'related_text': 'abc', 'index': 0}, - ] +request=factory.get('/',{'ordering':'related_object__text'}) +response=view(request) +assertresponse.data==[{'related_title':'zyx','related_text':'abc','index':0},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'xwv','related_text':'cde','index':2},] +request=factory.get('/',{'ordering':'-index'}) +response=view(request) +assertresponse.data==[{'related_title':'xwv','related_text':'cde','index':2},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'zyx','related_text':'abc','index':0},] def test_ordering_with_nonstandard_ordering_param(self): - with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): - reload_module(filters) - - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - serializer_class = OrderingFilterSerializer - filter_backends = (filters.OrderingFilter,) - ordering = ('title',) - ordering_fields = ('text',) + with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): + reload_module(filters) + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + ordering_fields = ('text',) view = OrderingListView.as_view() - request = factory.get('/', {'order': 'text'}) - response = view(request) - assert response.data == [ - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - ] +request=factory.get('/',{'order':'text'}) +response=view(request) +assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},] reload_module(filters) def test_get_template_context(self): - class OrderingListView(generics.ListAPIView): - ordering_fields = '__all__' - serializer_class = OrderingFilterSerializer - queryset = OrderingFilterModel.objects.all() - filter_backends = (filters.OrderingFilter,) + class OrderingListView(generics.ListAPIView): + ordering_fields = '__all__' + serializer_class = OrderingFilterSerializer + queryset = OrderingFilterModel.objects.all() + filter_backends = (filters.OrderingFilter,) request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html') - view = OrderingListView.as_view() - response = view(request) - - self.assertContains(response, 'verbose title') +view=OrderingListView.as_view() +response=view(request) +self.assertContains(response,'verbose title') def test_ordering_with_overridden_get_serializer_class(self): - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - filter_backends = (filters.OrderingFilter,) - ordering = ('title',) - - # note: no ordering_fields and serializer_class specified - - def get_serializer_class(self): - return OrderingFilterSerializer + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + def get_serializer_class(self): + return OrderingFilterSerializer view = OrderingListView.as_view() - request = factory.get('/', {'ordering': 'text'}) - response = view(request) - assert response.data == [ - {'id': 1, 'title': 'zyx', 'text': 'abc'}, - {'id': 2, 'title': 'yxw', 'text': 'bcd'}, - {'id': 3, 'title': 'xwv', 'text': 'cde'}, - ] +request=factory.get('/',{'ordering':'text'}) +response=view(request) +assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},] def test_ordering_with_improper_configuration(self): - class OrderingListView(generics.ListAPIView): - queryset = OrderingFilterModel.objects.all() - filter_backends = (filters.OrderingFilter,) - ordering = ('title',) + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) # note: no ordering_fields and serializer_class # or get_serializer_class specified view = OrderingListView.as_view() - request = factory.get('/', {'ordering': 'text'}) - with self.assertRaises(ImproperlyConfigured): - view(request) +request=factory.get('/',{'ordering':'text'}) +withpytest.raises(ImproperlyConfigured): + view(request) class SensitiveOrderingFilterModel(models.Model): @@ -684,63 +561,44 @@ class SensitiveDataSerializer3(serializers.ModelSerializer): fields = ('id', 'user') -class SensitiveOrderingFilterTests(TestCase): - def setUp(self): - for idx in range(3): - username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] - password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] - SensitiveOrderingFilterModel(username=username, password=password).save() +def setUp(self): + for idx in range(3): + username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] + password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] + SensitiveOrderingFilterModel(username=username, password=password).save() def test_order_by_serializer_fields(self): - for serializer_cls in [ - SensitiveDataSerializer1, - SensitiveDataSerializer2, - SensitiveDataSerializer3 - ]: - class OrderingListView(generics.ListAPIView): - queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') - filter_backends = (filters.OrderingFilter,) - serializer_class = serializer_cls + for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]: + class OrderingListView(generics.ListAPIView): + queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') + filter_backends = (filters.OrderingFilter,) + serializer_class = serializer_cls view = OrderingListView.as_view() - request = factory.get('/', {'ordering': '-username'}) - response = view(request) - - if serializer_cls == SensitiveDataSerializer3: - username_field = 'user' +request=factory.get('/',{'ordering':'-username'}) +response=view(request) +ifserializer_cls==SensitiveDataSerializer3: + username_field = 'user' else: - username_field = 'username' + username_field = 'username' # Note: Inverse username ordering correctly applied. - assert response.data == [ - {'id': 3, username_field: 'userC'}, - {'id': 2, username_field: 'userB'}, - {'id': 1, username_field: 'userA'}, - ] + assert response.data == [{'id':3,username_field:'userC'},{'id':2,username_field:'userB'},{'id':1,username_field:'userA'},] def test_cannot_order_by_non_serializer_fields(self): - for serializer_cls in [ - SensitiveDataSerializer1, - SensitiveDataSerializer2, - SensitiveDataSerializer3 - ]: - class OrderingListView(generics.ListAPIView): - queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') - filter_backends = (filters.OrderingFilter,) - serializer_class = serializer_cls + for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]: + class OrderingListView(generics.ListAPIView): + queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') + filter_backends = (filters.OrderingFilter,) + serializer_class = serializer_cls view = OrderingListView.as_view() - request = factory.get('/', {'ordering': 'password'}) - response = view(request) - - if serializer_cls == SensitiveDataSerializer3: - username_field = 'user' +request=factory.get('/',{'ordering':'password'}) +response=view(request) +ifserializer_cls==SensitiveDataSerializer3: + username_field = 'user' else: - username_field = 'username' + username_field = 'username' # Note: The passwords are not in order. Default ordering is used. - assert response.data == [ - {'id': 1, username_field: 'userA'}, # PassB - {'id': 2, username_field: 'userB'}, # PassC - {'id': 3, username_field: 'userC'}, # PassA - ] + assert response.data == [{'id':1,username_field:'userA'},{'id':2,username_field:'userB'},{'id':3,username_field:'userC'},] diff --git a/tests/test_generateschema.py b/tests/test_generateschema.py index a6a1f2bed..02180f681 100644 --- a/tests/test_generateschema.py +++ b/tests/test_generateschema.py @@ -23,14 +23,12 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_generateschema') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class GenerateSchemaTests(TestCase): - """Tests for management command generateschema.""" - - def setUp(self): - self.out = io.StringIO() +"""Tests for management command generateschema.""" +defsetUp(self): + self.out = io.StringIO() def test_renders_default_schema_with_custom_title_url_and_description(self): - expected_out = """info: + expected_out = """info: description: Sample description title: SampleAPI version: '' @@ -42,45 +40,16 @@ class GenerateSchemaTests(TestCase): servers: - url: http://api.sample.com/ """ - call_command('generateschema', - '--title=SampleAPI', - '--url=http://api.sample.com', - '--description=Sample description', - stdout=self.out) - - self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) + call_command('generateschema', '--title=SampleAPI', '--url=http://api.sample.com', '--description=Sample description', stdout=self.out) + assert formatting.dedent(expected_out) in self.out.getvalue() def test_renders_openapi_json_schema(self): - expected_out = { - "openapi": "3.0.0", - "info": { - "version": "", - "title": "", - "description": "" - }, - "servers": [ - { - "url": "" - } - ], - "paths": { - "/": { - "get": { - "operationId": "list" - } - } - } - } - call_command('generateschema', - '--format=openapi-json', - stdout=self.out) - out_json = json.loads(self.out.getvalue()) - - self.assertDictEqual(out_json, expected_out) + expected_out = { "openapi": "3.0.0", "info": { "version": "", "title": "", "description": "" }, "servers": [ { "url": "" } ], "paths": { "/": { "get": { "operationId": "list" } } } } + call_command('generateschema', '--format=openapi-json', stdout=self.out) + out_json = json.loads(self.out.getvalue()) + assert out_json == expected_out def test_renders_corejson_schema(self): - expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" - call_command('generateschema', - '--format=corejson', - stdout=self.out) - self.assertIn(expected_out, self.out.getvalue()) + expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" + call_command('generateschema', '--format=corejson', stdout=self.out) + assert expected_out in self.out.getvalue() diff --git a/tests/test_generics.py b/tests/test_generics.py index 0b91e3465..48a65cd5d 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -75,304 +75,282 @@ class SlugBasedInstanceView(InstanceView): # Tests -class TestRootView(TestCase): - def setUp(self): - """ +def setUp(self): + """ Create 3 BasicModel instances. """ - items = ['foo', 'bar', 'baz'] - for item in items: - BasicModel(text=item).save() + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = RootView.as_view() +self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()] +self.view=RootView.as_view() def test_get_root_view(self): - """ + """ GET requests to ListCreateAPIView should return list of objects. """ - request = factory.get('/') - with self.assertNumQueries(1): - response = self.view(request).render() + request = factory.get('/') + with self.assertNumQueries(1): + response = self.view(request).render() assert response.status_code == status.HTTP_200_OK - assert response.data == self.data +assertresponse.data==self.data def test_head_root_view(self): - """ + """ HEAD requests to ListCreateAPIView should return 200. """ - request = factory.head('/') - with self.assertNumQueries(1): - response = self.view(request).render() + request = factory.head('/') + with self.assertNumQueries(1): + response = self.view(request).render() assert response.status_code == status.HTTP_200_OK def test_post_root_view(self): - """ + """ POST requests to ListCreateAPIView should create a new object. """ - data = {'text': 'foobar'} - request = factory.post('/', data, format='json') - with self.assertNumQueries(1): - response = self.view(request).render() + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') + with self.assertNumQueries(1): + response = self.view(request).render() assert response.status_code == status.HTTP_201_CREATED - assert response.data == {'id': 4, 'text': 'foobar'} - created = self.objects.get(id=4) - assert created.text == 'foobar' +assertresponse.data=={'id':4,'text':'foobar'} +created=self.objects.get(id=4) +assertcreated.text=='foobar' def test_put_root_view(self): - """ + """ PUT requests to ListCreateAPIView should not be allowed """ - data = {'text': 'foobar'} - request = factory.put('/', data, format='json') - with self.assertNumQueries(0): - response = self.view(request).render() + data = {'text': 'foobar'} + request = factory.put('/', data, format='json') + with self.assertNumQueries(0): + response = self.view(request).render() assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED - assert response.data == {"detail": 'Method "PUT" not allowed.'} +assertresponse.data=={"detail":'Method "PUT" not allowed.'} def test_delete_root_view(self): - """ + """ DELETE requests to ListCreateAPIView should not be allowed """ - request = factory.delete('/') - with self.assertNumQueries(0): - response = self.view(request).render() + request = factory.delete('/') + with self.assertNumQueries(0): + response = self.view(request).render() assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED - assert response.data == {"detail": 'Method "DELETE" not allowed.'} +assertresponse.data=={"detail":'Method "DELETE" not allowed.'} def test_post_cannot_set_id(self): - """ + """ POST requests to create a new object should not be able to set the id. """ - data = {'id': 999, 'text': 'foobar'} - request = factory.post('/', data, format='json') - with self.assertNumQueries(1): - response = self.view(request).render() + data = {'id': 999, 'text': 'foobar'} + request = factory.post('/', data, format='json') + with self.assertNumQueries(1): + response = self.view(request).render() assert response.status_code == status.HTTP_201_CREATED - assert response.data == {'id': 4, 'text': 'foobar'} - created = self.objects.get(id=4) - assert created.text == 'foobar' +assertresponse.data=={'id':4,'text':'foobar'} +created=self.objects.get(id=4) +assertcreated.text=='foobar' def test_post_error_root_view(self): - """ + """ POST requests to ListCreateAPIView in HTML should include a form error. """ - data = {'text': 'foobar' * 100} - request = factory.post('/', data, HTTP_ACCEPT='text/html') - response = self.view(request).render() - expected_error = 'Ensure this field has no more than 100 characters.' - assert expected_error in response.rendered_content.decode() + data = {'text': 'foobar' * 100} + request = factory.post('/', data, HTTP_ACCEPT='text/html') + response = self.view(request).render() + expected_error = 'Ensure this field has no more than 100 characters.' + assert expected_error in response.rendered_content.decode() EXPECTED_QUERIES_FOR_PUT = 2 - - -class TestInstanceView(TestCase): - def setUp(self): - """ +def setUp(self): + """ Create 3 BasicModel instances. """ - items = ['foo', 'bar', 'baz', 'filtered out'] - for item in items: - BasicModel(text=item).save() + items = ['foo', 'bar', 'baz', 'filtered out'] + for item in items: + BasicModel(text=item).save() self.objects = BasicModel.objects.exclude(text='filtered out') - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = InstanceView.as_view() - self.slug_based_view = SlugBasedInstanceView.as_view() +self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()] +self.view=InstanceView.as_view() +self.slug_based_view=SlugBasedInstanceView.as_view() def test_get_instance_view(self): - """ + """ GET requests to RetrieveUpdateDestroyAPIView should return a single object. """ - request = factory.get('/1') - with self.assertNumQueries(1): - response = self.view(request, pk=1).render() + request = factory.get('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() assert response.status_code == status.HTTP_200_OK - assert response.data == self.data[0] +assertresponse.data==self.data[0] def test_post_instance_view(self): - """ + """ POST requests to RetrieveUpdateDestroyAPIView should not be allowed """ - data = {'text': 'foobar'} - request = factory.post('/', data, format='json') - with self.assertNumQueries(0): - response = self.view(request).render() + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') + with self.assertNumQueries(0): + response = self.view(request).render() assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED - assert response.data == {"detail": 'Method "POST" not allowed.'} +assertresponse.data=={"detail":'Method "POST" not allowed.'} def test_put_instance_view(self): - """ + """ PUT requests to RetrieveUpdateDestroyAPIView should update an object. """ - data = {'text': 'foobar'} - request = factory.put('/1', data, format='json') - with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): - response = self.view(request, pk='1').render() + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') + with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): + response = self.view(request, pk='1').render() assert response.status_code == status.HTTP_200_OK - assert dict(response.data) == {'id': 1, 'text': 'foobar'} - updated = self.objects.get(id=1) - assert updated.text == 'foobar' +assertdict(response.data)=={'id':1,'text':'foobar'} +updated=self.objects.get(id=1) +assertupdated.text=='foobar' def test_patch_instance_view(self): - """ + """ PATCH requests to RetrieveUpdateDestroyAPIView should update an object. """ - data = {'text': 'foobar'} - request = factory.patch('/1', data, format='json') - - with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): - response = self.view(request, pk=1).render() + data = {'text': 'foobar'} + request = factory.patch('/1', data, format='json') + with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): + response = self.view(request, pk=1).render() assert response.status_code == status.HTTP_200_OK - assert response.data == {'id': 1, 'text': 'foobar'} - updated = self.objects.get(id=1) - assert updated.text == 'foobar' +assertresponse.data=={'id':1,'text':'foobar'} +updated=self.objects.get(id=1) +assertupdated.text=='foobar' def test_delete_instance_view(self): - """ + """ DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. """ - request = factory.delete('/1') - with self.assertNumQueries(2): - response = self.view(request, pk=1).render() + request = factory.delete('/1') + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() assert response.status_code == status.HTTP_204_NO_CONTENT - assert response.content == b'' - ids = [obj.id for obj in self.objects.all()] - assert ids == [2, 3] +assertresponse.content==b'' +ids=[obj.idforobjinself.objects.all()] +assertids==[2,3] def test_get_instance_view_incorrect_arg(self): - """ + """ GET requests with an incorrect pk type, should raise 404, not 500. Regression test for #890. """ - request = factory.get('/a') - with self.assertNumQueries(0): - response = self.view(request, pk='a').render() + request = factory.get('/a') + with self.assertNumQueries(0): + response = self.view(request, pk='a').render() assert response.status_code == status.HTTP_404_NOT_FOUND def test_put_cannot_set_id(self): - """ + """ PUT requests to create a new object should not be able to set the id. """ - data = {'id': 999, 'text': 'foobar'} - request = factory.put('/1', data, format='json') - with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): - response = self.view(request, pk=1).render() + data = {'id': 999, 'text': 'foobar'} + request = factory.put('/1', data, format='json') + with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): + response = self.view(request, pk=1).render() assert response.status_code == status.HTTP_200_OK - assert response.data == {'id': 1, 'text': 'foobar'} - updated = self.objects.get(id=1) - assert updated.text == 'foobar' +assertresponse.data=={'id':1,'text':'foobar'} +updated=self.objects.get(id=1) +assertupdated.text=='foobar' def test_put_to_deleted_instance(self): - """ + """ PUT requests to RetrieveUpdateDestroyAPIView should return 404 if an object does not currently exist. """ - self.objects.get(id=1).delete() - data = {'text': 'foobar'} - request = factory.put('/1', data, format='json') - with self.assertNumQueries(1): - response = self.view(request, pk=1).render() + self.objects.get(id=1).delete() + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() assert response.status_code == status.HTTP_404_NOT_FOUND def test_put_to_filtered_out_instance(self): - """ + """ PUT requests to an URL of instance which is filtered out should not be able to create new objects. """ - data = {'text': 'foo'} - filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk - request = factory.put('/{}'.format(filtered_out_pk), data, format='json') - response = self.view(request, pk=filtered_out_pk).render() - assert response.status_code == status.HTTP_404_NOT_FOUND + data = {'text': 'foo'} + filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk + request = factory.put('/{}'.format(filtered_out_pk), data, format='json') + response = self.view(request, pk=filtered_out_pk).render() + assert response.status_code == status.HTTP_404_NOT_FOUND def test_patch_cannot_create_an_object(self): - """ + """ PATCH requests should not be able to create objects. """ - data = {'text': 'foobar'} - request = factory.patch('/999', data, format='json') - with self.assertNumQueries(1): - response = self.view(request, pk=999).render() + data = {'text': 'foobar'} + request = factory.patch('/999', data, format='json') + with self.assertNumQueries(1): + response = self.view(request, pk=999).render() assert response.status_code == status.HTTP_404_NOT_FOUND - assert not self.objects.filter(id=999).exists() +assertnotself.objects.filter(id=999).exists() def test_put_error_instance_view(self): - """ + """ Incorrect PUT requests in HTML should include a form error. """ - data = {'text': 'foobar' * 100} - request = factory.put('/', data, HTTP_ACCEPT='text/html') - response = self.view(request, pk=1).render() - expected_error = 'Ensure this field has no more than 100 characters.' - assert expected_error in response.rendered_content.decode() + data = {'text': 'foobar' * 100} + request = factory.put('/', data, HTTP_ACCEPT='text/html') + response = self.view(request, pk=1).render() + expected_error = 'Ensure this field has no more than 100 characters.' + assert expected_error in response.rendered_content.decode() -class TestFKInstanceView(TestCase): - def setUp(self): - """ +def setUp(self): + """ Create 3 BasicModel instances. """ - items = ['foo', 'bar', 'baz'] - for item in items: - t = ForeignKeyTarget(name=item) - t.save() - ForeignKeySource(name='source_' + item, target=t).save() + items = ['foo', 'bar', 'baz'] + for item in items: + t = ForeignKeyTarget(name=item) + t.save() + ForeignKeySource(name='source_' + item, target=t).save() self.objects = ForeignKeySource.objects - self.data = [ - {'id': obj.id, 'name': obj.name} - for obj in self.objects.all() - ] - self.view = FKInstanceView.as_view() +self.data=[{'id':obj.id,'name':obj.name}forobjinself.objects.all()] +self.view=FKInstanceView.as_view() -class TestOverriddenGetObject(TestCase): - """ +""" Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the queryset/model mechanism but instead overrides get_object() """ - - def setUp(self): - """ +defsetUp(self): + """ Create 3 BasicModel instances. """ - items = ['foo', 'bar', 'baz'] - for item in items: - BasicModel(text=item).save() + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - - class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): - """ +self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()] +classOverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): + """ Example detail view for override of get_object(). """ - serializer_class = BasicSerializer - - def get_object(self): - pk = int(self.kwargs['pk']) - return get_object_or_404(BasicModel.objects.all(), id=pk) + serializer_class = BasicSerializer + def get_object(self): + pk = int(self.kwargs['pk']) + return get_object_or_404(BasicModel.objects.all(), id=pk) self.view = OverriddenGetObjectView.as_view() def test_overridden_get_object_view(self): - """ + """ GET requests to RetrieveUpdateDestroyAPIView should return a single object. """ - request = factory.get('/1') - with self.assertNumQueries(1): - response = self.view(request, pk=1).render() + request = factory.get('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() assert response.status_code == status.HTTP_200_OK - assert response.data == self.data[0] +assertresponse.data==self.data[0] # Regression test for #285 @@ -388,23 +366,22 @@ class CommentView(generics.ListCreateAPIView): model = Comment -class TestCreateModelWithAutoNowAddField(TestCase): - def setUp(self): - self.objects = Comment.objects - self.view = CommentView.as_view() +def setUp(self): + self.objects = Comment.objects + self.view = CommentView.as_view() def test_create_model_with_auto_now_add_field(self): - """ + """ Regression test for #285 https://github.com/encode/django-rest-framework/issues/285 """ - data = {'email': 'foobar@example.com', 'content': 'foobar'} - request = factory.post('/', data, format='json') - response = self.view(request).render() - assert response.status_code == status.HTTP_201_CREATED - created = self.objects.get(id=1) - assert created.content == 'foobar' + data = {'email': 'foobar@example.com', 'content': 'foobar'} + request = factory.post('/', data, format='json') + response = self.view(request).render() + assert response.status_code == status.HTTP_201_CREATED + created = self.objects.get(id=1) + assert created.content == 'foobar' # Test for particularly ugly regression with m2m in browsable API @@ -432,15 +409,14 @@ class ExampleView(generics.ListCreateAPIView): queryset = ClassA.objects.all() -class TestM2MBrowsableAPI(TestCase): - def test_m2m_in_browsable_api(self): - """ +def test_m2m_in_browsable_api(self): + """ Test for particularly ugly regression with m2m in browsable API """ - request = factory.get('/', HTTP_ACCEPT='text/html') - view = ExampleView().as_view() - response = view(request).render() - assert response.status_code == status.HTTP_200_OK + request = factory.get('/', HTTP_ACCEPT='text/html') + view = ExampleView().as_view() + response = view(request).render() + assert response.status_code == status.HTTP_200_OK class InclusiveFilterBackend: @@ -476,189 +452,179 @@ class DynamicSerializerView(generics.ListCreateAPIView): return DynamicSerializer -class TestFilterBackendAppliedToViews(TestCase): - def setUp(self): - """ +def setUp(self): + """ Create 3 BasicModel instances to filter on. """ - items = ['foo', 'bar', 'baz'] - for item in items: - BasicModel(text=item).save() + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] +self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()] def test_get_root_view_filters_by_name_with_filter_backend(self): - """ + """ GET requests to ListCreateAPIView should return filtered list. """ - root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) - request = factory.get('/') - response = root_view(request).render() - assert response.status_code == status.HTTP_200_OK - assert len(response.data) == 1 - assert response.data == [{'id': 1, 'text': 'foo'}] + root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) + request = factory.get('/') + response = root_view(request).render() + assert response.status_code == status.HTTP_200_OK + assert len(response.data) == 1 + assert response.data == [{'id': 1, 'text': 'foo'}] def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): - """ + """ GET requests to ListCreateAPIView should return empty list when all models are filtered out. """ - root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) - request = factory.get('/') - response = root_view(request).render() - assert response.status_code == status.HTTP_200_OK - assert response.data == [] + root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) + request = factory.get('/') + response = root_view(request).render() + assert response.status_code == status.HTTP_200_OK + assert response.data == [] def test_get_instance_view_filters_out_name_with_filter_backend(self): - """ + """ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. """ - instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) - request = factory.get('/1') - response = instance_view(request, pk=1).render() - assert response.status_code == status.HTTP_404_NOT_FOUND - assert response.data == {'detail': 'Not found.'} + instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) + request = factory.get('/1') + response = instance_view(request, pk=1).render() + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == {'detail': 'Not found.'} def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): - """ + """ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded """ - instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) - request = factory.get('/1') - response = instance_view(request, pk=1).render() - assert response.status_code == status.HTTP_200_OK - assert response.data == {'id': 1, 'text': 'foo'} + instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) + request = factory.get('/1') + response = instance_view(request, pk=1).render() + assert response.status_code == status.HTTP_200_OK + assert response.data == {'id': 1, 'text': 'foo'} def test_dynamic_serializer_form_in_browsable_api(self): - """ + """ GET requests to ListCreateAPIView should return filtered list. """ - view = DynamicSerializerView.as_view() - request = factory.get('/') - response = view(request).render() - content = response.content.decode() - assert 'field_b' in content - assert 'field_a' not in content + view = DynamicSerializerView.as_view() + request = factory.get('/') + response = view(request).render() + content = response.content.decode() + assert 'field_b' in content + assert 'field_a' not in content -class TestGuardedQueryset(TestCase): - def test_guarded_queryset(self): - class QuerysetAccessError(generics.ListAPIView): - queryset = BasicModel.objects.all() - - def get(self, request): - return Response(list(self.queryset)) +def test_guarded_queryset(self): + class QuerysetAccessError(generics.ListAPIView): + queryset = BasicModel.objects.all() + def get(self, request): + return Response(list(self.queryset)) view = QuerysetAccessError.as_view() - request = factory.get('/') - with pytest.raises(RuntimeError): - view(request).render() +request=factory.get('/') +withpytest.raises(RuntimeError): + view(request).render() -class ApiViewsTests(TestCase): - def test_create_api_view_post(self): - class MockCreateApiView(generics.CreateAPIView): - def create(self, request, *args, **kwargs): - self.called = True - self.call_args = (request, args, kwargs) +def test_create_api_view_post(self): + class MockCreateApiView(generics.CreateAPIView): + def create(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) view = MockCreateApiView() - data = ('test request', ('test arg',), {'test_kwarg': 'test'}) - view.post('test request', 'test arg', test_kwarg='test') - assert view.called is True - assert view.call_args == data +data=('test request',('test arg',),{'test_kwarg':'test'}) +view.post('test request','test arg',test_kwarg='test') +assertview.calledisTrue +assertview.call_args==data def test_destroy_api_view_delete(self): - class MockDestroyApiView(generics.DestroyAPIView): - def destroy(self, request, *args, **kwargs): - self.called = True - self.call_args = (request, args, kwargs) + class MockDestroyApiView(generics.DestroyAPIView): + def destroy(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) view = MockDestroyApiView() - data = ('test request', ('test arg',), {'test_kwarg': 'test'}) - view.delete('test request', 'test arg', test_kwarg='test') - assert view.called is True - assert view.call_args == data +data=('test request',('test arg',),{'test_kwarg':'test'}) +view.delete('test request','test arg',test_kwarg='test') +assertview.calledisTrue +assertview.call_args==data def test_update_api_view_partial_update(self): - class MockUpdateApiView(generics.UpdateAPIView): - def partial_update(self, request, *args, **kwargs): - self.called = True - self.call_args = (request, args, kwargs) + class MockUpdateApiView(generics.UpdateAPIView): + def partial_update(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) view = MockUpdateApiView() - data = ('test request', ('test arg',), {'test_kwarg': 'test'}) - view.patch('test request', 'test arg', test_kwarg='test') - assert view.called is True - assert view.call_args == data +data=('test request',('test arg',),{'test_kwarg':'test'}) +view.patch('test request','test arg',test_kwarg='test') +assertview.calledisTrue +assertview.call_args==data def test_retrieve_update_api_view_get(self): - class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): - def retrieve(self, request, *args, **kwargs): - self.called = True - self.call_args = (request, args, kwargs) + class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): + def retrieve(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) view = MockRetrieveUpdateApiView() - data = ('test request', ('test arg',), {'test_kwarg': 'test'}) - view.get('test request', 'test arg', test_kwarg='test') - assert view.called is True - assert view.call_args == data +data=('test request',('test arg',),{'test_kwarg':'test'}) +view.get('test request','test arg',test_kwarg='test') +assertview.calledisTrue +assertview.call_args==data def test_retrieve_update_api_view_put(self): - class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): - def update(self, request, *args, **kwargs): - self.called = True - self.call_args = (request, args, kwargs) + class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): + def update(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) view = MockRetrieveUpdateApiView() - data = ('test request', ('test arg',), {'test_kwarg': 'test'}) - view.put('test request', 'test arg', test_kwarg='test') - assert view.called is True - assert view.call_args == data +data=('test request',('test arg',),{'test_kwarg':'test'}) +view.put('test request','test arg',test_kwarg='test') +assertview.calledisTrue +assertview.call_args==data def test_retrieve_update_api_view_patch(self): - class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): - def partial_update(self, request, *args, **kwargs): - self.called = True - self.call_args = (request, args, kwargs) + class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): + def partial_update(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) view = MockRetrieveUpdateApiView() - data = ('test request', ('test arg',), {'test_kwarg': 'test'}) - view.patch('test request', 'test arg', test_kwarg='test') - assert view.called is True - assert view.call_args == data +data=('test request',('test arg',),{'test_kwarg':'test'}) +view.patch('test request','test arg',test_kwarg='test') +assertview.calledisTrue +assertview.call_args==data def test_retrieve_destroy_api_view_get(self): - class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): - def retrieve(self, request, *args, **kwargs): - self.called = True - self.call_args = (request, args, kwargs) + class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): + def retrieve(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) view = MockRetrieveDestroyUApiView() - data = ('test request', ('test arg',), {'test_kwarg': 'test'}) - view.get('test request', 'test arg', test_kwarg='test') - assert view.called is True - assert view.call_args == data +data=('test request',('test arg',),{'test_kwarg':'test'}) +view.get('test request','test arg',test_kwarg='test') +assertview.calledisTrue +assertview.call_args==data def test_retrieve_destroy_api_view_delete(self): - class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): - def destroy(self, request, *args, **kwargs): - self.called = True - self.call_args = (request, args, kwargs) + class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): + def destroy(self, request, *args, **kwargs): + self.called = True + self.call_args = (request, args, kwargs) view = MockRetrieveDestroyUApiView() - data = ('test request', ('test arg',), {'test_kwarg': 'test'}) - view.delete('test request', 'test arg', test_kwarg='test') - assert view.called is True - assert view.call_args == data +data=('test request',('test arg',),{'test_kwarg':'test'}) +view.delete('test request','test arg',test_kwarg='test') +assertview.calledisTrue +assertview.call_args==data -class GetObjectOr404Tests(TestCase): - def setUp(self): - super().setUp() - self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar') +def setUp(self): + super().setUp() + self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar') def test_get_object_or_404_with_valid_uuid(self): - obj = generics.get_object_or_404( - UUIDForeignKeyTarget, pk=self.uuid_object.pk - ) - assert obj == self.uuid_object + obj = generics.get_object_or_404( UUIDForeignKeyTarget, pk=self.uuid_object.pk ) + assert obj == self.uuid_object def test_get_object_or_404_with_invalid_string_for_uuid(self): - with pytest.raises(Http404): - generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid') + with pytest.raises(Http404): + generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid') diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py index e31a9ced5..ba4e112c9 100644 --- a/tests/test_htmlrenderer.py +++ b/tests/test_htmlrenderer.py @@ -42,122 +42,113 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_htmlrenderer') -class TemplateHTMLRendererTests(TestCase): - def setUp(self): - class MockResponse: - template_name = None +def setUp(self): + class MockResponse: + template_name = None self.mock_response = MockResponse() - self._monkey_patch_get_template() +self._monkey_patch_get_template() def _monkey_patch_get_template(self): - """ + """ Monkeypatch get_template """ - self.get_template = django.template.loader.get_template - - def get_template(template_name, dirs=None): - if template_name == 'example.html': - return engines['django'].from_string("example: {{ object }}") + self.get_template = django.template.loader.get_template + def get_template(template_name, dirs=None): + if template_name == 'example.html': + return engines['django'].from_string("example: {{ object }}") raise TemplateDoesNotExist(template_name) def select_template(template_name_list, dirs=None, using=None): - if template_name_list == ['example.html']: - return engines['django'].from_string("example: {{ object }}") + if template_name_list == ['example.html']: + return engines['django'].from_string("example: {{ object }}") raise TemplateDoesNotExist(template_name_list[0]) django.template.loader.get_template = get_template - django.template.loader.select_template = select_template +django.template.loader.select_template=select_template def tearDown(self): - """ + """ Revert monkeypatching """ - django.template.loader.get_template = self.get_template + django.template.loader.get_template = self.get_template def test_simple_html_view(self): - response = self.client.get('/') - self.assertContains(response, "example: foobar") - self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + response = self.client.get('/') + self.assertContains(response, "example: foobar") + assert response['Content-Type'] == 'text/html; charset=utf-8' def test_not_found_html_view(self): - response = self.client.get('/not_found') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - self.assertEqual(response.content, b"404 Not Found") - self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + response = self.client.get('/not_found') + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.content == b"404 Not Found" + assert response['Content-Type'] == 'text/html; charset=utf-8' def test_permission_denied_html_view(self): - response = self.client.get('/permission_denied') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEqual(response.content, b"403 Forbidden") - self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + response = self.client.get('/permission_denied') + assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.content == b"403 Forbidden" + assert response['Content-Type'] == 'text/html; charset=utf-8' # 2 tests below are based on order of if statements in corresponding method # of TemplateHTMLRenderer def test_get_template_names_returns_own_template_name(self): - renderer = TemplateHTMLRenderer() - renderer.template_name = 'test_template' - template_name = renderer.get_template_names(self.mock_response, view={}) - assert template_name == ['test_template'] + renderer = TemplateHTMLRenderer() + renderer.template_name = 'test_template' + template_name = renderer.get_template_names(self.mock_response, view={}) + assert template_name == ['test_template'] def test_get_template_names_returns_view_template_name(self): - renderer = TemplateHTMLRenderer() - - class MockResponse: - template_name = None + renderer = TemplateHTMLRenderer() + class MockResponse: + template_name = None class MockView: - def get_template_names(self): - return ['template from get_template_names method'] + def get_template_names(self): + return ['template from get_template_names method'] class MockView2: - template_name = 'template from template_name attribute' + template_name = 'template from template_name attribute' - template_name = renderer.get_template_names(self.mock_response, - MockView()) - assert template_name == ['template from get_template_names method'] - - template_name = renderer.get_template_names(self.mock_response, - MockView2()) - assert template_name == ['template from template_name attribute'] + template_name = renderer.get_template_names(self.mock_response,MockView()) +asserttemplate_name==['template from get_template_names method'] +template_name=renderer.get_template_names(self.mock_response,MockView2()) +asserttemplate_name==['template from template_name attribute'] def test_get_template_names_raises_error_if_no_template_found(self): - renderer = TemplateHTMLRenderer() - with pytest.raises(ImproperlyConfigured): - renderer.get_template_names(self.mock_response, view=object()) + renderer = TemplateHTMLRenderer() + with pytest.raises(ImproperlyConfigured): + renderer.get_template_names(self.mock_response, view=object()) @override_settings(ROOT_URLCONF='tests.test_htmlrenderer') -class TemplateHTMLRendererExceptionTests(TestCase): - def setUp(self): - """ +def setUp(self): + """ Monkeypatch get_template """ - self.get_template = django.template.loader.get_template - - def get_template(template_name): - if template_name == '404.html': - return engines['django'].from_string("404: {{ detail }}") + self.get_template = django.template.loader.get_template + def get_template(template_name): + if template_name == '404.html': + return engines['django'].from_string("404: {{ detail }}") if template_name == '403.html': - return engines['django'].from_string("403: {{ detail }}") + return engines['django'].from_string("403: {{ detail }}") raise TemplateDoesNotExist(template_name) django.template.loader.get_template = get_template def tearDown(self): - """ + """ Revert monkeypatching """ - django.template.loader.get_template = self.get_template + django.template.loader.get_template = self.get_template def test_not_found_html_view_with_template(self): - response = self.client.get('/not_found') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - self.assertTrue(response.content in ( - b"404: Not found", b"404 Not Found")) - self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + response = self.client.get('/not_found') + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.content in ( b"404: Not found", b"404 Not Found") + assert response['Content-Type'] == 'text/html; charset=utf-8' def test_permission_denied_html_view_with_template(self): - response = self.client.get('/permission_denied') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertTrue(response.content in (b"403: Permission denied", b"403 Forbidden")) - self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + response = self.client.get('/permission_denied') + assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.content in (b"403: Permission denied", b"403 Forbidden") + assert response['Content-Type'] == 'text/html; charset=utf-8' diff --git a/tests/test_lazy_hyperlinks.py b/tests/test_lazy_hyperlinks.py index cf3ee735f..a4287d8f0 100644 --- a/tests/test_lazy_hyperlinks.py +++ b/tests/test_lazy_hyperlinks.py @@ -34,16 +34,15 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks') -class TestLazyHyperlinkNames(TestCase): - def setUp(self): - self.example = Example.objects.create(text='foo') +def setUp(self): + self.example = Example.objects.create(text='foo') def test_lazy_hyperlink_names(self): - global str_called - context = {'request': None} - serializer = ExampleSerializer(self.example, context=context) - JSONRenderer().render(serializer.data) - assert not str_called - hyperlink_string = format_value(serializer.data['url']) - assert hyperlink_string == 'An example' - assert str_called + global str_called + context = {'request': None} + serializer = ExampleSerializer(self.example, context=context) + JSONRenderer().render(serializer.data) + assert not str_called + hyperlink_string = format_value(serializer.data['url']) + assert hyperlink_string == 'An example' + assert str_called diff --git a/tests/test_metadata.py b/tests/test_metadata.py index e1a1fd352..d67272f1f 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -308,98 +308,49 @@ class TestMetadata: assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer) -class TestSimpleMetadataFieldInfo(TestCase): - def test_null_boolean_field_info_type(self): - options = metadata.SimpleMetadata() - field_info = options.get_field_info(serializers.NullBooleanField()) - assert field_info['type'] == 'boolean' +def test_null_boolean_field_info_type(self): + options = metadata.SimpleMetadata() + field_info = options.get_field_info(serializers.NullBooleanField()) + assert field_info['type'] == 'boolean' def test_related_field_choices(self): - options = metadata.SimpleMetadata() - BasicModel.objects.create() - with self.assertNumQueries(0): - field_info = options.get_field_info( - serializers.RelatedField(queryset=BasicModel.objects.all()) - ) + options = metadata.SimpleMetadata() + BasicModel.objects.create() + with self.assertNumQueries(0): + field_info = options.get_field_info( serializers.RelatedField(queryset=BasicModel.objects.all()) ) assert 'choices' not in field_info -class TestModelSerializerMetadata(TestCase): - def test_read_only_primary_key_related_field(self): - """ +def test_read_only_primary_key_related_field(self): + """ On generic views OPTIONS should return an 'actions' key with metadata on the fields that may be supplied to PUT and POST requests. It should not fail when a read_only PrimaryKeyRelatedField is present """ - class Parent(models.Model): - integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)]) - children = models.ManyToManyField('Child') - name = models.CharField(max_length=100, blank=True, null=True) + class Parent(models.Model): + integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)]) + children = models.ManyToManyField('Child') + name = models.CharField(max_length=100, blank=True, null=True) class Child(models.Model): - name = models.CharField(max_length=100) + name = models.CharField(max_length=100) class ExampleSerializer(serializers.ModelSerializer): - children = serializers.PrimaryKeyRelatedField(read_only=True, many=True) - - class Meta: - model = Parent - fields = '__all__' + children = serializers.PrimaryKeyRelatedField(read_only=True, many=True) + class Meta: + model = Parent + fields = '__all__' class ExampleView(views.APIView): - """Example view.""" - def post(self, request): - pass + """Example view.""" + def post(self, request): + pass def get_serializer(self): - return ExampleSerializer() + return ExampleSerializer() view = ExampleView.as_view() - response = view(request=request) - expected = { - 'name': 'Example', - 'description': 'Example view.', - 'renders': [ - 'application/json', - 'text/html' - ], - 'parses': [ - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ], - 'actions': { - 'POST': { - 'id': { - 'type': 'integer', - 'required': False, - 'read_only': True, - 'label': 'ID' - }, - 'children': { - 'type': 'field', - 'required': False, - 'read_only': True, - 'label': 'Children' - }, - 'integer_field': { - 'type': 'integer', - 'required': True, - 'read_only': False, - 'label': 'Integer field', - 'min_value': 1, - 'max_value': 1000 - }, - 'name': { - 'type': 'string', - 'required': False, - 'read_only': False, - 'label': 'Name', - 'max_length': 100 - } - } - } - } - - assert response.status_code == status.HTTP_200_OK - assert response.data == expected +response=view(request=request) +expected={'name':'Example','description':'Example view.','renders':['application/json','text/html'],'parses':['application/json','application/x-www-form-urlencoded','multipart/form-data'],'actions':{'POST':{'id':{'type':'integer','required':False,'read_only':True,'label':'ID'},'children':{'type':'field','required':False,'read_only':True,'label':'Children'},'integer_field':{'type':'integer','required':True,'read_only':False,'label':'Integer field','min_value':1,'max_value':1000},'name':{'type':'string','required':False,'read_only':False,'label':'Name','max_length':100}}}} +assertresponse.status_code==status.HTTP_200_OK +assertresponse.data==expected diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 413d7885d..cde068f03 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -110,59 +110,48 @@ class UniqueChoiceModel(models.Model): name = models.CharField(max_length=254, unique=True, choices=CHOICES) -class TestModelSerializer(TestCase): - def test_create_method(self): - class TestSerializer(serializers.ModelSerializer): - non_model_field = serializers.CharField() +def test_create_method(self): + class TestSerializer(serializers.ModelSerializer): + non_model_field = serializers.CharField() + class Meta: + model = OneFieldModel + fields = ('char_field', 'non_model_field') - class Meta: - model = OneFieldModel - fields = ('char_field', 'non_model_field') - - serializer = TestSerializer(data={ - 'char_field': 'foo', - 'non_model_field': 'bar', - }) - serializer.is_valid() - - msginitial = 'Got a `TypeError` when calling `OneFieldModel.objects.create()`.' - with self.assertRaisesMessage(TypeError, msginitial): - serializer.save() + serializer = TestSerializer(data={'char_field':'foo','non_model_field':'bar',}) +serializer.is_valid() +msginitial='Got a `TypeError` when calling `OneFieldModel.objects.create()`.' +withself.assertRaisesMessage(TypeError,msginitial): + serializer.save() def test_abstract_model(self): - """ + """ Test that trying to use ModelSerializer with Abstract Models throws a ValueError exception. """ - class AbstractModel(models.Model): - afield = models.CharField(max_length=255) - - class Meta: - abstract = True + class AbstractModel(models.Model): + afield = models.CharField(max_length=255) + class Meta: + abstract = True class TestSerializer(serializers.ModelSerializer): - class Meta: - model = AbstractModel - fields = ('afield',) + class Meta: + model = AbstractModel + fields = ('afield',) - serializer = TestSerializer(data={ - 'afield': 'foo', - }) - - msginitial = 'Cannot use ModelSerializer with Abstract Models.' - with self.assertRaisesMessage(ValueError, msginitial): - serializer.is_valid() + serializer = TestSerializer(data={'afield':'foo',}) +msginitial='Cannot use ModelSerializer with Abstract Models.' +withself.assertRaisesMessage(ValueError,msginitial): + serializer.is_valid() -class TestRegularFieldMappings(TestCase): - def test_regular_fields(self): - """ +def test_regular_fields(self): + """ Model fields should map to their equivalent serializer fields. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RegularFieldsModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = '__all__' expected = dedent(""" TestSerializer(): @@ -189,14 +178,13 @@ class TestRegularFieldMappings(TestCase): custom_field = ModelField(model_field=) file_path_field = FilePathField(path='/tmp/') """) - - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_field_options(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = FieldOptionsModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = FieldOptionsModel + fields = '__all__' expected = dedent(""" TestSerializer(): @@ -209,260 +197,248 @@ class TestRegularFieldMappings(TestCase): descriptive_field = IntegerField(help_text='Some help text', label='A label') choices_field = ChoiceField(choices=(('red', 'Red'), ('blue', 'Blue'), ('green', 'Green'))) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected # merge this into test_regular_fields / RegularFieldsModel when # Django 2.1 is the minimum supported version @pytest.mark.skipif(django.VERSION < (2, 1), reason='Django version < 2.1') - def test_nullable_boolean_field(self): - class NullableBooleanModel(models.Model): - field = models.BooleanField(null=True, default=False) +deftest_nullable_boolean_field(self): + class NullableBooleanModel(models.Model): + field = models.BooleanField(null=True, default=False) class NullableBooleanSerializer(serializers.ModelSerializer): - class Meta: - model = NullableBooleanModel - fields = ['field'] + class Meta: + model = NullableBooleanModel + fields = ['field'] expected = dedent(""" NullableBooleanSerializer(): field = BooleanField(allow_null=True, required=False) """) - - self.assertEqual(repr(NullableBooleanSerializer()), expected) +assertrepr(NullableBooleanSerializer())==expected def test_method_field(self): - """ + """ Properties and methods on the model should be allowed as `Meta.fields` values, and should map to `ReadOnlyField`. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RegularFieldsModel - fields = ('auto_field', 'method') + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('auto_field', 'method') expected = dedent(""" TestSerializer(): auto_field = IntegerField(read_only=True) method = ReadOnlyField() """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_pk_fields(self): - """ + """ Both `pk` and the actual primary key name are valid in `Meta.fields`. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RegularFieldsModel - fields = ('pk', 'auto_field') + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('pk', 'auto_field') expected = dedent(""" TestSerializer(): pk = IntegerField(label='Auto field', read_only=True) auto_field = IntegerField(read_only=True) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_extra_field_kwargs(self): - """ + """ Ensure `extra_kwargs` are passed to generated fields. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RegularFieldsModel - fields = ('auto_field', 'char_field') - extra_kwargs = {'char_field': {'default': 'extra'}} + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('auto_field', 'char_field') + extra_kwargs = {'char_field': {'default': 'extra'}} expected = dedent(""" TestSerializer(): auto_field = IntegerField(read_only=True) char_field = CharField(default='extra', max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_extra_field_kwargs_required(self): - """ + """ Ensure `extra_kwargs` are passed to generated fields. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RegularFieldsModel - fields = ('auto_field', 'char_field') - extra_kwargs = {'auto_field': {'required': False, 'read_only': False}} + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('auto_field', 'char_field') + extra_kwargs = {'auto_field': {'required': False, 'read_only': False}} expected = dedent(""" TestSerializer(): auto_field = IntegerField(read_only=False, required=False) char_field = CharField(max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_invalid_field(self): - """ + """ Field names that do not map to a model field or relationship should raise a configuration errror. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RegularFieldsModel - fields = ('auto_field', 'invalid') + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('auto_field', 'invalid') expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.' - with self.assertRaisesMessage(ImproperlyConfigured, expected): - TestSerializer().fields +withself.assertRaisesMessage(ImproperlyConfigured,expected): + TestSerializer().fields def test_missing_field(self): - """ + """ Fields that have been declared on the serializer class must be included in the `Meta.fields` if it exists. """ - class TestSerializer(serializers.ModelSerializer): - missing = serializers.ReadOnlyField() + class TestSerializer(serializers.ModelSerializer): + missing = serializers.ReadOnlyField() + class Meta: + model = RegularFieldsModel + fields = ('auto_field',) - class Meta: - model = RegularFieldsModel - fields = ('auto_field',) - - expected = ( - "The field 'missing' was declared on serializer TestSerializer, " - "but has not been included in the 'fields' option." - ) - with self.assertRaisesMessage(AssertionError, expected): - TestSerializer().fields + expected = ("The field 'missing' was declared on serializer TestSerializer, ""but has not been included in the 'fields' option.") +withself.assertRaisesMessage(AssertionError,expected): + TestSerializer().fields def test_missing_superclass_field(self): - """ + """ Fields that have been declared on a parent of the serializer class may be excluded from the `Meta.fields` option. """ - class TestSerializer(serializers.ModelSerializer): - missing = serializers.ReadOnlyField() + class TestSerializer(serializers.ModelSerializer): + missing = serializers.ReadOnlyField() class ChildSerializer(TestSerializer): - class Meta: - model = RegularFieldsModel - fields = ('auto_field',) + class Meta: + model = RegularFieldsModel + fields = ('auto_field',) ChildSerializer().fields def test_choices_with_nonstandard_args(self): - class ExampleSerializer(serializers.ModelSerializer): - class Meta: - model = ChoicesModel - fields = '__all__' + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = ChoicesModel + fields = '__all__' ExampleSerializer() -class TestDurationFieldMapping(TestCase): - def test_duration_field(self): - class DurationFieldModel(models.Model): - """ +def test_duration_field(self): + class DurationFieldModel(models.Model): + """ A model that defines DurationField. """ - duration_field = models.DurationField() + duration_field = models.DurationField() class TestSerializer(serializers.ModelSerializer): - class Meta: - model = DurationFieldModel - fields = '__all__' + class Meta: + model = DurationFieldModel + fields = '__all__' expected = dedent(""" TestSerializer(): id = IntegerField(label='ID', read_only=True) duration_field = DurationField() """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_duration_field_with_validators(self): - class ValidatedDurationFieldModel(models.Model): - """ + class ValidatedDurationFieldModel(models.Model): + """ A model that defines DurationField with validators. """ - duration_field = models.DurationField( - validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))] - ) + duration_field = models.DurationField( validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))] ) class TestSerializer(serializers.ModelSerializer): - class Meta: - model = ValidatedDurationFieldModel - fields = '__all__' + class Meta: + model = ValidatedDurationFieldModel + fields = '__all__' expected = dedent(""" TestSerializer(): id = IntegerField(label='ID', read_only=True) duration_field = DurationField(max_value=datetime.timedelta(3), min_value=datetime.timedelta(1)) - """) if sys.version_info < (3, 7) else dedent(""" + """)ifsys.version_info<(3,7)elsededent(""" TestSerializer(): id = IntegerField(label='ID', read_only=True) duration_field = DurationField(max_value=datetime.timedelta(days=3), min_value=datetime.timedelta(days=1)) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected -class TestGenericIPAddressFieldValidation(TestCase): - def test_ip_address_validation(self): - class IPAddressFieldModel(models.Model): - address = models.GenericIPAddressField() +def test_ip_address_validation(self): + class IPAddressFieldModel(models.Model): + address = models.GenericIPAddressField() class TestSerializer(serializers.ModelSerializer): - class Meta: - model = IPAddressFieldModel - fields = '__all__' + class Meta: + model = IPAddressFieldModel + fields = '__all__' s = TestSerializer(data={'address': 'not an ip address'}) - self.assertFalse(s.is_valid()) - self.assertEqual(1, len(s.errors['address']), - 'Unexpected number of validation errors: ' - '{}'.format(s.errors)) +assertnots.is_valid() +assert1==len(s.errors['address']),'Unexpected number of validation errors: ''{}'.format(s.errors) @pytest.mark.skipif('not postgres_fields') -class TestPosgresFieldsMapping(TestCase): - def test_hstore_field(self): - class HStoreFieldModel(models.Model): - hstore_field = postgres_fields.HStoreField() +def test_hstore_field(self): + class HStoreFieldModel(models.Model): + hstore_field = postgres_fields.HStoreField() class TestSerializer(serializers.ModelSerializer): - class Meta: - model = HStoreFieldModel - fields = ['hstore_field'] + class Meta: + model = HStoreFieldModel + fields = ['hstore_field'] expected = dedent(""" TestSerializer(): hstore_field = HStoreField() """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_array_field(self): - class ArrayFieldModel(models.Model): - array_field = postgres_fields.ArrayField(base_field=models.CharField()) + class ArrayFieldModel(models.Model): + array_field = postgres_fields.ArrayField(base_field=models.CharField()) class TestSerializer(serializers.ModelSerializer): - class Meta: - model = ArrayFieldModel - fields = ['array_field'] + class Meta: + model = ArrayFieldModel + fields = ['array_field'] expected = dedent(""" TestSerializer(): array_field = ListField(child=CharField(label='Array field', validators=[])) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_json_field(self): - class JSONFieldModel(models.Model): - json_field = postgres_fields.JSONField() + class JSONFieldModel(models.Model): + json_field = postgres_fields.JSONField() class TestSerializer(serializers.ModelSerializer): - class Meta: - model = JSONFieldModel - fields = ['json_field'] + class Meta: + model = JSONFieldModel + fields = ['json_field'] expected = dedent(""" TestSerializer(): json_field = JSONField(style={'base_template': 'textarea.html'}) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected # Tests for relational field mappings. @@ -505,12 +481,11 @@ class UniqueTogetherModel(models.Model): unique_together = ("foreign_key", "one_to_one") -class TestRelationalFieldMappings(TestCase): - def test_pk_relations(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RelationalModel - fields = '__all__' +def test_pk_relations(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + fields = '__all__' expected = dedent(""" TestSerializer(): @@ -520,14 +495,14 @@ class TestRelationalFieldMappings(TestCase): many_to_many = PrimaryKeyRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all()) through = PrimaryKeyRelatedField(many=True, read_only=True) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_nested_relations(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RelationalModel - depth = 1 - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + depth = 1 + fields = '__all__' expected = dedent(""" TestSerializer(): @@ -545,13 +520,13 @@ class TestRelationalFieldMappings(TestCase): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_hyperlinked_relations(self): - class TestSerializer(serializers.HyperlinkedModelSerializer): - class Meta: - model = RelationalModel - fields = '__all__' + class TestSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RelationalModel + fields = '__all__' expected = dedent(""" TestSerializer(): @@ -561,14 +536,14 @@ class TestRelationalFieldMappings(TestCase): many_to_many = HyperlinkedRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_nested_hyperlinked_relations(self): - class TestSerializer(serializers.HyperlinkedModelSerializer): - class Meta: - model = RelationalModel - depth = 1 - fields = '__all__' + class TestSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RelationalModel + depth = 1 + fields = '__all__' expected = dedent(""" TestSerializer(): @@ -586,19 +561,15 @@ class TestRelationalFieldMappings(TestCase): url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') name = CharField(max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_nested_hyperlinked_relations_starred_source(self): - class TestSerializer(serializers.HyperlinkedModelSerializer): - class Meta: - model = RelationalModel - depth = 1 - fields = '__all__' - - extra_kwargs = { - 'url': { - 'source': '*', - }} + class TestSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RelationalModel + depth = 1 + fields = '__all__' + extra_kwargs = { 'url': { 'source': '*', }} expected = dedent(""" TestSerializer(): @@ -616,15 +587,15 @@ class TestRelationalFieldMappings(TestCase): url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') name = CharField(max_length=100) """) - self.maxDiff = None - self.assertEqual(repr(TestSerializer()), expected) +self.maxDiff=None +assertrepr(TestSerializer())==expected def test_nested_unique_together_relations(self): - class TestSerializer(serializers.HyperlinkedModelSerializer): - class Meta: - model = UniqueTogetherModel - depth = 1 - fields = '__all__' + class TestSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = UniqueTogetherModel + depth = 1 + fields = '__all__' expected = dedent(""" TestSerializer(): @@ -636,13 +607,13 @@ class TestRelationalFieldMappings(TestCase): url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') name = CharField(max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_pk_reverse_foreign_key(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeyTargetModel - fields = ('id', 'name', 'reverse_foreign_key') + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeyTargetModel + fields = ('id', 'name', 'reverse_foreign_key') expected = dedent(""" TestSerializer(): @@ -650,13 +621,13 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_pk_reverse_one_to_one(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = OneToOneTargetModel - fields = ('id', 'name', 'reverse_one_to_one') + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneTargetModel + fields = ('id', 'name', 'reverse_one_to_one') expected = dedent(""" TestSerializer(): @@ -664,13 +635,13 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_pk_reverse_many_to_many(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = ManyToManyTargetModel - fields = ('id', 'name', 'reverse_many_to_many') + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = ManyToManyTargetModel + fields = ('id', 'name', 'reverse_many_to_many') expected = dedent(""" TestSerializer(): @@ -678,13 +649,13 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected def test_pk_reverse_through(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = ThroughTargetModel - fields = ('id', 'name', 'reverse_through') + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = ThroughTargetModel + fields = ('id', 'name', 'reverse_through') expected = dedent(""" TestSerializer(): @@ -692,7 +663,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_through = PrimaryKeyRelatedField(many=True, read_only=True) """) - self.assertEqual(repr(TestSerializer()), expected) +assertrepr(TestSerializer())==expected class DisplayValueTargetModel(models.Model): @@ -706,171 +677,91 @@ class DisplayValueModel(models.Model): color = models.ForeignKey(DisplayValueTargetModel, on_delete=models.CASCADE) -class TestRelationalFieldDisplayValue(TestCase): - def setUp(self): - DisplayValueTargetModel.objects.bulk_create([ - DisplayValueTargetModel(name='Red'), - DisplayValueTargetModel(name='Yellow'), - DisplayValueTargetModel(name='Green'), - ]) +def setUp(self): + DisplayValueTargetModel.objects.bulk_create([ DisplayValueTargetModel(name='Red'), DisplayValueTargetModel(name='Yellow'), DisplayValueTargetModel(name='Green'), ]) def test_default_display_value(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = DisplayValueModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = DisplayValueModel + fields = '__all__' serializer = TestSerializer() - expected = OrderedDict([(1, 'Red Color'), (2, 'Yellow Color'), (3, 'Green Color')]) - self.assertEqual(serializer.fields['color'].choices, expected) +expected=OrderedDict([(1,'Red Color'),(2,'Yellow Color'),(3,'Green Color')]) +assertserializer.fields['color'].choices==expected def test_custom_display_value(self): - class TestField(serializers.PrimaryKeyRelatedField): - def display_value(self, instance): - return 'My %s Color' % (instance.name) + class TestField(serializers.PrimaryKeyRelatedField): + def display_value(self, instance): + return 'My %s Color' % (instance.name) class TestSerializer(serializers.ModelSerializer): - color = TestField(queryset=DisplayValueTargetModel.objects.all()) - - class Meta: - model = DisplayValueModel - fields = '__all__' + color = TestField(queryset=DisplayValueTargetModel.objects.all()) + class Meta: + model = DisplayValueModel + fields = '__all__' serializer = TestSerializer() - expected = OrderedDict([(1, 'My Red Color'), (2, 'My Yellow Color'), (3, 'My Green Color')]) - self.assertEqual(serializer.fields['color'].choices, expected) +expected=OrderedDict([(1,'My Red Color'),(2,'My Yellow Color'),(3,'My Green Color')]) +assertserializer.fields['color'].choices==expected -class TestIntegration(TestCase): - def setUp(self): - self.foreign_key_target = ForeignKeyTargetModel.objects.create( - name='foreign_key' - ) - self.one_to_one_target = OneToOneTargetModel.objects.create( - name='one_to_one' - ) - self.many_to_many_targets = [ - ManyToManyTargetModel.objects.create( - name='many_to_many (%d)' % idx - ) for idx in range(3) - ] - self.instance = RelationalModel.objects.create( - foreign_key=self.foreign_key_target, - one_to_one=self.one_to_one_target, - ) - self.instance.many_to_many.set(self.many_to_many_targets) +def setUp(self): + self.foreign_key_target = ForeignKeyTargetModel.objects.create( name='foreign_key' ) + self.one_to_one_target = OneToOneTargetModel.objects.create( name='one_to_one' ) + self.many_to_many_targets = [ ManyToManyTargetModel.objects.create( name='many_to_many (%d)' % idx ) for idx in range(3) ] + self.instance = RelationalModel.objects.create( foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, ) + self.instance.many_to_many.set(self.many_to_many_targets) def test_pk_retrival(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RelationalModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + fields = '__all__' serializer = TestSerializer(self.instance) - expected = { - 'id': self.instance.pk, - 'foreign_key': self.foreign_key_target.pk, - 'one_to_one': self.one_to_one_target.pk, - 'many_to_many': [item.pk for item in self.many_to_many_targets], - 'through': [] - } - self.assertEqual(serializer.data, expected) +expected={'id':self.instance.pk,'foreign_key':self.foreign_key_target.pk,'one_to_one':self.one_to_one_target.pk,'many_to_many':[item.pkforiteminself.many_to_many_targets],'through':[]} +assertserializer.data==expected def test_pk_create(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RelationalModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + fields = '__all__' - new_foreign_key = ForeignKeyTargetModel.objects.create( - name='foreign_key' - ) - new_one_to_one = OneToOneTargetModel.objects.create( - name='one_to_one' - ) - new_many_to_many = [ - ManyToManyTargetModel.objects.create( - name='new many_to_many (%d)' % idx - ) for idx in range(3) - ] - data = { - 'foreign_key': new_foreign_key.pk, - 'one_to_one': new_one_to_one.pk, - 'many_to_many': [item.pk for item in new_many_to_many], - } - - # Serializer should validate okay. - serializer = TestSerializer(data=data) - assert serializer.is_valid() - - # Creating the instance, relationship attributes should be set. - instance = serializer.save() - assert instance.foreign_key.pk == new_foreign_key.pk - assert instance.one_to_one.pk == new_one_to_one.pk - assert [ - item.pk for item in instance.many_to_many.all() - ] == [ - item.pk for item in new_many_to_many - ] - assert list(instance.through.all()) == [] - - # Representation should be correct. - expected = { - 'id': instance.pk, - 'foreign_key': new_foreign_key.pk, - 'one_to_one': new_one_to_one.pk, - 'many_to_many': [item.pk for item in new_many_to_many], - 'through': [] - } - self.assertEqual(serializer.data, expected) + new_foreign_key = ForeignKeyTargetModel.objects.create(name='foreign_key') +new_one_to_one=OneToOneTargetModel.objects.create(name='one_to_one') +new_many_to_many=[ManyToManyTargetModel.objects.create(name='new many_to_many (%d)'%idx)foridxinrange(3)] +data={'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],} +serializer=TestSerializer(data=data) +assertserializer.is_valid() +instance=serializer.save() +assertinstance.foreign_key.pk==new_foreign_key.pk +assertinstance.one_to_one.pk==new_one_to_one.pk +assert[item.pkforitemininstance.many_to_many.all()]==[item.pkforiteminnew_many_to_many] +assertlist(instance.through.all())==[] +expected={'id':instance.pk,'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],'through':[]} +assertserializer.data==expected def test_pk_update(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RelationalModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + fields = '__all__' - new_foreign_key = ForeignKeyTargetModel.objects.create( - name='foreign_key' - ) - new_one_to_one = OneToOneTargetModel.objects.create( - name='one_to_one' - ) - new_many_to_many = [ - ManyToManyTargetModel.objects.create( - name='new many_to_many (%d)' % idx - ) for idx in range(3) - ] - data = { - 'foreign_key': new_foreign_key.pk, - 'one_to_one': new_one_to_one.pk, - 'many_to_many': [item.pk for item in new_many_to_many], - } - - # Serializer should validate okay. - serializer = TestSerializer(self.instance, data=data) - assert serializer.is_valid() - - # Creating the instance, relationship attributes should be set. - instance = serializer.save() - assert instance.foreign_key.pk == new_foreign_key.pk - assert instance.one_to_one.pk == new_one_to_one.pk - assert [ - item.pk for item in instance.many_to_many.all() - ] == [ - item.pk for item in new_many_to_many - ] - assert list(instance.through.all()) == [] - - # Representation should be correct. - expected = { - 'id': self.instance.pk, - 'foreign_key': new_foreign_key.pk, - 'one_to_one': new_one_to_one.pk, - 'many_to_many': [item.pk for item in new_many_to_many], - 'through': [] - } - self.assertEqual(serializer.data, expected) + new_foreign_key = ForeignKeyTargetModel.objects.create(name='foreign_key') +new_one_to_one=OneToOneTargetModel.objects.create(name='one_to_one') +new_many_to_many=[ManyToManyTargetModel.objects.create(name='new many_to_many (%d)'%idx)foridxinrange(3)] +data={'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],} +serializer=TestSerializer(self.instance,data=data) +assertserializer.is_valid() +instance=serializer.save() +assertinstance.foreign_key.pk==new_foreign_key.pk +assertinstance.one_to_one.pk==new_one_to_one.pk +assert[item.pkforitemininstance.many_to_many.all()]==[item.pkforiteminnew_many_to_many] +assertlist(instance.through.all())==[] +expected={'id':self.instance.pk,'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],'through':[]} +assertserializer.data==expected # Tests for bulk create using `ListSerializer`. @@ -879,109 +770,88 @@ class BulkCreateModel(models.Model): name = models.CharField(max_length=10) -class TestBulkCreate(TestCase): - def test_bulk_create(self): - class BasicModelSerializer(serializers.ModelSerializer): - class Meta: - model = BulkCreateModel - fields = ('name',) +def test_bulk_create(self): + class BasicModelSerializer(serializers.ModelSerializer): + class Meta: + model = BulkCreateModel + fields = ('name',) class BulkCreateSerializer(serializers.ListSerializer): - child = BasicModelSerializer() + child = BasicModelSerializer() data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}] - serializer = BulkCreateSerializer(data=data) - assert serializer.is_valid() - - # Objects are returned by save(). - instances = serializer.save() - assert len(instances) == 3 - assert [item.name for item in instances] == ['a', 'b', 'c'] - - # Objects have been created in the database. - assert BulkCreateModel.objects.count() == 3 - assert list(BulkCreateModel.objects.values_list('name', flat=True)) == ['a', 'b', 'c'] - - # Serializer returns correct data. - assert serializer.data == data +serializer=BulkCreateSerializer(data=data) +assertserializer.is_valid() +instances=serializer.save() +assertlen(instances)==3 +assert[item.nameforitemininstances]==['a','b','c'] +assertBulkCreateModel.objects.count()==3 +assertlist(BulkCreateModel.objects.values_list('name',flat=True))==['a','b','c'] +assertserializer.data==data class MetaClassTestModel(models.Model): text = models.CharField(max_length=100) -class TestSerializerMetaClass(TestCase): - def test_meta_class_fields_option(self): - class ExampleSerializer(serializers.ModelSerializer): - class Meta: - model = MetaClassTestModel - fields = 'text' +def test_meta_class_fields_option(self): + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = MetaClassTestModel + fields = 'text' msginitial = "The `fields` option must be a list or tuple" - with self.assertRaisesMessage(TypeError, msginitial): - ExampleSerializer().fields +withself.assertRaisesMessage(TypeError,msginitial): + ExampleSerializer().fields def test_meta_class_exclude_option(self): - class ExampleSerializer(serializers.ModelSerializer): - class Meta: - model = MetaClassTestModel - exclude = 'text' + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = MetaClassTestModel + exclude = 'text' msginitial = "The `exclude` option must be a list or tuple" - with self.assertRaisesMessage(TypeError, msginitial): - ExampleSerializer().fields +withself.assertRaisesMessage(TypeError,msginitial): + ExampleSerializer().fields def test_meta_class_fields_and_exclude_options(self): - class ExampleSerializer(serializers.ModelSerializer): - class Meta: - model = MetaClassTestModel - fields = ('text',) - exclude = ('text',) + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = MetaClassTestModel + fields = ('text',) + exclude = ('text',) msginitial = "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer." - with self.assertRaisesMessage(AssertionError, msginitial): - ExampleSerializer().fields +withself.assertRaisesMessage(AssertionError,msginitial): + ExampleSerializer().fields def test_declared_fields_with_exclude_option(self): - class ExampleSerializer(serializers.ModelSerializer): - text = serializers.CharField() + class ExampleSerializer(serializers.ModelSerializer): + text = serializers.CharField() + class Meta: + model = MetaClassTestModel + exclude = ('text',) - class Meta: - model = MetaClassTestModel - exclude = ('text',) - - expected = ( - "Cannot both declare the field 'text' and include it in the " - "ExampleSerializer 'exclude' option. Remove the field or, if " - "inherited from a parent serializer, disable with `text = None`." - ) - with self.assertRaisesMessage(AssertionError, expected): - ExampleSerializer().fields + expected = ("Cannot both declare the field 'text' and include it in the ""ExampleSerializer 'exclude' option. Remove the field or, if ""inherited from a parent serializer, disable with `text = None`.") +withself.assertRaisesMessage(AssertionError,expected): + ExampleSerializer().fields -class Issue2704TestCase(TestCase): - def test_queryset_all(self): - class TestSerializer(serializers.ModelSerializer): - additional_attr = serializers.CharField() - - class Meta: - model = OneFieldModel - fields = ('char_field', 'additional_attr') +def test_queryset_all(self): + class TestSerializer(serializers.ModelSerializer): + additional_attr = serializers.CharField() + class Meta: + model = OneFieldModel + fields = ('char_field', 'additional_attr') OneFieldModel.objects.create(char_field='abc') - qs = OneFieldModel.objects.all() - - for o in qs: - o.additional_attr = '123' +qs=OneFieldModel.objects.all() +foroinqs: + o.additional_attr = '123' serializer = TestSerializer(instance=qs, many=True) - - expected = [{ - 'char_field': 'abc', - 'additional_attr': '123', - }] - - assert serializer.data == expected +expected=[{'char_field':'abc','additional_attr':'123',}] +assertserializer.data==expected class DecimalFieldModel(models.Model): @@ -992,78 +862,71 @@ class DecimalFieldModel(models.Model): ) -class TestDecimalFieldMappings(TestCase): - def test_decimal_field_has_decimal_validator(self): - """ +def test_decimal_field_has_decimal_validator(self): + """ Test that a `DecimalField` has no `DecimalValidator`. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = DecimalFieldModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = DecimalFieldModel + fields = '__all__' serializer = TestSerializer() - - assert len(serializer.fields['decimal_field'].validators) == 2 +assertlen(serializer.fields['decimal_field'].validators)==2 def test_min_value_is_passed(self): - """ + """ Test that the `MinValueValidator` is converted to the `min_value` argument for the field. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = DecimalFieldModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = DecimalFieldModel + fields = '__all__' serializer = TestSerializer() - - assert serializer.fields['decimal_field'].min_value == 1 +assertserializer.fields['decimal_field'].min_value==1 def test_max_value_is_passed(self): - """ + """ Test that the `MaxValueValidator` is converted to the `max_value` argument for the field. """ - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = DecimalFieldModel - fields = '__all__' + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = DecimalFieldModel + fields = '__all__' serializer = TestSerializer() - - assert serializer.fields['decimal_field'].max_value == 3 +assertserializer.fields['decimal_field'].max_value==3 -class TestMetaInheritance(TestCase): - def test_extra_kwargs_not_altered(self): - class TestSerializer(serializers.ModelSerializer): - non_model_field = serializers.CharField() - - class Meta: - model = OneFieldModel - read_only_fields = ('char_field', 'non_model_field') - fields = read_only_fields - extra_kwargs = {} +def test_extra_kwargs_not_altered(self): + class TestSerializer(serializers.ModelSerializer): + non_model_field = serializers.CharField() + class Meta: + model = OneFieldModel + read_only_fields = ('char_field', 'non_model_field') + fields = read_only_fields + extra_kwargs = {} class ChildSerializer(TestSerializer): - class Meta(TestSerializer.Meta): - read_only_fields = () + class Meta(TestSerializer.Meta): + read_only_fields = () test_expected = dedent(""" TestSerializer(): char_field = CharField(read_only=True) non_model_field = CharField() """) - - child_expected = dedent(""" +child_expected=dedent(""" ChildSerializer(): char_field = CharField(max_length=100) non_model_field = CharField() """) - self.assertEqual(repr(ChildSerializer()), child_expected) - self.assertEqual(repr(TestSerializer()), test_expected) - self.assertEqual(repr(ChildSerializer()), child_expected) +assertrepr(ChildSerializer())==child_expected +assertrepr(TestSerializer())==test_expected +assertrepr(ChildSerializer())==child_expected class OneToOneTargetTestModel(models.Model): @@ -1074,57 +937,53 @@ class OneToOneSourceTestModel(models.Model): target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE) -class TestModelFieldValues(TestCase): - def test_model_field(self): - class ExampleSerializer(serializers.ModelSerializer): - class Meta: - model = OneToOneSourceTestModel - fields = ('target',) +def test_model_field(self): + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneSourceTestModel + fields = ('target',) target = OneToOneTargetTestModel(id=1, text='abc') - source = OneToOneSourceTestModel(target=target) - serializer = ExampleSerializer(source) - self.assertEqual(serializer.data, {'target': 1}) +source=OneToOneSourceTestModel(target=target) +serializer=ExampleSerializer(source) +assertserializer.data=={'target':1} -class TestUniquenessOverride(TestCase): - def test_required_not_overwritten(self): - class TestModel(models.Model): - field_1 = models.IntegerField(null=True) - field_2 = models.IntegerField() - - class Meta: - unique_together = (('field_1', 'field_2'),) +def test_required_not_overwritten(self): + class TestModel(models.Model): + field_1 = models.IntegerField(null=True) + field_2 = models.IntegerField() + class Meta: + unique_together = (('field_1', 'field_2'),) class TestSerializer(serializers.ModelSerializer): - class Meta: - model = TestModel - fields = '__all__' - extra_kwargs = {'field_1': {'required': False}} + class Meta: + model = TestModel + fields = '__all__' + extra_kwargs = {'field_1': {'required': False}} fields = TestSerializer().fields - self.assertFalse(fields['field_1'].required) - self.assertTrue(fields['field_2'].required) +assertnotfields['field_1'].required +assertfields['field_2'].required -class Issue3674Test(TestCase): - def test_nonPK_foreignkey_model_serializer(self): - class TestParentModel(models.Model): - title = models.CharField(max_length=64) +def test_nonPK_foreignkey_model_serializer(self): + class TestParentModel(models.Model): + title = models.CharField(max_length=64) class TestChildModel(models.Model): - parent = models.ForeignKey(TestParentModel, related_name='children', on_delete=models.CASCADE) - value = models.CharField(primary_key=True, max_length=64) + parent = models.ForeignKey(TestParentModel, related_name='children', on_delete=models.CASCADE) + value = models.CharField(primary_key=True, max_length=64) class TestChildModelSerializer(serializers.ModelSerializer): - class Meta: - model = TestChildModel - fields = ('value', 'parent') + class Meta: + model = TestChildModel + fields = ('value', 'parent') class TestParentModelSerializer(serializers.ModelSerializer): - class Meta: - model = TestParentModel - fields = ('id', 'title', 'children') + class Meta: + model = TestParentModel + fields = ('id', 'title', 'children') parent_expected = dedent(""" TestParentModelSerializer(): @@ -1132,106 +991,91 @@ class Issue3674Test(TestCase): title = CharField(max_length=64) children = PrimaryKeyRelatedField(many=True, queryset=TestChildModel.objects.all()) """) - self.assertEqual(repr(TestParentModelSerializer()), parent_expected) - - child_expected = dedent(""" +assertrepr(TestParentModelSerializer())==parent_expected +child_expected=dedent(""" TestChildModelSerializer(): value = CharField(max_length=64, validators=[]) parent = PrimaryKeyRelatedField(queryset=TestParentModel.objects.all()) """) - self.assertEqual(repr(TestChildModelSerializer()), child_expected) +assertrepr(TestChildModelSerializer())==child_expected def test_nonID_PK_foreignkey_model_serializer(self): - class TestChildModelSerializer(serializers.ModelSerializer): - class Meta: - model = Issue3674ChildModel - fields = ('value', 'parent') + class TestChildModelSerializer(serializers.ModelSerializer): + class Meta: + model = Issue3674ChildModel + fields = ('value', 'parent') class TestParentModelSerializer(serializers.ModelSerializer): - class Meta: - model = Issue3674ParentModel - fields = ('id', 'title', 'children') + class Meta: + model = Issue3674ParentModel + fields = ('id', 'title', 'children') parent = Issue3674ParentModel.objects.create(title='abc') - child = Issue3674ChildModel.objects.create(value='def', parent=parent) - - parent_serializer = TestParentModelSerializer(parent) - child_serializer = TestChildModelSerializer(child) - - parent_expected = {'children': ['def'], 'id': 1, 'title': 'abc'} - self.assertEqual(parent_serializer.data, parent_expected) - - child_expected = {'parent': 1, 'value': 'def'} - self.assertEqual(child_serializer.data, child_expected) +child=Issue3674ChildModel.objects.create(value='def',parent=parent) +parent_serializer=TestParentModelSerializer(parent) +child_serializer=TestChildModelSerializer(child) +parent_expected={'children':['def'],'id':1,'title':'abc'} +assertparent_serializer.data==parent_expected +child_expected={'parent':1,'value':'def'} +assertchild_serializer.data==child_expected -class Issue4897TestCase(TestCase): - def test_should_assert_if_writing_readonly_fields(self): - class TestSerializer(serializers.ModelSerializer): - class Meta: - model = OneFieldModel - fields = ('char_field',) - readonly_fields = fields +def test_should_assert_if_writing_readonly_fields(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = OneFieldModel + fields = ('char_field',) + readonly_fields = fields obj = OneFieldModel.objects.create(char_field='abc') - - with pytest.raises(AssertionError) as cm: - TestSerializer(obj).fields +withpytest.raises(AssertionError)ascm: + TestSerializer(obj).fields cm.match(r'readonly_fields') -class Test5004UniqueChoiceField(TestCase): - def test_unique_choice_field(self): - class TestUniqueChoiceSerializer(serializers.ModelSerializer): - class Meta: - model = UniqueChoiceModel - fields = '__all__' +def test_unique_choice_field(self): + class TestUniqueChoiceSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueChoiceModel + fields = '__all__' UniqueChoiceModel.objects.create(name='choice1') - serializer = TestUniqueChoiceSerializer(data={'name': 'choice1'}) - assert not serializer.is_valid() - assert serializer.errors == {'name': ['unique choice model with this name already exists.']} +serializer=TestUniqueChoiceSerializer(data={'name':'choice1'}) +assertnotserializer.is_valid() +assertserializer.errors=={'name':['unique choice model with this name already exists.']} -class TestFieldSource(TestCase): - def test_traverse_nullable_fk(self): - """ +def test_traverse_nullable_fk(self): + """ A dotted source with nullable elements uses default when any item in the chain is None. #5849. Similar to model example from test_serializer.py `test_default_for_multiple_dotted_source` method, but using RelatedField, rather than CharField. """ - class TestSerializer(serializers.ModelSerializer): - target = serializers.PrimaryKeyRelatedField( - source='target.target', read_only=True, allow_null=True, default=None - ) - - class Meta: - model = NestedForeignKeySource - fields = ('target', ) + class TestSerializer(serializers.ModelSerializer): + target = serializers.PrimaryKeyRelatedField( source='target.target', read_only=True, allow_null=True, default=None ) + class Meta: + model = NestedForeignKeySource + fields = ('target', ) model = NestedForeignKeySource.objects.create() - assert TestSerializer(model).data['target'] is None +assertTestSerializer(model).data['target']isNone def test_named_field_source(self): - class TestSerializer(serializers.ModelSerializer): + class TestSerializer(serializers.ModelSerializer): - class Meta: - model = RegularFieldsModel - fields = ('number_field',) - extra_kwargs = { - 'number_field': { - 'source': 'integer_field' - } - } + class Meta: + model = RegularFieldsModel + fields = ('number_field',) + extra_kwargs = { 'number_field': { 'source': 'integer_field' } } expected = dedent(""" TestSerializer(): number_field = IntegerField(source='integer_field') """) - self.maxDiff = None - self.assertEqual(repr(TestSerializer()), expected) +self.maxDiff=None +assertrepr(TestSerializer())==expected class Issue6110TestModel(models.Model): @@ -1247,13 +1091,12 @@ class Issue6110ModelSerializer(serializers.ModelSerializer): fields = ('name',) -class Issue6110Test(TestCase): - def test_model_serializer_custom_manager(self): - instance = Issue6110ModelSerializer().create({'name': 'test_name'}) - self.assertEqual(instance.name, 'test_name') +def test_model_serializer_custom_manager(self): + instance = Issue6110ModelSerializer().create({'name': 'test_name'}) + assert instance.name == 'test_name' def test_model_serializer_custom_manager_error_message(self): - msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.') - with self.assertRaisesMessage(TypeError, msginitial): - Issue6110ModelSerializer().create({'wrong_param': 'wrong_param'}) + msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.') + with self.assertRaisesMessage(TypeError, msginitial): + Issue6110ModelSerializer().create({'wrong_param': 'wrong_param'}) diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py index 1e8ab3448..52b6052b8 100644 --- a/tests/test_multitable_inheritance.py +++ b/tests/test_multitable_inheritance.py @@ -33,35 +33,31 @@ class AssociatedModelSerializer(serializers.ModelSerializer): # Tests -class InheritedModelSerializationTests(TestCase): - def test_multitable_inherited_model_fields_as_expected(self): - """ +def test_multitable_inherited_model_fields_as_expected(self): + """ Assert that the parent pointer field is not included in the fields serialized fields """ - child = ChildModel(name1='parent name', name2='child name') - serializer = DerivedModelSerializer(child) - assert set(serializer.data) == {'name1', 'name2', 'id'} + child = ChildModel(name1='parent name', name2='child name') + serializer = DerivedModelSerializer(child) + assert set(serializer.data) == {'name1', 'name2', 'id'} def test_onetoone_primary_key_model_fields_as_expected(self): - """ + """ Assert that a model with a onetoone field that is the primary key is not treated like a derived model """ - parent = ParentModel.objects.create(name1='parent name') - associate = AssociatedModel.objects.create(name='hello', ref=parent) - serializer = AssociatedModelSerializer(associate) - assert set(serializer.data) == {'name', 'ref'} + parent = ParentModel.objects.create(name1='parent name') + associate = AssociatedModel.objects.create(name='hello', ref=parent) + serializer = AssociatedModelSerializer(associate) + assert set(serializer.data) == {'name', 'ref'} def test_data_is_valid_without_parent_ptr(self): - """ + """ Assert that the pointer to the parent table is not a required field for input data """ - data = { - 'name1': 'parent name', - 'name2': 'child name', - } - serializer = DerivedModelSerializer(data=data) - assert serializer.is_valid() is True + data = { 'name1': 'parent name', 'name2': 'child name', } + serializer = DerivedModelSerializer(data=data) + assert serializer.is_valid() is True diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py index 089a86c62..eab8a9b88 100644 --- a/tests/test_negotiation.py +++ b/tests/test_negotiation.py @@ -30,70 +30,68 @@ class NoCharsetSpecifiedRenderer(BaseRenderer): media_type = 'my/media' -class TestAcceptedMediaType(TestCase): - def setUp(self): - self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()] - self.negotiator = DefaultContentNegotiation() +def setUp(self): + self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()] + self.negotiator = DefaultContentNegotiation() def select_renderer(self, request): - return self.negotiator.select_renderer(request, self.renderers) + return self.negotiator.select_renderer(request, self.renderers) def test_client_without_accept_use_renderer(self): - request = Request(factory.get('/')) - accepted_renderer, accepted_media_type = self.select_renderer(request) - assert accepted_media_type == 'application/json' + request = Request(factory.get('/')) + accepted_renderer, accepted_media_type = self.select_renderer(request) + assert accepted_media_type == 'application/json' def test_client_underspecifies_accept_use_renderer(self): - request = Request(factory.get('/', HTTP_ACCEPT='*/*')) - accepted_renderer, accepted_media_type = self.select_renderer(request) - assert accepted_media_type == 'application/json' + request = Request(factory.get('/', HTTP_ACCEPT='*/*')) + accepted_renderer, accepted_media_type = self.select_renderer(request) + assert accepted_media_type == 'application/json' def test_client_overspecifies_accept_use_client(self): - request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) - accepted_renderer, accepted_media_type = self.select_renderer(request) - assert accepted_media_type == 'application/json; indent=8' + request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) + accepted_renderer, accepted_media_type = self.select_renderer(request) + assert accepted_media_type == 'application/json; indent=8' def test_client_specifies_parameter(self): - request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0')) - accepted_renderer, accepted_media_type = self.select_renderer(request) - assert accepted_media_type == 'application/openapi+json;version=2.0' - assert accepted_renderer.format == 'swagger' + request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0')) + accepted_renderer, accepted_media_type = self.select_renderer(request) + assert accepted_media_type == 'application/openapi+json;version=2.0' + assert accepted_renderer.format == 'swagger' def test_match_is_false_if_main_types_not_match(self): - mediatype = _MediaType('test_1') - anoter_mediatype = _MediaType('test_2') - assert mediatype.match(anoter_mediatype) is False + mediatype = _MediaType('test_1') + anoter_mediatype = _MediaType('test_2') + assert mediatype.match(anoter_mediatype) is False def test_mediatype_match_is_false_if_keys_not_match(self): - mediatype = _MediaType(';test_param=foo') - another_mediatype = _MediaType(';test_param=bar') - assert mediatype.match(another_mediatype) is False + mediatype = _MediaType(';test_param=foo') + another_mediatype = _MediaType(';test_param=bar') + assert mediatype.match(another_mediatype) is False def test_mediatype_precedence_with_wildcard_subtype(self): - mediatype = _MediaType('test/*') - assert mediatype.precedence == 1 + mediatype = _MediaType('test/*') + assert mediatype.precedence == 1 def test_mediatype_string_representation(self): - mediatype = _MediaType('test/*; foo=bar') - assert str(mediatype) == 'test/*; foo=bar' + mediatype = _MediaType('test/*; foo=bar') + assert str(mediatype) == 'test/*; foo=bar' def test_raise_error_if_no_suitable_renderers_found(self): - class MockRenderer: - format = 'xml' + class MockRenderer: + format = 'xml' renderers = [MockRenderer()] - with pytest.raises(Http404): - self.negotiator.filter_renderers(renderers, format='json') +withpytest.raises(Http404): + self.negotiator.filter_renderers(renderers, format='json') -class BaseContentNegotiationTests(TestCase): - def setUp(self): - self.negotiator = BaseContentNegotiation() +def setUp(self): + self.negotiator = BaseContentNegotiation() def test_raise_error_for_abstract_select_parser_method(self): - with pytest.raises(NotImplementedError): - self.negotiator.select_parser(None, None) + with pytest.raises(NotImplementedError): + self.negotiator.select_parser(None, None) def test_raise_error_for_abstract_select_renderer_method(self): - with pytest.raises(NotImplementedError): - self.negotiator.select_renderer(None, None) + with pytest.raises(NotImplementedError): + self.negotiator.select_renderer(None, None) diff --git a/tests/test_one_to_one_with_inheritance.py b/tests/test_one_to_one_with_inheritance.py index 40793d7ca..7f67fbb7c 100644 --- a/tests/test_one_to_one_with_inheritance.py +++ b/tests/test_one_to_one_with_inheritance.py @@ -28,14 +28,12 @@ class ChildAssociatedModelSerializer(serializers.ModelSerializer): # Tests -class InheritedModelSerializationTests(TestCase): - def test_multitable_inherited_model_fields_as_expected(self): - """ +def test_multitable_inherited_model_fields_as_expected(self): + """ Assert that the parent pointer field is not included in the fields serialized fields """ - child = ChildModel(name1='parent name', name2='child name') - serializer = DerivedModelSerializer(child) - self.assertEqual(set(serializer.data), - {'name1', 'name2', 'id', 'childassociatedmodel'}) + child = ChildModel(name1='parent name', name2='child name') + serializer = DerivedModelSerializer(child) + assert set(serializer.data) == {'name1', 'name2', 'id', 'childassociatedmodel'} diff --git a/tests/test_parsers.py b/tests/test_parsers.py index dcd62fac9..8040c160e 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -22,157 +22,138 @@ class Form(forms.Form): field2 = forms.CharField() -class TestFormParser(TestCase): - def setUp(self): - self.string = "field1=abc&field2=defghijk" +def setUp(self): + self.string = "field1=abc&field2=defghijk" def test_parse(self): - """ Make sure the `QueryDict` works OK """ - parser = FormParser() - - stream = io.StringIO(self.string) - data = parser.parse(stream) - - assert Form(data).is_valid() is True + """ Make sure the `QueryDict` works OK """ + parser = FormParser() + stream = io.StringIO(self.string) + data = parser.parse(stream) + assert Form(data).is_valid() is True -class TestFileUploadParser(TestCase): - def setUp(self): - class MockRequest: - pass +def setUp(self): + class MockRequest: + pass self.stream = io.BytesIO(b"Test text file") - request = MockRequest() - request.upload_handlers = (MemoryFileUploadHandler(),) - request.META = { - 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt', - 'HTTP_CONTENT_LENGTH': 14, - } - self.parser_context = {'request': request, 'kwargs': {}} +request=MockRequest() +request.upload_handlers=(MemoryFileUploadHandler(),) +request.META={'HTTP_CONTENT_DISPOSITION':'Content-Disposition: inline; filename=file.txt','HTTP_CONTENT_LENGTH':14,} +self.parser_context={'request':request,'kwargs':{}} def test_parse(self): - """ + """ Parse raw file upload. """ - parser = FileUploadParser() - self.stream.seek(0) - data_and_files = parser.parse(self.stream, None, self.parser_context) - file_obj = data_and_files.files['file'] - assert file_obj.size == 14 + parser = FileUploadParser() + self.stream.seek(0) + data_and_files = parser.parse(self.stream, None, self.parser_context) + file_obj = data_and_files.files['file'] + assert file_obj.size == 14 def test_parse_missing_filename(self): - """ + """ Parse raw file upload when filename is missing. """ - parser = FileUploadParser() - self.stream.seek(0) - self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' - with pytest.raises(ParseError) as excinfo: - parser.parse(self.stream, None, self.parser_context) + parser = FileUploadParser() + self.stream.seek(0) + self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' + with pytest.raises(ParseError) as excinfo: + parser.parse(self.stream, None, self.parser_context) assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' def test_parse_missing_filename_multiple_upload_handlers(self): - """ + """ Parse raw file upload with multiple handlers when filename is missing. Regression test for #2109. """ - parser = FileUploadParser() - self.stream.seek(0) - self.parser_context['request'].upload_handlers = ( - MemoryFileUploadHandler(), - MemoryFileUploadHandler() - ) - self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' - with pytest.raises(ParseError) as excinfo: - parser.parse(self.stream, None, self.parser_context) + parser = FileUploadParser() + self.stream.seek(0) + self.parser_context['request'].upload_handlers = ( MemoryFileUploadHandler(), MemoryFileUploadHandler() ) + self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' + with pytest.raises(ParseError) as excinfo: + parser.parse(self.stream, None, self.parser_context) assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' def test_parse_missing_filename_large_file(self): - """ + """ Parse raw file upload when filename is missing with TemporaryFileUploadHandler. """ - parser = FileUploadParser() - self.stream.seek(0) - self.parser_context['request'].upload_handlers = ( - TemporaryFileUploadHandler(), - ) - self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' - with pytest.raises(ParseError) as excinfo: - parser.parse(self.stream, None, self.parser_context) + parser = FileUploadParser() + self.stream.seek(0) + self.parser_context['request'].upload_handlers = ( TemporaryFileUploadHandler(), ) + self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' + with pytest.raises(ParseError) as excinfo: + parser.parse(self.stream, None, self.parser_context) assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' def test_get_filename(self): - parser = FileUploadParser() - filename = parser.get_filename(self.stream, None, self.parser_context) - assert filename == 'file.txt' + parser = FileUploadParser() + filename = parser.get_filename(self.stream, None, self.parser_context) + assert filename == 'file.txt' def test_get_encoded_filename(self): - parser = FileUploadParser() - - self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt') - filename = parser.get_filename(self.stream, None, self.parser_context) - assert filename == 'ÀĥƦ.txt' - - self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt') - filename = parser.get_filename(self.stream, None, self.parser_context) - assert filename == 'ÀĥƦ.txt' - - self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt') - filename = parser.get_filename(self.stream, None, self.parser_context) - assert filename == 'ÀĥƦ.txt' + parser = FileUploadParser() + self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt') + filename = parser.get_filename(self.stream, None, self.parser_context) + assert filename == 'ÀĥƦ.txt' + self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt') + filename = parser.get_filename(self.stream, None, self.parser_context) + assert filename == 'ÀĥƦ.txt' + self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt') + filename = parser.get_filename(self.stream, None, self.parser_context) + assert filename == 'ÀĥƦ.txt' def __replace_content_disposition(self, disposition): - self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition + self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition -class TestJSONParser(TestCase): - def bytes(self, value): - return io.BytesIO(value.encode()) +def bytes(self, value): + return io.BytesIO(value.encode()) def test_float_strictness(self): - parser = JSONParser() - - # Default to strict - for value in ['Infinity', '-Infinity', 'NaN']: - with pytest.raises(ParseError): - parser.parse(self.bytes(value)) + parser = JSONParser() + for value in ['Infinity', '-Infinity', 'NaN']: + with pytest.raises(ParseError): + parser.parse(self.bytes(value)) parser.strict = False - assert parser.parse(self.bytes('Infinity')) == float('inf') - assert parser.parse(self.bytes('-Infinity')) == float('-inf') - assert math.isnan(parser.parse(self.bytes('NaN'))) +assertparser.parse(self.bytes('Infinity'))==float('inf') +assertparser.parse(self.bytes('-Infinity'))==float('-inf') +assertmath.isnan(parser.parse(self.bytes('NaN'))) -class TestPOSTAccessed(TestCase): - def setUp(self): - self.factory = APIRequestFactory() +def setUp(self): + self.factory = APIRequestFactory() def test_post_accessed_in_post_method(self): - django_request = self.factory.post('/', {'foo': 'bar'}) - request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) - django_request.POST - assert request.POST == {'foo': ['bar']} - assert request.data == {'foo': ['bar']} + django_request = self.factory.post('/', {'foo': 'bar'}) + request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) + django_request.POST + assert request.POST == {'foo': ['bar']} + assert request.data == {'foo': ['bar']} def test_post_accessed_in_post_method_with_json_parser(self): - django_request = self.factory.post('/', {'foo': 'bar'}) - request = Request(django_request, parsers=[JSONParser()]) - django_request.POST - assert request.POST == {} - assert request.data == {} + django_request = self.factory.post('/', {'foo': 'bar'}) + request = Request(django_request, parsers=[JSONParser()]) + django_request.POST + assert request.POST == {} + assert request.data == {} def test_post_accessed_in_put_method(self): - django_request = self.factory.put('/', {'foo': 'bar'}) - request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) - django_request.POST - assert request.POST == {'foo': ['bar']} - assert request.data == {'foo': ['bar']} + django_request = self.factory.put('/', {'foo': 'bar'}) + request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) + django_request.POST + assert request.POST == {'foo': ['bar']} + assert request.data == {'foo': ['bar']} def test_request_read_before_parsing(self): - django_request = self.factory.put('/', {'foo': 'bar'}) - request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) - django_request.read() + django_request = self.factory.put('/', {'foo': 'bar'}) + request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) + django_request.read() + with pytest.raises(RawPostDataException): + request.POST with pytest.raises(RawPostDataException): - request.POST - with pytest.raises(RawPostDataException): - request.POST - request.data + request.POST + request.data diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 9c9300694..92d1d4c4c 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -72,177 +72,136 @@ def basic_auth_header(username, password): return 'Basic %s' % base64_credentials -class ModelPermissionsIntegrationTests(TestCase): - def setUp(self): - User.objects.create_user('disallowed', 'disallowed@example.com', 'password') - user = User.objects.create_user('permitted', 'permitted@example.com', 'password') - user.user_permissions.set([ - Permission.objects.get(codename='add_basicmodel'), - Permission.objects.get(codename='change_basicmodel'), - Permission.objects.get(codename='delete_basicmodel') - ]) - - user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') - user.user_permissions.set([ - Permission.objects.get(codename='change_basicmodel'), - ]) - - self.permitted_credentials = basic_auth_header('permitted', 'password') - self.disallowed_credentials = basic_auth_header('disallowed', 'password') - self.updateonly_credentials = basic_auth_header('updateonly', 'password') - - BasicModel(text='foo').save() +def setUp(self): + User.objects.create_user('disallowed', 'disallowed@example.com', 'password') + user = User.objects.create_user('permitted', 'permitted@example.com', 'password') + user.user_permissions.set([ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') ]) + user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') + user.user_permissions.set([ Permission.objects.get(codename='change_basicmodel'), ]) + self.permitted_credentials = basic_auth_header('permitted', 'password') + self.disallowed_credentials = basic_auth_header('disallowed', 'password') + self.updateonly_credentials = basic_auth_header('updateonly', 'password') + BasicModel(text='foo').save() def test_has_create_permissions(self): - request = factory.post('/', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.permitted_credentials) - response = root_view(request, pk=1) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) + response = root_view(request, pk=1) + assert response.status_code == status.HTTP_201_CREATED def test_api_root_view_discard_default_django_model_permission(self): - """ + """ We check that DEFAULT_PERMISSION_CLASSES can apply to APIRoot view. More specifically we check expected behavior of ``_ignore_model_permissions`` attribute support. """ - request = factory.get('/', format='json', - HTTP_AUTHORIZATION=self.permitted_credentials) - request.resolver_match = ResolverMatch('get', (), {}) - response = api_root_view(request) - self.assertEqual(response.status_code, status.HTTP_200_OK) + request = factory.get('/', format='json', HTTP_AUTHORIZATION=self.permitted_credentials) + request.resolver_match = ResolverMatch('get', (), {}) + response = api_root_view(request) + assert response.status_code == status.HTTP_200_OK def test_get_queryset_has_create_permissions(self): - request = factory.post('/', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.permitted_credentials) - response = get_queryset_list_view(request, pk=1) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) + response = get_queryset_list_view(request, pk=1) + assert response.status_code == status.HTTP_201_CREATED def test_has_put_permissions(self): - request = factory.put('/1', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.permitted_credentials) - response = instance_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK def test_has_delete_permissions(self): - request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) - response = instance_view(request, pk=1) - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk=1) + assert response.status_code == status.HTTP_204_NO_CONTENT def test_does_not_have_create_permissions(self): - request = factory.post('/', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.disallowed_credentials) - response = root_view(request, pk=1) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) + response = root_view(request, pk=1) + assert response.status_code == status.HTTP_403_FORBIDDEN def test_does_not_have_put_permissions(self): - request = factory.put('/1', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.disallowed_credentials) - response = instance_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk='1') + assert response.status_code == status.HTTP_403_FORBIDDEN def test_does_not_have_delete_permissions(self): - request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) - response = instance_view(request, pk=1) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk=1) + assert response.status_code == status.HTTP_403_FORBIDDEN def test_options_permitted(self): - request = factory.options( - '/', - HTTP_AUTHORIZATION=self.permitted_credentials - ) - response = root_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertIn('actions', response.data) - self.assertEqual(list(response.data['actions']), ['POST']) - - request = factory.options( - '/1', - HTTP_AUTHORIZATION=self.permitted_credentials - ) - response = instance_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertIn('actions', response.data) - self.assertEqual(list(response.data['actions']), ['PUT']) + request = factory.options( '/', HTTP_AUTHORIZATION=self.permitted_credentials ) + response = root_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK + assert 'actions' in response.data + assert list(response.data['actions']) == ['POST'] + request = factory.options( '/1', HTTP_AUTHORIZATION=self.permitted_credentials ) + response = instance_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK + assert 'actions' in response.data + assert list(response.data['actions']) == ['PUT'] def test_options_disallowed(self): - request = factory.options( - '/', - HTTP_AUTHORIZATION=self.disallowed_credentials - ) - response = root_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertNotIn('actions', response.data) - - request = factory.options( - '/1', - HTTP_AUTHORIZATION=self.disallowed_credentials - ) - response = instance_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertNotIn('actions', response.data) + request = factory.options( '/', HTTP_AUTHORIZATION=self.disallowed_credentials ) + response = root_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK + assert 'actions' not in response.data + request = factory.options( '/1', HTTP_AUTHORIZATION=self.disallowed_credentials ) + response = instance_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK + assert 'actions' not in response.data def test_options_updateonly(self): - request = factory.options( - '/', - HTTP_AUTHORIZATION=self.updateonly_credentials - ) - response = root_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertNotIn('actions', response.data) - - request = factory.options( - '/1', - HTTP_AUTHORIZATION=self.updateonly_credentials - ) - response = instance_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertIn('actions', response.data) - self.assertEqual(list(response.data['actions']), ['PUT']) + request = factory.options( '/', HTTP_AUTHORIZATION=self.updateonly_credentials ) + response = root_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK + assert 'actions' not in response.data + request = factory.options( '/1', HTTP_AUTHORIZATION=self.updateonly_credentials ) + response = instance_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK + assert 'actions' in response.data + assert list(response.data['actions']) == ['PUT'] def test_empty_view_does_not_assert(self): - request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials) - response = empty_list_view(request, pk=1) - self.assertEqual(response.status_code, status.HTTP_200_OK) + request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials) + response = empty_list_view(request, pk=1) + assert response.status_code == status.HTTP_200_OK def test_calling_method_not_allowed(self): - request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials) - response = root_view(request) - self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - - request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials) - response = instance_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials) + response = root_view(request) + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk='1') + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED def test_check_auth_before_queryset_call(self): - class View(RootView): - def get_queryset(_): - self.fail('should not reach due to auth check') + class View(RootView): + def get_queryset(_): + self.fail('should not reach due to auth check') view = View.as_view() - - request = factory.get('/', HTTP_AUTHORIZATION='') - response = view(request) - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) +request=factory.get('/',HTTP_AUTHORIZATION='') +response=view(request) +assertresponse.status_code==status.HTTP_401_UNAUTHORIZED def test_queryset_assertions(self): - class View(views.APIView): - authentication_classes = [authentication.BasicAuthentication] - permission_classes = [permissions.DjangoModelPermissions] + class View(views.APIView): + authentication_classes = [authentication.BasicAuthentication] + permission_classes = [permissions.DjangoModelPermissions] view = View.as_view() - - request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials) - msg = 'Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.' - with self.assertRaisesMessage(AssertionError, msg): - view(request) +request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials) +msg='Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.' +withself.assertRaisesMessage(AssertionError,msg): + view(request) # Faulty `get_queryset()` methods should trigger the above "view does not have a queryset" assertion. class View(RootView): - def get_queryset(self): - return None + def get_queryset(self): + return None view = View.as_view() - - request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials) - with self.assertRaisesMessage(AssertionError, 'View.get_queryset() returned None'): - view(request) +request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials) +withself.assertRaisesMessage(AssertionError,'View.get_queryset() returned None'): + view(request) class BasicPermModel(models.Model): @@ -310,149 +269,117 @@ get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.a @unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed') -class ObjectPermissionsIntegrationTests(TestCase): - """ +""" Integration tests for the object level permissions API. """ - def setUp(self): - from guardian.shortcuts import assign_perm - - # create users - create = User.objects.create_user - users = { - 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), - 'readonly': create('readonly', 'readonly@example.com', 'password'), - 'writeonly': create('writeonly', 'writeonly@example.com', 'password'), - 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), - } - - # give everyone model level permissions, as we are not testing those - everyone = Group.objects.create(name='everyone') - model_name = BasicPermModel._meta.model_name - app_label = BasicPermModel._meta.app_label - f = '{}_{}'.format - perms = { - 'view': f('view', model_name), - 'change': f('change', model_name), - 'delete': f('delete', model_name) - } - for perm in perms.values(): - perm = '{}.{}'.format(app_label, perm) - assign_perm(perm, everyone) +defsetUp(self): + from guardian.shortcuts import assign_perm + create = User.objects.create_user + users = { 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), 'readonly': create('readonly', 'readonly@example.com', 'password'), 'writeonly': create('writeonly', 'writeonly@example.com', 'password'), 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), } + everyone = Group.objects.create(name='everyone') + model_name = BasicPermModel._meta.model_name + app_label = BasicPermModel._meta.app_label + f = '{}_{}'.format + perms = { 'view': f('view', model_name), 'change': f('change', model_name), 'delete': f('delete', model_name) } + for perm in perms.values(): + perm = '{}.{}'.format(app_label, perm) + assign_perm(perm, everyone) everyone.user_set.add(*users.values()) - - # appropriate object level permissions - readers = Group.objects.create(name='readers') - writers = Group.objects.create(name='writers') - deleters = Group.objects.create(name='deleters') - - model = BasicPermModel.objects.create(text='foo') - - assign_perm(perms['view'], readers, model) - assign_perm(perms['change'], writers, model) - assign_perm(perms['delete'], deleters, model) - - readers.user_set.add(users['fullaccess'], users['readonly']) - writers.user_set.add(users['fullaccess'], users['writeonly']) - deleters.user_set.add(users['fullaccess'], users['deleteonly']) - - self.credentials = {} - for user in users.values(): - self.credentials[user.username] = basic_auth_header(user.username, 'password') +readers=Group.objects.create(name='readers') +writers=Group.objects.create(name='writers') +deleters=Group.objects.create(name='deleters') +model=BasicPermModel.objects.create(text='foo') +assign_perm(perms['view'],readers,model) +assign_perm(perms['change'],writers,model) +assign_perm(perms['delete'],deleters,model) +readers.user_set.add(users['fullaccess'],users['readonly']) +writers.user_set.add(users['fullaccess'],users['writeonly']) +deleters.user_set.add(users['fullaccess'],users['deleteonly']) +self.credentials={} +foruserinusers.values(): + self.credentials[user.username] = basic_auth_header(user.username, 'password') # Delete def test_can_delete_permissions(self): - request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) - response = object_permissions_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) + response = object_permissions_view(request, pk='1') + assert response.status_code == status.HTTP_204_NO_CONTENT def test_cannot_delete_permissions(self): - request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) - response = object_permissions_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) + response = object_permissions_view(request, pk='1') + assert response.status_code == status.HTTP_403_FORBIDDEN # Update def test_can_update_permissions(self): - request = factory.patch( - '/1', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.credentials['writeonly'] - ) - response = object_permissions_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data.get('text'), 'foobar') + request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['writeonly'] ) + response = object_permissions_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK + assert response.data.get('text') == 'foobar' def test_cannot_update_permissions(self): - request = factory.patch( - '/1', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.credentials['deleteonly'] - ) - response = object_permissions_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] ) + response = object_permissions_view(request, pk='1') + assert response.status_code == status.HTTP_404_NOT_FOUND def test_cannot_update_permissions_non_existing(self): - request = factory.patch( - '/999', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.credentials['deleteonly'] - ) - response = object_permissions_view(request, pk='999') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + request = factory.patch( '/999', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] ) + response = object_permissions_view(request, pk='999') + assert response.status_code == status.HTTP_404_NOT_FOUND # Read def test_can_read_permissions(self): - request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) - response = object_permissions_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) + request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) + response = object_permissions_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK def test_cannot_read_permissions(self): - request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) - response = object_permissions_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) + response = object_permissions_view(request, pk='1') + assert response.status_code == status.HTTP_404_NOT_FOUND def test_can_read_get_queryset_permissions(self): - """ + """ same as ``test_can_read_permissions`` but with a view that rely on ``.get_queryset()`` instead of ``.queryset``. """ - request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) - response = get_queryset_object_permissions_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) + request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) + response = get_queryset_object_permissions_view(request, pk='1') + assert response.status_code == status.HTTP_200_OK # Read list def test_django_object_permissions_filter_deprecated(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - DjangoObjectPermissionsFilter() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + DjangoObjectPermissionsFilter() - message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved " - "to the 3rd-party django-rest-framework-guardian package.") - self.assertEqual(len(w), 1) - self.assertIs(w[-1].category, RemovedInDRF310Warning) - self.assertEqual(str(w[-1].message), message) + message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved ""to the 3rd-party django-rest-framework-guardian package.") +assertlen(w)==1 +assertw[-1].categoryisRemovedInDRF310Warning +assertstr(w[-1].message)==message def test_can_read_list_permissions(self): - request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) - object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) - # TODO: remove in version 3.10 - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - response = object_permissions_list_view(request) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data[0].get('id'), 1) + request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) + object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + response = object_permissions_list_view(request) + assert response.status_code== status.HTTP_200_OK +assertresponse.data[0].get('id')==1 def test_cannot_read_list_permissions(self): - request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) - object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) - # TODO: remove in version 3.10 - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - response = object_permissions_list_view(request) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertListEqual(response.data, []) + request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) + object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + response = object_permissions_list_view(request) + assert response.status_code== status.HTTP_200_OK +assertresponse.data==[] def test_cannot_method_not_allowed(self): - request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly']) - response = object_permissions_list_view(request) - self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly']) + response = object_permissions_list_view(request) + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED class BasicPerm(permissions.BasePermission): @@ -507,203 +434,176 @@ denied_view_with_detail = DeniedViewWithDetail.as_view() denied_object_view = DeniedObjectView.as_view() denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view() - - -class CustomPermissionsTests(TestCase): - def setUp(self): - BasicModel(text='foo').save() - User.objects.create_user('username', 'username@example.com', 'password') - credentials = basic_auth_header('username', 'password') - self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials) - self.custom_message = 'Custom: You cannot access this resource' +def setUp(self): + BasicModel(text='foo').save() + User.objects.create_user('username', 'username@example.com', 'password') + credentials = basic_auth_header('username', 'password') + self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials) + self.custom_message = 'Custom: You cannot access this resource' def test_permission_denied(self): - response = denied_view(self.request, pk=1) - detail = response.data.get('detail') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertNotEqual(detail, self.custom_message) + response = denied_view(self.request, pk=1) + detail = response.data.get('detail') + assert response.status_code == status.HTTP_403_FORBIDDEN + assert detail != self.custom_message def test_permission_denied_with_custom_detail(self): - response = denied_view_with_detail(self.request, pk=1) - detail = response.data.get('detail') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEqual(detail, self.custom_message) + response = denied_view_with_detail(self.request, pk=1) + detail = response.data.get('detail') + assert response.status_code == status.HTTP_403_FORBIDDEN + assert detail == self.custom_message def test_permission_denied_for_object(self): - response = denied_object_view(self.request, pk=1) - detail = response.data.get('detail') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertNotEqual(detail, self.custom_message) + response = denied_object_view(self.request, pk=1) + detail = response.data.get('detail') + assert response.status_code == status.HTTP_403_FORBIDDEN + assert detail != self.custom_message def test_permission_denied_for_object_with_custom_detail(self): - response = denied_object_view_with_detail(self.request, pk=1) - detail = response.data.get('detail') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEqual(detail, self.custom_message) + response = denied_object_view_with_detail(self.request, pk=1) + detail = response.data.get('detail') + assert response.status_code == status.HTTP_403_FORBIDDEN + assert detail == self.custom_message -class PermissionsCompositionTests(TestCase): - def setUp(self): - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user( - self.username, - self.email, - self.password - ) - self.client.login(username=self.username, password=self.password) +def setUp(self): + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user( self.username, self.email, self.password ) + self.client.login(username=self.username, password=self.password) def test_and_false(self): - request = factory.get('/1', format='json') - request.user = AnonymousUser() - composed_perm = permissions.IsAuthenticated & permissions.AllowAny - assert composed_perm().has_permission(request, None) is False + request = factory.get('/1', format='json') + request.user = AnonymousUser() + composed_perm = permissions.IsAuthenticated & permissions.AllowAny + assert composed_perm().has_permission(request, None) is False def test_and_true(self): - request = factory.get('/1', format='json') - request.user = self.user - composed_perm = permissions.IsAuthenticated & permissions.AllowAny - assert composed_perm().has_permission(request, None) is True + request = factory.get('/1', format='json') + request.user = self.user + composed_perm = permissions.IsAuthenticated & permissions.AllowAny + assert composed_perm().has_permission(request, None) is True def test_or_false(self): - request = factory.get('/1', format='json') - request.user = AnonymousUser() - composed_perm = permissions.IsAuthenticated | permissions.AllowAny - assert composed_perm().has_permission(request, None) is True + request = factory.get('/1', format='json') + request.user = AnonymousUser() + composed_perm = permissions.IsAuthenticated | permissions.AllowAny + assert composed_perm().has_permission(request, None) is True def test_or_true(self): - request = factory.get('/1', format='json') - request.user = self.user - composed_perm = permissions.IsAuthenticated | permissions.AllowAny - assert composed_perm().has_permission(request, None) is True + request = factory.get('/1', format='json') + request.user = self.user + composed_perm = permissions.IsAuthenticated | permissions.AllowAny + assert composed_perm().has_permission(request, None) is True def test_not_false(self): - request = factory.get('/1', format='json') - request.user = AnonymousUser() - composed_perm = ~permissions.IsAuthenticated - assert composed_perm().has_permission(request, None) is True + request = factory.get('/1', format='json') + request.user = AnonymousUser() + composed_perm = ~permissions.IsAuthenticated + assert composed_perm().has_permission(request, None) is True def test_not_true(self): - request = factory.get('/1', format='json') - request.user = self.user - composed_perm = ~permissions.AllowAny - assert composed_perm().has_permission(request, None) is False + request = factory.get('/1', format='json') + request.user = self.user + composed_perm = ~permissions.AllowAny + assert composed_perm().has_permission(request, None) is False def test_several_levels_without_negation(self): - request = factory.get('/1', format='json') - request.user = self.user - composed_perm = ( - permissions.IsAuthenticated & - permissions.IsAuthenticated & - permissions.IsAuthenticated & - permissions.IsAuthenticated - ) - assert composed_perm().has_permission(request, None) is True + request = factory.get('/1', format='json') + request.user = self.user + composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated ) + assert composed_perm().has_permission(request, None) is True def test_several_levels_and_precedence_with_negation(self): - request = factory.get('/1', format='json') - request.user = self.user - composed_perm = ( - permissions.IsAuthenticated & - ~ permissions.IsAdminUser & - permissions.IsAuthenticated & - ~(permissions.IsAdminUser & permissions.IsAdminUser) - ) - assert composed_perm().has_permission(request, None) is True + request = factory.get('/1', format='json') + request.user = self.user + composed_perm = ( permissions.IsAuthenticated & ~ permissions.IsAdminUser & permissions.IsAuthenticated & ~(permissions.IsAdminUser & permissions.IsAdminUser) ) + assert composed_perm().has_permission(request, None) is True def test_several_levels_and_precedence(self): - request = factory.get('/1', format='json') - request.user = self.user - composed_perm = ( - permissions.IsAuthenticated & - permissions.IsAuthenticated | - permissions.IsAuthenticated & - permissions.IsAuthenticated - ) - assert composed_perm().has_permission(request, None) is True + request = factory.get('/1', format='json') + request.user = self.user + composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated | permissions.IsAuthenticated & permissions.IsAuthenticated ) + assert composed_perm().has_permission(request, None) is True @pytest.mark.skipif(not PY36, reason="assert_called_once() not available") - def test_or_lazyness(self): - request = factory.get('/1', format='json') - request.user = AnonymousUser() +deftest_or_lazyness(self): + request = factory.get('/1', format='json') + request.user = AnonymousUser() + with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: + with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: + composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) + hasperm = composed_perm().has_permission(request, None) + assert hasperm is True + mock_allow.assert_called_once() + mock_deny.assert_not_called() with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: - with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: - composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) - hasperm = composed_perm().has_permission(request, None) - self.assertIs(hasperm, True) - mock_allow.assert_called_once() - mock_deny.assert_not_called() - - with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: - with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: - composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) - hasperm = composed_perm().has_permission(request, None) - self.assertIs(hasperm, True) - mock_deny.assert_called_once() - mock_allow.assert_called_once() + with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: + composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) + hasperm = composed_perm().has_permission(request, None) + assert hasperm is True + mock_deny.assert_called_once() + mock_allow.assert_called_once() @pytest.mark.skipif(not PY36, reason="assert_called_once() not available") - def test_object_or_lazyness(self): - request = factory.get('/1', format='json') - request.user = AnonymousUser() +deftest_object_or_lazyness(self): + request = factory.get('/1', format='json') + request.user = AnonymousUser() + with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: + with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: + composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) + hasperm = composed_perm().has_object_permission(request, None, None) + assert hasperm is True + mock_allow.assert_called_once() + mock_deny.assert_not_called() with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: - with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: - composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) - hasperm = composed_perm().has_object_permission(request, None, None) - self.assertIs(hasperm, True) - mock_allow.assert_called_once() - mock_deny.assert_not_called() - - with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: - with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: - composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) - hasperm = composed_perm().has_object_permission(request, None, None) - self.assertIs(hasperm, True) - mock_deny.assert_called_once() - mock_allow.assert_called_once() + with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: + composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) + hasperm = composed_perm().has_object_permission(request, None, None) + assert hasperm is True + mock_deny.assert_called_once() + mock_allow.assert_called_once() @pytest.mark.skipif(not PY36, reason="assert_called_once() not available") - def test_and_lazyness(self): - request = factory.get('/1', format='json') - request.user = AnonymousUser() +deftest_and_lazyness(self): + request = factory.get('/1', format='json') + request.user = AnonymousUser() + with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: + with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: + composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) + hasperm = composed_perm().has_permission(request, None) + assert hasperm is False + mock_allow.assert_called_once() + mock_deny.assert_called_once() with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: - with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: - composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) - hasperm = composed_perm().has_permission(request, None) - self.assertIs(hasperm, False) - mock_allow.assert_called_once() - mock_deny.assert_called_once() - - with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: - with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: - composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) - hasperm = composed_perm().has_permission(request, None) - self.assertIs(hasperm, False) - mock_allow.assert_not_called() - mock_deny.assert_called_once() + with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: + composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) + hasperm = composed_perm().has_permission(request, None) + assert hasperm is False + mock_allow.assert_not_called() + mock_deny.assert_called_once() @pytest.mark.skipif(not PY36, reason="assert_called_once() not available") - def test_object_and_lazyness(self): - request = factory.get('/1', format='json') - request.user = AnonymousUser() +deftest_object_and_lazyness(self): + request = factory.get('/1', format='json') + request.user = AnonymousUser() + with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: + with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: + composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) + hasperm = composed_perm().has_object_permission(request, None, None) + assert hasperm is False + mock_allow.assert_called_once() + mock_deny.assert_called_once() with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: - with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: - composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) - hasperm = composed_perm().has_object_permission(request, None, None) - self.assertIs(hasperm, False) - mock_allow.assert_called_once() - mock_deny.assert_called_once() - - with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: - with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: - composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) - hasperm = composed_perm().has_object_permission(request, None, None) - self.assertIs(hasperm, False) - mock_allow.assert_not_called() - mock_deny.assert_called_once() + with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: + composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) + hasperm = composed_perm().has_object_permission(request, None, None) + assert hasperm is False + mock_allow.assert_not_called() + mock_deny.assert_called_once() diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index b07087c97..e4162d9aa 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -18,41 +18,30 @@ class UserUpdate(generics.UpdateAPIView): serializer_class = UserSerializer -class TestPrefetchRelatedUpdates(TestCase): - def setUp(self): - self.user = User.objects.create(username='tom', email='tom@example.com') - self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] - self.user.groups.set(self.groups) +def setUp(self): + self.user = User.objects.create(username='tom', email='tom@example.com') + self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] + self.user.groups.set(self.groups) def test_prefetch_related_updates(self): - view = UserUpdate.as_view() - pk = self.user.pk - groups_pk = self.groups[0].pk - request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') - response = view(request, pk=pk) - assert User.objects.get(pk=pk).groups.count() == 1 - expected = { - 'id': pk, - 'username': 'new', - 'groups': [1], - 'email': 'tom@example.com' - } - assert response.data == expected + view = UserUpdate.as_view() + pk = self.user.pk + groups_pk = self.groups[0].pk + request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') + response = view(request, pk=pk) + assert User.objects.get(pk=pk).groups.count() == 1 + expected = { 'id': pk, 'username': 'new', 'groups': [1], 'email': 'tom@example.com' } + assert response.data == expected def test_prefetch_related_excluding_instance_from_original_queryset(self): - """ + """ Regression test for https://github.com/encode/django-rest-framework/issues/4661 """ - view = UserUpdate.as_view() - pk = self.user.pk - groups_pk = self.groups[0].pk - request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') - response = view(request, pk=pk) - assert User.objects.get(pk=pk).groups.count() == 1 - expected = { - 'id': pk, - 'username': 'exclude', - 'groups': [1], - 'email': 'tom@example.com' - } - assert response.data == expected + view = UserUpdate.as_view() + pk = self.user.pk + groups_pk = self.groups[0].pk + request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') + response = view(request, pk=pk) + assert User.objects.get(pk=pk).groups.count() == 1 + expected = { 'id': pk, 'username': 'exclude', 'groups': [1], 'email': 'tom@example.com' } + assert response.data == expected diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index 5ad0e31ff..735e432d1 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -70,380 +70,268 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): # TODO: Add test that .data cannot be accessed prior to .is_valid @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') -class HyperlinkedManyToManyTests(TestCase): - def setUp(self): - for idx in range(1, 4): - target = ManyToManyTarget(name='target-%d' % idx) - target.save() - source = ManyToManySource(name='source-%d' % idx) - source.save() - for target in ManyToManyTarget.objects.all(): - source.targets.add(target) +def setUp(self): + for idx in range(1, 4): + target = ManyToManyTarget(name='target-%d' % idx) + target.save() + source = ManyToManySource(name='source-%d' % idx) + source.save() + for target in ManyToManyTarget.objects.all(): + source.targets.add(target) def test_relative_hyperlinks(self): - queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None}) - expected = [ - {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, - {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, - {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} - ] - with self.assertNumQueries(4): - assert serializer.data == expected + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None}) + expected = [ {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} ] + with self.assertNumQueries(4): + assert serializer.data == expected def test_many_to_many_retrieve(self): - queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, - {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, - {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} - ] - with self.assertNumQueries(4): - assert serializer.data == expected + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] + with self.assertNumQueries(4): + assert serializer.data == expected def test_many_to_many_retrieve_prefetch_related(self): - queryset = ManyToManySource.objects.all().prefetch_related('targets') - serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) - with self.assertNumQueries(2): - serializer.data + queryset = ManyToManySource.objects.all().prefetch_related('targets') + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) + with self.assertNumQueries(2): + serializer.data def test_reverse_many_to_many_retrieve(self): - queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, - {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, - {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} - ] - with self.assertNumQueries(4): - assert serializer.data == expected + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] + with self.assertNumQueries(4): + assert serializer.data == expected def test_many_to_many_update(self): - data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} - instance = ManyToManySource.objects.get(pk=1) - serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure source 1 is updated, and everything else is as expected - queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, - {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, - {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} - ] - assert serializer.data == expected + data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} + instance = ManyToManySource.objects.get(pk=1) + serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] + assert serializer.data == expected def test_reverse_many_to_many_update(self): - data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} - instance = ManyToManyTarget.objects.get(pk=1) - serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - # Ensure target 1 is updated, and everything else is as expected - queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, - {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, - {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} - - ] - assert serializer.data == expected + data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} + instance = ManyToManyTarget.objects.get(pk=1) + serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] + assert serializer.data == expected def test_many_to_many_create(self): - data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} - serializer = ManyToManySourceSerializer(data=data, context={'request': request}) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'source-4' - - # Ensure source 4 is added, and everything else is as expected - queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, - {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, - {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, - {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} - ] - assert serializer.data == expected + data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} + serializer = ManyToManySourceSerializer(data=data, context={'request': request}) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'source-4' + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ] + assert serializer.data == expected def test_reverse_many_to_many_create(self): - data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} - serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'target-4' - - # Ensure target 4 is added, and everything else is as expected - queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, - {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, - {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, - {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} - ] - assert serializer.data == expected + data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} + serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'target-4' + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ] + assert serializer.data == expected @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') -class HyperlinkedForeignKeyTests(TestCase): - def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - new_target = ForeignKeyTarget(name='target-2') - new_target.save() - for idx in range(1, 4): - source = ForeignKeySource(name='source-%d' % idx, target=target) - source.save() - - def test_foreign_key_retrieve(self): - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} - ] - with self.assertNumQueries(1): - assert serializer.data == expected - - def test_reverse_foreign_key_retrieve(self): - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, - {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, - ] - with self.assertNumQueries(3): - assert serializer.data == expected - - def test_foreign_key_update(self): - data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure source 1 is updated, and everything else is as expected - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, - {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} - ] - assert serializer.data == expected - - def test_foreign_key_update_incorrect_type(self): - data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) - assert not serializer.is_valid() - assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']} - - def test_reverse_foreign_key_update(self): - data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} - instance = ForeignKeyTarget.objects.get(pk=2) - serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) - assert serializer.is_valid() - # We shouldn't have saved anything to the db yet since save - # hasn't been called. - queryset = ForeignKeyTarget.objects.all() - new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, - {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, - ] - assert new_serializer.data == expected - - serializer.save() - assert serializer.data == data - - # Ensure target 2 is update, and everything else is as expected - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, - {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, - ] - assert serializer.data == expected - - def test_foreign_key_create(self): - data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} - serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'source-4' - - # Ensure source 1 is updated, and everything else is as expected - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, - ] - assert serializer.data == expected - - def test_reverse_foreign_key_create(self): - data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} - serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'target-3' - - # Ensure target 4 is added, and everything else is as expected - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, - {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, - {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, - ] - assert serializer.data == expected - - def test_foreign_key_update_with_invalid_null(self): - data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) - assert not serializer.is_valid() - assert serializer.errors == {'target': ['This field may not be null.']} - - -@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') -class HyperlinkedNullableForeignKeyTests(TestCase): - def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - for idx in range(1, 4): - if idx == 3: - target = None - source = NullableForeignKeySource(name='source-%d' % idx, target=target) - source.save() - - def test_foreign_key_retrieve_with_null(self): - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, - ] - assert serializer.data == expected - - def test_foreign_key_create_with_valid_null(self): - data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'source-4' - - # Ensure source 4 is created, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, - {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} - ] - assert serializer.data == expected - - def test_foreign_key_create_with_valid_emptystring(self): - """ - The emptystring should be interpreted as null in the context - of relationships. - """ - data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} - expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == expected_data - assert obj.name == 'source-4' - - # Ensure source 4 is created, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, - {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} - ] - assert serializer.data == expected - - def test_foreign_key_update_with_valid_null(self): - data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} - instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure source 1 is updated, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, - {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, - ] - assert serializer.data == expected - - def test_foreign_key_update_with_valid_emptystring(self): - """ - The emptystring should be interpreted as null in the context - of relationships. - """ - data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} - expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} - instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) - assert serializer.is_valid() - serializer.save() - assert serializer.data == expected_data - - # Ensure source 1 is updated, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, - {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, - {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, - ] - assert serializer.data == expected - - -@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') -class HyperlinkedNullableOneToOneTests(TestCase): - def setUp(self): - target = OneToOneTarget(name='target-1') - target.save() - new_target = OneToOneTarget(name='target-2') - new_target.save() - source = NullableOneToOneSource(name='source-1', target=target) +def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) source.save() - def test_reverse_foreign_key_retrieve_with_null(self): - queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) - expected = [ - {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, - {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, - ] + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ] + with self.assertNumQueries(1): assert serializer.data == expected + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ] + with self.assertNumQueries(3): + assert serializer.data == expected + + def test_foreign_key_update(self): + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ] + assert serializer.data == expected + + def test_foreign_key_update_incorrect_type(self): + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']} + + def test_reverse_foreign_key_update(self): + data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) + assert serializer.is_valid() + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ] + assert new_serializer.data == expected + serializer.save() + assert serializer.data == data + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ] + assert serializer.data == expected + + def test_foreign_key_create(self): + data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} + serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'source-4' + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, ] + assert serializer.data == expected + + def test_reverse_foreign_key_create(self): + data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} + serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'target-3' + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ] + assert serializer.data == expected + + def test_foreign_key_update_with_invalid_null(self): + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['This field may not be null.']} + + +@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') +def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None + source = NullableForeignKeySource(name='source-%d' % idx, target=target) +source.save() + + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] + assert serializer.data == expected + + def test_foreign_key_create_with_valid_null(self): + data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'source-4' + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ] + assert serializer.data == expected + + def test_foreign_key_create_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} + expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == expected_data + assert obj.name == 'source-4' + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ] + assert serializer.data == expected + + def test_foreign_key_update_with_valid_null(self): + data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] + assert serializer.data == expected + + def test_foreign_key_update_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} + expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) + assert serializer.is_valid() + serializer.save() + assert serializer.data == expected_data + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] + assert serializer.data == expected + + +@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') +def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ] + assert serializer.data == expected diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index 0da9da890..517b8bbb2 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -77,496 +77,373 @@ class OneToOnePKSourceSerializer(serializers.ModelSerializer): # TODO: Add test that .data cannot be accessed prior to .is_valid -class PKManyToManyTests(TestCase): - def setUp(self): - for idx in range(1, 4): - target = ManyToManyTarget(name='target-%d' % idx) - target.save() - source = ManyToManySource(name='source-%d' % idx) - source.save() - for target in ManyToManyTarget.objects.all(): - source.targets.add(target) +def setUp(self): + for idx in range(1, 4): + target = ManyToManyTarget(name='target-%d' % idx) + target.save() + source = ManyToManySource(name='source-%d' % idx) + source.save() + for target in ManyToManyTarget.objects.all(): + source.targets.add(target) def test_many_to_many_retrieve(self): - queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'targets': [1]}, - {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} - ] - with self.assertNumQueries(4): - assert serializer.data == expected + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ] + with self.assertNumQueries(4): + assert serializer.data == expected def test_many_to_many_retrieve_prefetch_related(self): - queryset = ManyToManySource.objects.all().prefetch_related('targets') - serializer = ManyToManySourceSerializer(queryset, many=True) - with self.assertNumQueries(2): - serializer.data + queryset = ManyToManySource.objects.all().prefetch_related('targets') + serializer = ManyToManySourceSerializer(queryset, many=True) + with self.assertNumQueries(2): + serializer.data def test_reverse_many_to_many_retrieve(self): - queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': 'target-3', 'sources': [3]} - ] - with self.assertNumQueries(4): - assert serializer.data == expected + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ] + with self.assertNumQueries(4): + assert serializer.data == expected def test_many_to_many_update(self): - data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} - instance = ManyToManySource.objects.get(pk=1) - serializer = ManyToManySourceSerializer(instance, data=data) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure source 1 is updated, and everything else is as expected - queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, - {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} - ] - assert serializer.data == expected + data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} + instance = ManyToManySource.objects.get(pk=1) + serializer = ManyToManySourceSerializer(instance, data=data) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ] + assert serializer.data == expected def test_reverse_many_to_many_update(self): - data = {'id': 1, 'name': 'target-1', 'sources': [1]} - instance = ManyToManyTarget.objects.get(pk=1) - serializer = ManyToManyTargetSerializer(instance, data=data) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure target 1 is updated, and everything else is as expected - queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [1]}, - {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': 'target-3', 'sources': [3]} - ] - assert serializer.data == expected + data = {'id': 1, 'name': 'target-1', 'sources': [1]} + instance = ManyToManyTarget.objects.get(pk=1) + serializer = ManyToManyTargetSerializer(instance, data=data) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': [1]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ] + assert serializer.data == expected def test_many_to_many_create(self): - data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} - serializer = ManyToManySourceSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'source-4' - - # Ensure source 4 is added, and everything else is as expected - queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'targets': [1]}, - {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, - {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, - ] - assert serializer.data == expected + data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} + serializer = ManyToManySourceSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'source-4' + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, ] + assert serializer.data == expected def test_many_to_many_unsaved(self): - source = ManyToManySource(name='source-unsaved') - - serializer = ManyToManySourceSerializer(source) - - expected = {'id': None, 'name': 'source-unsaved', 'targets': []} - # no query if source hasn't been created yet - with self.assertNumQueries(0): - assert serializer.data == expected + source = ManyToManySource(name='source-unsaved') + serializer = ManyToManySourceSerializer(source) + expected = {'id': None, 'name': 'source-unsaved', 'targets': []} + with self.assertNumQueries(0): + assert serializer.data == expected def test_reverse_many_to_many_create(self): - data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} - serializer = ManyToManyTargetSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'target-4' - - # Ensure target 4 is added, and everything else is as expected - queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': 'target-3', 'sources': [3]}, - {'id': 4, 'name': 'target-4', 'sources': [1, 3]} - ] - assert serializer.data == expected + data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} + serializer = ManyToManyTargetSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'target-4' + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]}, {'id': 4, 'name': 'target-4', 'sources': [1, 3]} ] + assert serializer.data == expected -class PKForeignKeyTests(TestCase): - def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - new_target = ForeignKeyTarget(name='target-2') - new_target.save() - for idx in range(1, 4): - source = ForeignKeySource(name='source-%d' % idx, target=target) - source.save() +def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() def test_foreign_key_retrieve(self): - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': 1} - ] - with self.assertNumQueries(1): - assert serializer.data == expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ] + with self.assertNumQueries(1): + assert serializer.data == expected def test_reverse_foreign_key_retrieve(self): - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': 'target-2', 'sources': []}, - ] - with self.assertNumQueries(3): - assert serializer.data == expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ] + with self.assertNumQueries(3): + assert serializer.data == expected def test_reverse_foreign_key_retrieve_prefetch_related(self): - queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') - serializer = ForeignKeyTargetSerializer(queryset, many=True) - with self.assertNumQueries(2): - serializer.data + queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') + serializer = ForeignKeyTargetSerializer(queryset, many=True) + with self.assertNumQueries(2): + serializer.data def test_foreign_key_update(self): - data = {'id': 1, 'name': 'source-1', 'target': 2} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure source 1 is updated, and everything else is as expected - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 2}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': 1} - ] - assert serializer.data == expected + data = {'id': 1, 'name': 'source-1', 'target': 2} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 2}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ] + assert serializer.data == expected def test_foreign_key_update_incorrect_type(self): - data = {'id': 1, 'name': 'source-1', 'target': 'foo'} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) - assert not serializer.is_valid() - assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received str.']} + data = {'id': 1, 'name': 'source-1', 'target': 'foo'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received str.']} def test_reverse_foreign_key_update(self): - data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} - instance = ForeignKeyTarget.objects.get(pk=2) - serializer = ForeignKeyTargetSerializer(instance, data=data) - assert serializer.is_valid() - # We shouldn't have saved anything to the db yet since save - # hasn't been called. - queryset = ForeignKeyTarget.objects.all() - new_serializer = ForeignKeyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': 'target-2', 'sources': []}, - ] - assert new_serializer.data == expected - - serializer.save() - assert serializer.data == data - - # Ensure target 2 is update, and everything else is as expected - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [2]}, - {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, - ] - assert serializer.data == expected + data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + assert serializer.is_valid() + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ] + assert new_serializer.data == expected + serializer.save() + assert serializer.data == data + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, ] + assert serializer.data == expected def test_foreign_key_create(self): - data = {'id': 4, 'name': 'source-4', 'target': 2} - serializer = ForeignKeySourceSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'source-4' - - # Ensure source 4 is added, and everything else is as expected - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': 1}, - {'id': 4, 'name': 'source-4', 'target': 2}, - ] - assert serializer.data == expected + data = {'id': 4, 'name': 'source-4', 'target': 2} + serializer = ForeignKeySourceSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'source-4' + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1}, {'id': 4, 'name': 'source-4', 'target': 2}, ] + assert serializer.data == expected def test_reverse_foreign_key_create(self): - data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} - serializer = ForeignKeyTargetSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'target-3' - - # Ensure target 3 is added, and everything else is as expected - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [2]}, - {'id': 2, 'name': 'target-2', 'sources': []}, - {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, - ] - assert serializer.data == expected + data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} + serializer = ForeignKeyTargetSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'target-3' + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, ] + assert serializer.data == expected def test_foreign_key_update_with_invalid_null(self): - data = {'id': 1, 'name': 'source-1', 'target': None} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) - assert not serializer.is_valid() - assert serializer.errors == {'target': ['This field may not be null.']} + data = {'id': 1, 'name': 'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['This field may not be null.']} def test_foreign_key_with_unsaved(self): - source = ForeignKeySource(name='source-unsaved') - expected = {'id': None, 'name': 'source-unsaved', 'target': None} - - serializer = ForeignKeySourceSerializer(source) - - # no query if source hasn't been created yet - with self.assertNumQueries(0): - assert serializer.data == expected + source = ForeignKeySource(name='source-unsaved') + expected = {'id': None, 'name': 'source-unsaved', 'target': None} + serializer = ForeignKeySourceSerializer(source) + with self.assertNumQueries(0): + assert serializer.data == expected def test_foreign_key_with_empty(self): - """ + """ Regression test for #1072 https://github.com/encode/django-rest-framework/issues/1072 """ - serializer = NullableForeignKeySourceSerializer() - assert serializer.data['target'] is None + serializer = NullableForeignKeySourceSerializer() + assert serializer.data['target'] is None def test_foreign_key_not_required(self): - """ + """ Let's say we wanted to fill the non-nullable model field inside Model.save(), we would make it empty and not required. """ - class ModelSerializer(ForeignKeySourceSerializer): - class Meta(ForeignKeySourceSerializer.Meta): - extra_kwargs = {'target': {'required': False}} + class ModelSerializer(ForeignKeySourceSerializer): + class Meta(ForeignKeySourceSerializer.Meta): + extra_kwargs = {'target': {'required': False}} serializer = ModelSerializer(data={'name': 'test'}) - serializer.is_valid(raise_exception=True) - assert 'target' not in serializer.validated_data +serializer.is_valid(raise_exception=True) +assert'target'notinserializer.validated_data def test_queryset_size_without_limited_choices(self): - limited_target = ForeignKeyTarget(name="limited-target") - limited_target.save() - queryset = ForeignKeySourceSerializer().fields["target"].get_queryset() - assert len(queryset) == 3 + limited_target = ForeignKeyTarget(name="limited-target") + limited_target.save() + queryset = ForeignKeySourceSerializer().fields["target"].get_queryset() + assert len(queryset) == 3 def test_queryset_size_with_limited_choices(self): - limited_target = ForeignKeyTarget(name="limited-target") - limited_target.save() - queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset() - assert len(queryset) == 1 + limited_target = ForeignKeyTarget(name="limited-target") + limited_target.save() + queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset() + assert len(queryset) == 1 def test_queryset_size_with_Q_limited_choices(self): - limited_target = ForeignKeyTarget(name="limited-target") - limited_target.save() - - class QLimitedChoicesSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeySourceWithQLimitedChoices - fields = ("id", "target") + limited_target = ForeignKeyTarget(name="limited-target") + limited_target.save() + class QLimitedChoicesSerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeySourceWithQLimitedChoices + fields = ("id", "target") queryset = QLimitedChoicesSerializer().fields["target"].get_queryset() - assert len(queryset) == 1 +assertlen(queryset)==1 -class PKNullableForeignKeyTests(TestCase): - def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - for idx in range(1, 4): - if idx == 3: - target = None +def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None source = NullableForeignKeySource(name='source-%d' % idx, target=target) - source.save() +source.save() def test_foreign_key_retrieve_with_null(self): - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': None}, - ] - assert serializer.data == expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, ] + assert serializer.data == expected def test_foreign_key_create_with_valid_null(self): - data = {'id': 4, 'name': 'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'source-4' - - # Ensure source 4 is created, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': None}, - {'id': 4, 'name': 'source-4', 'target': None} - ] - assert serializer.data == expected + data = {'id': 4, 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'source-4' + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ] + assert serializer.data == expected def test_foreign_key_create_with_valid_emptystring(self): - """ + """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'id': 4, 'name': 'source-4', 'target': ''} - expected_data = {'id': 4, 'name': 'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == expected_data - assert obj.name == 'source-4' - - # Ensure source 4 is created, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': None}, - {'id': 4, 'name': 'source-4', 'target': None} - ] - assert serializer.data == expected + data = {'id': 4, 'name': 'source-4', 'target': ''} + expected_data = {'id': 4, 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == expected_data + assert obj.name == 'source-4' + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ] + assert serializer.data == expected def test_foreign_key_update_with_valid_null(self): - data = {'id': 1, 'name': 'source-1', 'target': None} - instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure source 1 is updated, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': None}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': None} - ] - assert serializer.data == expected + data = {'id': 1, 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ] + assert serializer.data == expected def test_foreign_key_update_with_valid_emptystring(self): - """ + """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'id': 1, 'name': 'source-1', 'target': ''} - expected_data = {'id': 1, 'name': 'source-1', 'target': None} - instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data) - assert serializer.is_valid() - serializer.save() - assert serializer.data == expected_data - - # Ensure source 1 is updated, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': None}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': None} - ] - assert serializer.data == expected + data = {'id': 1, 'name': 'source-1', 'target': ''} + expected_data = {'id': 1, 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + assert serializer.is_valid() + serializer.save() + assert serializer.data == expected_data + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ] + assert serializer.data == expected def test_null_uuid_foreign_key_serializes_as_none(self): - source = NullableUUIDForeignKeySource(name='Source') - serializer = NullableUUIDForeignKeySourceSerializer(source) - data = serializer.data - assert data["target"] is None + source = NullableUUIDForeignKeySource(name='Source') + serializer = NullableUUIDForeignKeySourceSerializer(source) + data = serializer.data + assert data["target"] is None def test_nullable_uuid_foreign_key_is_valid_when_none(self): - data = {"name": "Source", "target": None} - serializer = NullableUUIDForeignKeySourceSerializer(data=data) - assert serializer.is_valid(), serializer.errors + data = {"name": "Source", "target": None} + serializer = NullableUUIDForeignKeySourceSerializer(data=data) + assert serializer.is_valid(), serializer.errors -class PKNullableOneToOneTests(TestCase): - def setUp(self): - target = OneToOneTarget(name='target-1') - target.save() - new_target = OneToOneTarget(name='target-2') - new_target.save() - source = NullableOneToOneSource(name='source-1', target=new_target) - source.save() +def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=new_target) + source.save() def test_reverse_foreign_key_retrieve_with_null(self): - queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'nullable_source': None}, - {'id': 2, 'name': 'target-2', 'nullable_source': 1}, - ] - assert serializer.data == expected + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ] + assert serializer.data == expected -class OneToOnePrimaryKeyTests(TestCase): - def setUp(self): +def setUp(self): # Given: Some target models already exist - self.target = target = OneToOneTarget(name='target-1') - target.save() - self.alt_target = alt_target = OneToOneTarget(name='target-2') - alt_target.save() + self.target = target = OneToOneTarget(name='target-1') + target.save() + self.alt_target = alt_target = OneToOneTarget(name='target-2') + alt_target.save() def test_one_to_one_when_primary_key(self): # When: Creating a Source pointing at the id of the second Target - target_pk = self.alt_target.id - source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) - # Then: The source is valid with the serializer - if not source.is_valid(): - self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors)) + target_pk = self.alt_target.id + source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) + if not source.is_valid(): + self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors)) # Then: Saving the serializer creates a new object new_source = source.save() - # Then: The new object has the same pk as the target object - self.assertEqual(new_source.pk, target_pk) +assertnew_source.pk==target_pk def test_one_to_one_when_primary_key_no_duplicates(self): # When: Creating a Source pointing at the id of the second Target - target_pk = self.target.id - data = {'name': 'source-1', 'target': target_pk} - source = OneToOnePKSourceSerializer(data=data) - # Then: The source is valid with the serializer - self.assertTrue(source.is_valid()) - # Then: Saving the serializer creates a new object - new_source = source.save() - # Then: The new object has the same pk as the target object - self.assertEqual(new_source.pk, target_pk) - # When: Trying to create a second object - second_source = OneToOnePKSourceSerializer(data=data) - self.assertFalse(second_source.is_valid()) - expected = {'target': ['one to one pk source with this target already exists.']} - self.assertDictEqual(second_source.errors, expected) + target_pk = self.target.id + data = {'name': 'source-1', 'target': target_pk} + source = OneToOnePKSourceSerializer(data=data) + assert source.is_valid() + new_source = source.save() + assert new_source.pk == target_pk + second_source = OneToOnePKSourceSerializer(data=data) + assert not second_source.is_valid() + expected = {'target': ['one to one pk source with this target already exists.']} + assert second_source.errors == expected def test_one_to_one_when_primary_key_does_not_exist(self): # Given: a target PK that does not exist - target_pk = self.target.pk + self.alt_target.pk - source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) - # Then: The source is not valid with the serializer - self.assertFalse(source.is_valid()) - self.assertIn("Invalid pk", source.errors['target'][0]) - self.assertIn("object does not exist", source.errors['target'][0]) + target_pk = self.target.pk + self.alt_target.pk + source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) + assert not source.is_valid() + assert "Invalid pk" in source.errors['target'][0] + assert "object does not exist" in source.errors['target'][0] diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py index 0b9ca79d3..44e8d9e47 100644 --- a/tests/test_relations_slug.py +++ b/tests/test_relations_slug.py @@ -42,246 +42,177 @@ class NullableForeignKeySourceSerializer(serializers.ModelSerializer): # TODO: M2M Tests, FKTests (Non-nullable), One2One -class SlugForeignKeyTests(TestCase): - def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - new_target = ForeignKeyTarget(name='target-2') - new_target.save() - for idx in range(1, 4): - source = ForeignKeySource(name='source-%d' % idx, target=target) - source.save() +def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() def test_foreign_key_retrieve(self): - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 'target-1'}, - {'id': 2, 'name': 'source-2', 'target': 'target-1'}, - {'id': 3, 'name': 'source-3', 'target': 'target-1'} - ] - with self.assertNumQueries(4): - assert serializer.data == expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ] + with self.assertNumQueries(4): + assert serializer.data == expected def test_foreign_key_retrieve_select_related(self): - queryset = ForeignKeySource.objects.all().select_related('target') - serializer = ForeignKeySourceSerializer(queryset, many=True) - with self.assertNumQueries(1): - serializer.data + queryset = ForeignKeySource.objects.all().select_related('target') + serializer = ForeignKeySourceSerializer(queryset, many=True) + with self.assertNumQueries(1): + serializer.data def test_reverse_foreign_key_retrieve(self): - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, - {'id': 2, 'name': 'target-2', 'sources': []}, - ] - assert serializer.data == expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ] + assert serializer.data == expected def test_reverse_foreign_key_retrieve_prefetch_related(self): - queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') - serializer = ForeignKeyTargetSerializer(queryset, many=True) - with self.assertNumQueries(2): - serializer.data + queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') + serializer = ForeignKeyTargetSerializer(queryset, many=True) + with self.assertNumQueries(2): + serializer.data def test_foreign_key_update(self): - data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure source 1 is updated, and everything else is as expected - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 'target-2'}, - {'id': 2, 'name': 'source-2', 'target': 'target-1'}, - {'id': 3, 'name': 'source-3', 'target': 'target-1'} - ] - assert serializer.data == expected + data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-2'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ] + assert serializer.data == expected def test_foreign_key_update_incorrect_type(self): - data = {'id': 1, 'name': 'source-1', 'target': 123} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) - assert not serializer.is_valid() - assert serializer.errors == {'target': ['Object with name=123 does not exist.']} + data = {'id': 1, 'name': 'source-1', 'target': 123} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['Object with name=123 does not exist.']} def test_reverse_foreign_key_update(self): - data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} - instance = ForeignKeyTarget.objects.get(pk=2) - serializer = ForeignKeyTargetSerializer(instance, data=data) - assert serializer.is_valid() - # We shouldn't have saved anything to the db yet since save - # hasn't been called. - queryset = ForeignKeyTarget.objects.all() - new_serializer = ForeignKeyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, - {'id': 2, 'name': 'target-2', 'sources': []}, - ] - assert new_serializer.data == expected - - serializer.save() - assert serializer.data == data - - # Ensure target 2 is update, and everything else is as expected - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, - {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, - ] - assert serializer.data == expected + data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + assert serializer.is_valid() + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ] + assert new_serializer.data == expected + serializer.save() + assert serializer.data == data + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, ] + assert serializer.data == expected def test_foreign_key_create(self): - data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} - serializer = ForeignKeySourceSerializer(data=data) - serializer.is_valid() - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'source-4' - - # Ensure source 4 is added, and everything else is as expected - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 'target-1'}, - {'id': 2, 'name': 'source-2', 'target': 'target-1'}, - {'id': 3, 'name': 'source-3', 'target': 'target-1'}, - {'id': 4, 'name': 'source-4', 'target': 'target-2'}, - ] - assert serializer.data == expected + data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} + serializer = ForeignKeySourceSerializer(data=data) + serializer.is_valid() + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'source-4' + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'}, {'id': 4, 'name': 'source-4', 'target': 'target-2'}, ] + assert serializer.data == expected def test_reverse_foreign_key_create(self): - data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} - serializer = ForeignKeyTargetSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'target-3' - - # Ensure target 3 is added, and everything else is as expected - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, - {'id': 2, 'name': 'target-2', 'sources': []}, - {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, - ] - assert serializer.data == expected + data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} + serializer = ForeignKeyTargetSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'target-3' + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, ] + assert serializer.data == expected def test_foreign_key_update_with_invalid_null(self): - data = {'id': 1, 'name': 'source-1', 'target': None} - instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) - assert not serializer.is_valid() - assert serializer.errors == {'target': ['This field may not be null.']} + data = {'id': 1, 'name': 'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + assert not serializer.is_valid() + assert serializer.errors == {'target': ['This field may not be null.']} -class SlugNullableForeignKeyTests(TestCase): - def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - for idx in range(1, 4): - if idx == 3: - target = None +def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None source = NullableForeignKeySource(name='source-%d' % idx, target=target) - source.save() +source.save() def test_foreign_key_retrieve_with_null(self): - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 'target-1'}, - {'id': 2, 'name': 'source-2', 'target': 'target-1'}, - {'id': 3, 'name': 'source-3', 'target': None}, - ] - assert serializer.data == expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, ] + assert serializer.data == expected def test_foreign_key_create_with_valid_null(self): - data = {'id': 4, 'name': 'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == data - assert obj.name == 'source-4' - - # Ensure source 4 is created, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 'target-1'}, - {'id': 2, 'name': 'source-2', 'target': 'target-1'}, - {'id': 3, 'name': 'source-3', 'target': None}, - {'id': 4, 'name': 'source-4', 'target': None} - ] - assert serializer.data == expected + data = {'id': 4, 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == data + assert obj.name == 'source-4' + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ] + assert serializer.data == expected def test_foreign_key_create_with_valid_emptystring(self): - """ + """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'id': 4, 'name': 'source-4', 'target': ''} - expected_data = {'id': 4, 'name': 'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data) - assert serializer.is_valid() - obj = serializer.save() - assert serializer.data == expected_data - assert obj.name == 'source-4' - - # Ensure source 4 is created, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': 'target-1'}, - {'id': 2, 'name': 'source-2', 'target': 'target-1'}, - {'id': 3, 'name': 'source-3', 'target': None}, - {'id': 4, 'name': 'source-4', 'target': None} - ] - assert serializer.data == expected + data = {'id': 4, 'name': 'source-4', 'target': ''} + expected_data = {'id': 4, 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + assert serializer.is_valid() + obj = serializer.save() + assert serializer.data == expected_data + assert obj.name == 'source-4' + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ] + assert serializer.data == expected def test_foreign_key_update_with_valid_null(self): - data = {'id': 1, 'name': 'source-1', 'target': None} - instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data) - assert serializer.is_valid() - serializer.save() - assert serializer.data == data - - # Ensure source 1 is updated, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': None}, - {'id': 2, 'name': 'source-2', 'target': 'target-1'}, - {'id': 3, 'name': 'source-3', 'target': None} - ] - assert serializer.data == expected + data = {'id': 1, 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + assert serializer.is_valid() + serializer.save() + assert serializer.data == data + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ] + assert serializer.data == expected def test_foreign_key_update_with_valid_emptystring(self): - """ + """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'id': 1, 'name': 'source-1', 'target': ''} - expected_data = {'id': 1, 'name': 'source-1', 'target': None} - instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data) - assert serializer.is_valid() - serializer.save() - assert serializer.data == expected_data - - # Ensure source 1 is updated, and everything else is as expected - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': None}, - {'id': 2, 'name': 'source-2', 'target': 'target-1'}, - {'id': 3, 'name': 'source-3', 'target': None} - ] - assert serializer.data == expected + data = {'id': 1, 'name': 'source-1', 'target': ''} + expected_data = {'id': 1, 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + assert serializer.is_valid() + serializer.save() + assert serializer.data == expected_data + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ] + assert serializer.data == expected diff --git a/tests/test_renderers.py b/tests/test_renderers.py index d63dbcb9c..d2c8dbef9 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -49,11 +49,10 @@ class DummyTestModel(models.Model): name = models.CharField(max_length=42, default='') -class BasicRendererTests(TestCase): - def test_expected_results(self): - for value, renderer_cls, expected in expected_results: - output = renderer_cls().render(value) - self.assertEqual(output, expected) +def test_expected_results(self): + for value, renderer_cls, expected in expected_results: + output = renderer_cls().render(value) + assert output == expected class RendererA(BaseRenderer): @@ -144,123 +143,114 @@ class POSTDeniedView(APIView): return Response() -class DocumentingRendererTests(TestCase): - def test_only_permitted_forms_are_displayed(self): - view = POSTDeniedView.as_view() - request = APIRequestFactory().get('/') - response = view(request).render() - self.assertNotContains(response, '>POST<') - self.assertContains(response, '>PUT<') - self.assertContains(response, '>PATCH<') +def test_only_permitted_forms_are_displayed(self): + view = POSTDeniedView.as_view() + request = APIRequestFactory().get('/') + response = view(request).render() + self.assertNotContains(response, '>POST<') + self.assertContains(response, '>PUT<') + self.assertContains(response, '>PATCH<') @override_settings(ROOT_URLCONF='tests.test_renderers') -class RendererEndToEndTests(TestCase): - """ +""" End-to-end testing of renderers using an RendererMixin on a generic view. """ - def test_default_renderer_serializes_content(self): - """If the Accept header is not set the default renderer should serialize the response.""" - resp = self.client.get('/') - self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) +deftest_default_renderer_serializes_content(self): + """If the Accept header is not set the default renderer should serialize the response.""" + resp = self.client.get('/') + assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8' + assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_head_method_serializes_no_content(self): - """No response must be included in HEAD requests.""" - resp = self.client.head('/') - self.assertEqual(resp.status_code, DUMMYSTATUS) - self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') - self.assertEqual(resp.content, b'') + """No response must be included in HEAD requests.""" + resp = self.client.head('/') + assert resp.status_code == DUMMYSTATUS + assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8' + assert resp.content == b'' def test_default_renderer_serializes_content_on_accept_any(self): - """If the Accept header is set to */* the default renderer should serialize the response.""" - resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + """If the Accept header is set to */* the default renderer should serialize the response.""" + resp = self.client.get('/', HTTP_ACCEPT='*/*') + assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8' + assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_serializes_content_default_case(self): - """If the Accept header is set the specified renderer should serialize the response. + """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" - resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) + assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8' + assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_serializes_content_non_default_case(self): - """If the Accept header is set the specified renderer should serialize the response. + """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" - resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) + assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8' + assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_unsatisfiable_accept_header_on_request_returns_406_status(self): - """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" - resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) + """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" + resp = self.client.get('/', HTTP_ACCEPT='foo/bar') + assert resp.status_code == status.HTTP_406_NOT_ACCEPTABLE def test_specified_renderer_serializes_content_on_format_query(self): - """If a 'format' query is specified, the renderer with the matching + """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" - param = '?%s=%s' % ( - api_settings.URL_FORMAT_OVERRIDE, - RendererB.format - ) - resp = self.client.get('/' + param) - self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format ) + resp = self.client.get('/' + param) + assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8' + assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_serializes_content_on_format_kwargs(self): - """If a 'format' keyword arg is specified, the renderer with the matching + """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" - resp = self.client.get('/something.formatb') - self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + resp = self.client.get('/something.formatb') + assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8' + assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): - """If both a 'format' query and a matching Accept header specified, + """If both a 'format' query and a matching Accept header specified, the renderer with the matching format attribute should serialize the response.""" - param = '?%s=%s' % ( - api_settings.URL_FORMAT_OVERRIDE, - RendererB.format - ) - resp = self.client.get('/' + param, - HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format ) + resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type) + assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8' + assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_parse_error_renderers_browsable_api(self): - """Invalid data should still render the browsable API correctly.""" - resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') - self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') - self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + """Invalid data should still render the browsable API correctly.""" + resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') + assert resp['Content-Type'] == 'text/html; charset=utf-8' + assert resp.status_code == status.HTTP_400_BAD_REQUEST def test_204_no_content_responses_have_no_content_type_set(self): - """ + """ Regression test for #1196 https://github.com/encode/django-rest-framework/issues/1196 """ - resp = self.client.get('/empty') - self.assertEqual(resp.get('Content-Type', None), None) - self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) + resp = self.client.get('/empty') + assert resp.get('Content-Type', None) == None + assert resp.status_code == status.HTTP_204_NO_CONTENT def test_contains_headers_of_api_response(self): - """ + """ Issue #1437 Test we display the headers of the API response and not those from the HTML response """ - resp = self.client.get('/html1') - self.assertContains(resp, '>GET, HEAD, OPTIONS<') - self.assertContains(resp, '>application/json<') - self.assertNotContains(resp, '>text/html; charset=utf-8<') + resp = self.client.get('/html1') + self.assertContains(resp, '>GET, HEAD, OPTIONS<') + self.assertContains(resp, '>application/json<') + self.assertNotContains(resp, '>text/html; charset=utf-8<') _flat_repr = '{"foo":["bar","baz"]}' @@ -275,183 +265,174 @@ def strip_trailing_whitespace(content): return re.sub(' +\n', '\n', content) -class BaseRendererTests(TestCase): - """ +""" Tests BaseRenderer """ - def test_render_raise_error(self): - """ +deftest_render_raise_error(self): + """ BaseRenderer.render should raise NotImplementedError """ - with pytest.raises(NotImplementedError): - BaseRenderer().render('test') + with pytest.raises(NotImplementedError): + BaseRenderer().render('test') -class JSONRendererTests(TestCase): - """ +""" Tests specific to the JSON Renderer """ - - def test_render_lazy_strings(self): - """ +deftest_render_lazy_strings(self): + """ JSONRenderer should deal with lazy translated strings. """ - ret = JSONRenderer().render(_('test')) - self.assertEqual(ret, b'"test"') + ret = JSONRenderer().render(_('test')) + assert ret == b'"test"' def test_render_queryset_values(self): - o = DummyTestModel.objects.create(name='dummy') - qs = DummyTestModel.objects.values('id', 'name') - ret = JSONRenderer().render(qs) - data = json.loads(ret.decode()) - self.assertEqual(data, [{'id': o.id, 'name': o.name}]) + o = DummyTestModel.objects.create(name='dummy') + qs = DummyTestModel.objects.values('id', 'name') + ret = JSONRenderer().render(qs) + data = json.loads(ret.decode()) + assert data == [{'id': o.id, 'name': o.name}] def test_render_queryset_values_list(self): - o = DummyTestModel.objects.create(name='dummy') - qs = DummyTestModel.objects.values_list('id', 'name') - ret = JSONRenderer().render(qs) - data = json.loads(ret.decode()) - self.assertEqual(data, [[o.id, o.name]]) + o = DummyTestModel.objects.create(name='dummy') + qs = DummyTestModel.objects.values_list('id', 'name') + ret = JSONRenderer().render(qs) + data = json.loads(ret.decode()) + assert data == [[o.id, o.name]] def test_render_dict_abc_obj(self): - class Dict(MutableMapping): - def __init__(self): - self._dict = {} + class Dict(MutableMapping): + def __init__(self): + self._dict = {} def __getitem__(self, key): - return self._dict.__getitem__(key) + return self._dict.__getitem__(key) def __setitem__(self, key, value): - return self._dict.__setitem__(key, value) + return self._dict.__setitem__(key, value) def __delitem__(self, key): - return self._dict.__delitem__(key) + return self._dict.__delitem__(key) def __iter__(self): - return self._dict.__iter__() + return self._dict.__iter__() def __len__(self): - return self._dict.__len__() + return self._dict.__len__() def keys(self): - return self._dict.keys() + return self._dict.keys() x = Dict() - x['key'] = 'string value' - x[2] = 3 - ret = JSONRenderer().render(x) - data = json.loads(ret.decode()) - self.assertEqual(data, {'key': 'string value', '2': 3}) +x['key']='string value' +x[2]=3 +ret=JSONRenderer().render(x) +data=json.loads(ret.decode()) +assertdata=={'key':'string value','2':3} def test_render_obj_with_getitem(self): - class DictLike: - def __init__(self): - self._dict = {} + class DictLike: + def __init__(self): + self._dict = {} def set(self, value): - self._dict = dict(value) + self._dict = dict(value) def __getitem__(self, key): - return self._dict[key] + return self._dict[key] x = DictLike() - x.set({'a': 1, 'b': 'string'}) - with self.assertRaises(TypeError): - JSONRenderer().render(x) +x.set({'a':1,'b':'string'}) +withpytest.raises(TypeError): + JSONRenderer().render(x) def test_float_strictness(self): - renderer = JSONRenderer() - - # Default to strict - for value in [float('inf'), float('-inf'), float('nan')]: - with pytest.raises(ValueError): - renderer.render(value) + renderer = JSONRenderer() + for value in [float('inf'), float('-inf'), float('nan')]: + with pytest.raises(ValueError): + renderer.render(value) renderer.strict = False - assert renderer.render(float('inf')) == b'Infinity' - assert renderer.render(float('-inf')) == b'-Infinity' - assert renderer.render(float('nan')) == b'NaN' +assertrenderer.render(float('inf'))==b'Infinity' +assertrenderer.render(float('-inf'))==b'-Infinity' +assertrenderer.render(float('nan'))==b'NaN' def test_without_content_type_args(self): - """ + """ Test basic JSON rendering. """ - obj = {'foo': ['bar', 'baz']} - renderer = JSONRenderer() - content = renderer.render(obj, 'application/json') - # Fix failing test case which depends on version of JSON library. - self.assertEqual(content.decode(), _flat_repr) + obj = {'foo': ['bar', 'baz']} + renderer = JSONRenderer() + content = renderer.render(obj, 'application/json') + assert content.decode() == _flat_repr def test_with_content_type_args(self): - """ + """ Test JSON rendering with additional content type arguments supplied. """ - obj = {'foo': ['bar', 'baz']} - renderer = JSONRenderer() - content = renderer.render(obj, 'application/json; indent=2') - self.assertEqual(strip_trailing_whitespace(content.decode()), _indented_repr) + obj = {'foo': ['bar', 'baz']} + renderer = JSONRenderer() + content = renderer.render(obj, 'application/json; indent=2') + assert strip_trailing_whitespace(content.decode()) == _indented_repr -class UnicodeJSONRendererTests(TestCase): - """ +""" Tests specific for the Unicode JSON Renderer """ - def test_proper_encoding(self): - obj = {'countries': ['United Kingdom', 'France', 'España']} - renderer = JSONRenderer() - content = renderer.render(obj, 'application/json') - self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode()) +deftest_proper_encoding(self): + obj = {'countries': ['United Kingdom', 'France', 'España']} + renderer = JSONRenderer() + content = renderer.render(obj, 'application/json') + assert content == '{"countries":["United Kingdom","France","España"]}'.encode() def test_u2028_u2029(self): # The \u2028 and \u2029 characters should be escaped, # even when the non-escaping unicode representation is used. # Regression test for #2169 - obj = {'should_escape': '\u2028\u2029'} - renderer = JSONRenderer() - content = renderer.render(obj, 'application/json') - self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode()) + obj = {'should_escape': '\u2028\u2029'} + renderer = JSONRenderer() + content = renderer.render(obj, 'application/json') + assert content == '{"should_escape":"\\u2028\\u2029"}'.encode() -class AsciiJSONRendererTests(TestCase): - """ +""" Tests specific for the Unicode JSON Renderer """ - def test_proper_encoding(self): - class AsciiJSONRenderer(JSONRenderer): - ensure_ascii = True +deftest_proper_encoding(self): + class AsciiJSONRenderer(JSONRenderer): + ensure_ascii = True obj = {'countries': ['United Kingdom', 'France', 'España']} - renderer = AsciiJSONRenderer() - content = renderer.render(obj, 'application/json') - self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()) +renderer=AsciiJSONRenderer() +content=renderer.render(obj,'application/json') +assertcontent=='{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode() # Tests for caching issue, #346 @override_settings(ROOT_URLCONF='tests.test_renderers') -class CacheRenderTest(TestCase): - """ +""" Tests specific to caching responses """ - def test_head_caching(self): - """ +deftest_head_caching(self): + """ Test caching of HEAD requests """ - response = self.client.head('/cache') - cache.set('key', response) - cached_response = cache.get('key') - assert isinstance(cached_response, Response) - assert cached_response.content == response.content - assert cached_response.status_code == response.status_code + response = self.client.head('/cache') + cache.set('key', response) + cached_response = cache.get('key') + assert isinstance(cached_response, Response) + assert cached_response.content == response.content + assert cached_response.status_code == response.status_code def test_get_caching(self): - """ + """ Test caching of GET requests """ - response = self.client.get('/cache') - cache.set('key', response) - cached_response = cache.get('key') - assert isinstance(cached_response, Response) - assert cached_response.content == response.content - assert cached_response.status_code == response.status_code + response = self.client.get('/cache') + cache.set('key', response) + cached_response = cache.get('key') + assert isinstance(cached_response, Response) + assert cached_response.content == response.content + assert cached_response.status_code == response.status_code class TestJSONIndentationStyles: @@ -476,150 +457,116 @@ class TestJSONIndentationStyles: assert renderer.render(data) == b'{"a": 1, "b": 2}' -class TestHiddenFieldHTMLFormRenderer(TestCase): - def test_hidden_field_rendering(self): - class TestSerializer(serializers.Serializer): - published = serializers.HiddenField(default=True) +def test_hidden_field_rendering(self): + class TestSerializer(serializers.Serializer): + published = serializers.HiddenField(default=True) serializer = TestSerializer(data={}) - serializer.is_valid() - renderer = HTMLFormRenderer() - field = serializer['published'] - rendered = renderer.render_field(field, {}) - assert rendered == '' +serializer.is_valid() +renderer=HTMLFormRenderer() +field=serializer['published'] +rendered=renderer.render_field(field,{}) +assertrendered=='' -class TestHTMLFormRenderer(TestCase): - def setUp(self): - class TestSerializer(serializers.Serializer): - test_field = serializers.CharField() +def setUp(self): + class TestSerializer(serializers.Serializer): + test_field = serializers.CharField() self.renderer = HTMLFormRenderer() - self.serializer = TestSerializer(data={}) +self.serializer=TestSerializer(data={}) def test_render_with_default_args(self): - self.serializer.is_valid() - renderer = HTMLFormRenderer() - - result = renderer.render(self.serializer.data) - - self.assertIsInstance(result, SafeText) + self.serializer.is_valid() + renderer = HTMLFormRenderer() + result = renderer.render(self.serializer.data) + assert isinstance(result, SafeText) def test_render_with_provided_args(self): - self.serializer.is_valid() - renderer = HTMLFormRenderer() - - result = renderer.render(self.serializer.data, None, {}) - - self.assertIsInstance(result, SafeText) + self.serializer.is_valid() + renderer = HTMLFormRenderer() + result = renderer.render(self.serializer.data, None, {}) + assert isinstance(result, SafeText) -class TestChoiceFieldHTMLFormRenderer(TestCase): - """ +""" Test rendering ChoiceField with HTMLFormRenderer. """ - - def setUp(self): - choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) - - class TestSerializer(serializers.Serializer): - test_field = serializers.ChoiceField(choices=choices, - initial=2) +defsetUp(self): + choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) + class TestSerializer(serializers.Serializer): + test_field = serializers.ChoiceField(choices=choices, initial=2) self.TestSerializer = TestSerializer - self.renderer = HTMLFormRenderer() +self.renderer=HTMLFormRenderer() def test_render_initial_option(self): - serializer = self.TestSerializer() - result = self.renderer.render(serializer.data) - - self.assertIsInstance(result, SafeText) - - self.assertInHTML('', - result) - self.assertInHTML('', result) - self.assertInHTML('', result) + serializer = self.TestSerializer() + result = self.renderer.render(serializer.data) + assert isinstance(result, SafeText) + self.assertInHTML('', result) + self.assertInHTML('', result) + self.assertInHTML('', result) def test_render_selected_option(self): - serializer = self.TestSerializer(data={'test_field': '12'}) - - serializer.is_valid() - result = self.renderer.render(serializer.data) - - self.assertIsInstance(result, SafeText) - - self.assertInHTML('', - result) - self.assertInHTML('', result) - self.assertInHTML('', result) + serializer = self.TestSerializer(data={'test_field': '12'}) + serializer.is_valid() + result = self.renderer.render(serializer.data) + assert isinstance(result, SafeText) + self.assertInHTML('', result) + self.assertInHTML('', result) + self.assertInHTML('', result) -class TestMultipleChoiceFieldHTMLFormRenderer(TestCase): - """ +""" Test rendering MultipleChoiceField with HTMLFormRenderer. """ - - def setUp(self): - self.renderer = HTMLFormRenderer() +defsetUp(self): + self.renderer = HTMLFormRenderer() def test_render_selected_option_with_string_option_ids(self): - choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), - ('}', 'OptionBrace')) - - class TestSerializer(serializers.Serializer): - test_field = serializers.MultipleChoiceField(choices=choices) + choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), ('}', 'OptionBrace')) + class TestSerializer(serializers.Serializer): + test_field = serializers.MultipleChoiceField(choices=choices) serializer = TestSerializer(data={'test_field': ['12']}) - serializer.is_valid() - - result = self.renderer.render(serializer.data) - - self.assertIsInstance(result, SafeText) - - self.assertInHTML('', - result) - self.assertInHTML('', result) - self.assertInHTML('', result) - self.assertInHTML('', result) +serializer.is_valid() +result=self.renderer.render(serializer.data) +assertisinstance(result, SafeText) +self.assertInHTML('',result) +self.assertInHTML('',result) +self.assertInHTML('',result) +self.assertInHTML('',result) def test_render_selected_option_with_integer_option_ids(self): - choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) - - class TestSerializer(serializers.Serializer): - test_field = serializers.MultipleChoiceField(choices=choices) + choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) + class TestSerializer(serializers.Serializer): + test_field = serializers.MultipleChoiceField(choices=choices) serializer = TestSerializer(data={'test_field': ['12']}) - serializer.is_valid() - - result = self.renderer.render(serializer.data) - - self.assertIsInstance(result, SafeText) - - self.assertInHTML('', - result) - self.assertInHTML('', result) - self.assertInHTML('', result) +serializer.is_valid() +result=self.renderer.render(serializer.data) +assertisinstance(result, SafeText) +self.assertInHTML('',result) +self.assertInHTML('',result) +self.assertInHTML('',result) -class StaticHTMLRendererTests(TestCase): - """ +""" Tests specific for Static HTML Renderer """ - def setUp(self): - self.renderer = StaticHTMLRenderer() +defsetUp(self): + self.renderer = StaticHTMLRenderer() def test_static_renderer(self): - data = 'text' - result = self.renderer.render(data) - assert result == data + data = 'text' + result = self.renderer.render(data) + assert result == data def test_static_renderer_with_exception(self): - context = { - 'response': Response(status=500, exception=True), - 'request': Request(HttpRequest()) - } - result = self.renderer.render({}, renderer_context=context) - assert result == '500 Internal Server Error' + context = { 'response': Response(status=500, exception=True), 'request': Request(HttpRequest()) } + result = self.renderer.render({}, renderer_context=context) + assert result == '500 Internal Server Error' class BrowsableAPIRendererTests(URLPatternsTestCase): @@ -658,196 +605,142 @@ class BrowsableAPIRendererTests(URLPatternsTestCase): assert '>Extra list action<' in resp.content.decode() -class AdminRendererTests(TestCase): - def setUp(self): - self.renderer = AdminRenderer() +def setUp(self): + self.renderer = AdminRenderer() def test_render_when_resource_created(self): - class DummyView(APIView): - renderer_classes = (AdminRenderer, ) + class DummyView(APIView): + renderer_classes = (AdminRenderer, ) request = Request(HttpRequest()) - request.build_absolute_uri = lambda: 'http://example.com' - response = Response(status=201, headers={'Location': '/test'}) - context = { - 'view': DummyView(), - 'request': request, - 'response': response - } - - result = self.renderer.render(data={'test': 'test'}, - renderer_context=context) - assert result == '' - assert response.status_code == status.HTTP_303_SEE_OTHER - assert response['Location'] == 'http://example.com' +request.build_absolute_uri=lambda:'http://example.com' +response=Response(status=201,headers={'Location':'/test'}) +context={'view':DummyView(),'request':request,'response':response} +result=self.renderer.render(data={'test':'test'},renderer_context=context) +assertresult=='' +assertresponse.status_code==status.HTTP_303_SEE_OTHER +assertresponse['Location']=='http://example.com' def test_render_dict(self): - factory = APIRequestFactory() - - class DummyView(APIView): - renderer_classes = (AdminRenderer, ) - - def get(self, request): - return Response({'foo': 'a string'}) + factory = APIRequestFactory() + class DummyView(APIView): + renderer_classes = (AdminRenderer, ) + def get(self, request): + return Response({'foo': 'a string'}) view = DummyView.as_view() - request = factory.get('/') - response = view(request) - response.render() - self.assertContains(response, 'Fooa string', html=True) +request=factory.get('/') +response=view(request) +response.render() +self.assertContains(response,'Fooa string',html=True) def test_render_dict_with_items_key(self): - factory = APIRequestFactory() - - class DummyView(APIView): - renderer_classes = (AdminRenderer, ) - - def get(self, request): - return Response({'items': 'a string'}) + factory = APIRequestFactory() + class DummyView(APIView): + renderer_classes = (AdminRenderer, ) + def get(self, request): + return Response({'items': 'a string'}) view = DummyView.as_view() - request = factory.get('/') - response = view(request) - response.render() - self.assertContains(response, 'Itemsa string', html=True) +request=factory.get('/') +response=view(request) +response.render() +self.assertContains(response,'Itemsa string',html=True) def test_render_dict_with_iteritems_key(self): - factory = APIRequestFactory() - - class DummyView(APIView): - renderer_classes = (AdminRenderer, ) - - def get(self, request): - return Response({'iteritems': 'a string'}) + factory = APIRequestFactory() + class DummyView(APIView): + renderer_classes = (AdminRenderer, ) + def get(self, request): + return Response({'iteritems': 'a string'}) view = DummyView.as_view() - request = factory.get('/') - response = view(request) - response.render() - self.assertContains(response, 'Iteritemsa string', html=True) +request=factory.get('/') +response=view(request) +response.render() +self.assertContains(response,'Iteritemsa string',html=True) def test_get_result_url(self): - factory = APIRequestFactory() - - class DummyGenericViewsetLike(APIView): - lookup_field = 'test' - - def reverse_action(view, *args, **kwargs): - self.assertEqual(kwargs['kwargs']['test'], 1) - return '/example/' + factory = APIRequestFactory() + class DummyGenericViewsetLike(APIView): + lookup_field = 'test' + def reverse_action(view, *args, **kwargs): + assert kwargs['kwargs']['test'] == 1 + return '/example/' # get the view instance instead of the view function view = DummyGenericViewsetLike.as_view() - request = factory.get('/') - response = view(request) - view = response.renderer_context['view'] - - self.assertEqual(self.renderer.get_result_url({'test': 1}, view), '/example/') - self.assertIsNone(self.renderer.get_result_url({}, view)) +request=factory.get('/') +response=view(request) +view=response.renderer_context['view'] +assertself.renderer.get_result_url({'test':1},view)=='/example/' +assertself.renderer.get_result_url({},view)is None def test_get_result_url_no_result(self): - factory = APIRequestFactory() - - class DummyView(APIView): - lookup_field = 'test' + factory = APIRequestFactory() + class DummyView(APIView): + lookup_field = 'test' # get the view instance instead of the view function view = DummyView.as_view() - request = factory.get('/') - response = view(request) - view = response.renderer_context['view'] - - self.assertIsNone(self.renderer.get_result_url({'test': 1}, view)) - self.assertIsNone(self.renderer.get_result_url({}, view)) +request=factory.get('/') +response=view(request) +view=response.renderer_context['view'] +assertself.renderer.get_result_url({'test':1},view)is None +assertself.renderer.get_result_url({},view)is None def test_get_context_result_urls(self): - factory = APIRequestFactory() - - class DummyView(APIView): - lookup_field = 'test' - - def reverse_action(view, url_name, args=None, kwargs=None): - return '/%s/%d' % (url_name, kwargs['test']) + factory = APIRequestFactory() + class DummyView(APIView): + lookup_field = 'test' + def reverse_action(view, url_name, args=None, kwargs=None): + return '/%s/%d' % (url_name, kwargs['test']) # get the view instance instead of the view function view = DummyView.as_view() - request = factory.get('/') - response = view(request) - - data = [ - {'test': 1}, - {'url': '/example', 'test': 2}, - {'url': None, 'test': 3}, - {}, - ] - context = { - 'view': DummyView(), - 'request': Request(request), - 'response': response - } - - context = self.renderer.get_context(data, None, context) - results = context['results'] - - self.assertEqual(len(results), 4) - self.assertEqual(results[0]['url'], '/detail/1') - self.assertEqual(results[1]['url'], '/example') - self.assertEqual(results[2]['url'], None) - self.assertNotIn('url', results[3]) +request=factory.get('/') +response=view(request) +data=[{'test':1},{'url':'/example','test':2},{'url':None,'test':3},{},] +context={'view':DummyView(),'request':Request(request),'response':response} +context=self.renderer.get_context(data,None,context) +results=context['results'] +assertlen(results)==4 +assertresults[0]['url']=='/detail/1' +assertresults[1]['url']=='/example' +assertresults[2]['url']==None +assert'url'not inresults[3] @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class TestDocumentationRenderer(TestCase): - def test_document_with_link_named_data(self): - """ +def test_document_with_link_named_data(self): + """ Ref #5395: Doc's `document.data` would fail with a Link named "data". As per #4972, use templatetag instead. """ - document = coreapi.Document( - title='Data Endpoint API', - url='https://api.example.org/', - content={ - 'data': coreapi.Link( - url='/data/', - action='get', - fields=[], - description='Return data.' - ) - } - ) - - factory = APIRequestFactory() - request = factory.get('/') - - renderer = DocumentationRenderer() - - html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request}) - assert '

Data Endpoint API

' in html + document = coreapi.Document( title='Data Endpoint API', url='https://api.example.org/', content={ 'data': coreapi.Link( url='/data/', action='get', fields=[], description='Return data.' ) } ) + factory = APIRequestFactory() + request = factory.get('/') + renderer = DocumentationRenderer() + html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request}) + assert '

Data Endpoint API

' in html def test_shell_code_example_rendering(self): - template = loader.get_template('rest_framework/docs/langs/shell.html') - context = { - 'document': coreapi.Document(url='https://api.example.org/'), - 'link_key': 'testcases > list', - 'link': coreapi.Link(url='/data/', action='get', fields=[]), - } - html = template.render(context) - assert 'testcases list' in html + template = loader.get_template('rest_framework/docs/langs/shell.html') + context = { 'document': coreapi.Document(url='https://api.example.org/'), 'link_key': 'testcases > list', 'link': coreapi.Link(url='/data/', action='get', fields=[]), } + html = template.render(context) + assert 'testcases list' in html @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class TestSchemaJSRenderer(TestCase): - def test_schemajs_output(self): - """ +def test_schemajs_output(self): + """ Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates, and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix. """ - factory = APIRequestFactory() - request = factory.get('/') - - renderer = SchemaJSRenderer() - - output = renderer.render('data', renderer_context={"request": request}) - assert "'ImRhdGEi'" in output - assert "'b'ImRhdGEi''" not in output + factory = APIRequestFactory() + request = factory.get('/') + renderer = SchemaJSRenderer() + output = renderer.render('data', renderer_context={"request": request}) + assert "'ImRhdGEi'" in output + assert "'b'ImRhdGEi''" not in output diff --git a/tests/test_request.py b/tests/test_request.py index 0f682deb0..d7f1d44cb 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -23,18 +23,11 @@ from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView factory = APIRequestFactory() - - -class TestInitializer(TestCase): - def test_request_type(self): - request = Request(factory.get('/')) - - message = ( - 'The `request` argument must be an instance of ' - '`django.http.HttpRequest`, not `rest_framework.request.Request`.' - ) - with self.assertRaisesMessage(AssertionError, message): - Request(request) +def test_request_type(self): + request = Request(factory.get('/')) + message = ( 'The `request` argument must be an instance of ' '`django.http.HttpRequest`, not `rest_framework.request.Request`.' ) + with self.assertRaisesMessage(AssertionError, message): + Request(request) class PlainTextParser(BaseParser): @@ -50,79 +43,78 @@ class PlainTextParser(BaseParser): return stream.read() -class TestContentParsing(TestCase): - def test_standard_behaviour_determines_no_content_GET(self): - """ +def test_standard_behaviour_determines_no_content_GET(self): + """ Ensure request.data returns empty QueryDict for GET request. """ - request = Request(factory.get('/')) - assert request.data == {} + request = Request(factory.get('/')) + assert request.data == {} def test_standard_behaviour_determines_no_content_HEAD(self): - """ + """ Ensure request.data returns empty QueryDict for HEAD request. """ - request = Request(factory.head('/')) - assert request.data == {} + request = Request(factory.head('/')) + assert request.data == {} def test_request_DATA_with_form_content(self): - """ + """ Ensure request.data returns content for POST request with form content. """ - data = {'qwerty': 'uiop'} - request = Request(factory.post('/', data)) - request.parsers = (FormParser(), MultiPartParser()) - assert list(request.data.items()) == list(data.items()) + data = {'qwerty': 'uiop'} + request = Request(factory.post('/', data)) + request.parsers = (FormParser(), MultiPartParser()) + assert list(request.data.items()) == list(data.items()) def test_request_DATA_with_text_content(self): - """ + """ Ensure request.data returns content for POST request with non-form content. """ - content = b'qwerty' - content_type = 'text/plain' - request = Request(factory.post('/', content, content_type=content_type)) - request.parsers = (PlainTextParser(),) - assert request.data == content + content = b'qwerty' + content_type = 'text/plain' + request = Request(factory.post('/', content, content_type=content_type)) + request.parsers = (PlainTextParser(),) + assert request.data == content def test_request_POST_with_form_content(self): - """ + """ Ensure request.POST returns content for POST request with form content. """ - data = {'qwerty': 'uiop'} - request = Request(factory.post('/', data)) - request.parsers = (FormParser(), MultiPartParser()) - assert list(request.POST.items()) == list(data.items()) + data = {'qwerty': 'uiop'} + request = Request(factory.post('/', data)) + request.parsers = (FormParser(), MultiPartParser()) + assert list(request.POST.items()) == list(data.items()) def test_request_POST_with_files(self): - """ + """ Ensure request.POST returns no content for POST request with file content. """ - upload = SimpleUploadedFile("file.txt", b"file_content") - request = Request(factory.post('/', {'upload': upload})) - request.parsers = (FormParser(), MultiPartParser()) - assert list(request.POST) == [] - assert list(request.FILES) == ['upload'] + upload = SimpleUploadedFile("file.txt", b"file_content") + request = Request(factory.post('/', {'upload': upload})) + request.parsers = (FormParser(), MultiPartParser()) + assert list(request.POST) == [] + assert list(request.FILES) == ['upload'] def test_standard_behaviour_determines_form_content_PUT(self): - """ + """ Ensure request.data returns content for PUT request with form content. """ - data = {'qwerty': 'uiop'} - request = Request(factory.put('/', data)) - request.parsers = (FormParser(), MultiPartParser()) - assert list(request.data.items()) == list(data.items()) + data = {'qwerty': 'uiop'} + request = Request(factory.put('/', data)) + request.parsers = (FormParser(), MultiPartParser()) + assert list(request.data.items()) == list(data.items()) def test_standard_behaviour_determines_non_form_content_PUT(self): - """ + """ Ensure request.data returns content for PUT request with non-form content. """ - content = b'qwerty' - content_type = 'text/plain' - request = Request(factory.put('/', content, content_type=content_type)) - request.parsers = (PlainTextParser(), ) - assert request.data == content + content = b'qwerty' + content_type = 'text/plain' + request = Request(factory.put('/', content, content_type=content_type)) + request.parsers = (PlainTextParser(), ) + assert request.data == content class MockView(APIView): @@ -160,175 +152,148 @@ urlpatterns = [ @override_settings( ROOT_URLCONF='tests.test_request', FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler']) -class FileUploadTests(TestCase): - def test_fileuploads_closed_at_request_end(self): - with tempfile.NamedTemporaryFile() as f: - response = self.client.post('/upload/', {'file': f}) +def test_fileuploads_closed_at_request_end(self): + with tempfile.NamedTemporaryFile() as f: + response = self.client.post('/upload/', {'file': f}) # sanity check that file was processed assert len(response.data) == 1 - - for file in response.data: - assert not os.path.exists(file) +forfileinresponse.data: + assert not os.path.exists(file) @override_settings(ROOT_URLCONF='tests.test_request') -class TestContentParsingWithAuthentication(TestCase): - def setUp(self): - self.csrf_client = APIClient(enforce_csrf_checks=True) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user(self.username, self.email, self.password) +def setUp(self): + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) def test_user_logged_in_authentication_has_POST_when_not_logged_in(self): - """ + """ Ensures request.POST exists after SessionAuthentication when user doesn't log in. """ - content = {'example': 'example'} - - response = self.client.post('/', content) - assert status.HTTP_200_OK == response.status_code - - response = self.csrf_client.post('/', content) - assert status.HTTP_200_OK == response.status_code + content = {'example': 'example'} + response = self.client.post('/', content) + assert status.HTTP_200_OK == response.status_code + response = self.csrf_client.post('/', content) + assert status.HTTP_200_OK == response.status_code -class TestUserSetter(TestCase): - def setUp(self): +def setUp(self): # Pass request object through session middleware so session is # available to login and logout functions - self.wrapped_request = factory.get('/') - self.request = Request(self.wrapped_request) - SessionMiddleware().process_request(self.wrapped_request) - AuthenticationMiddleware().process_request(self.wrapped_request) - - User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') - self.user = authenticate(username='ringo', password='yellow') + self.wrapped_request = factory.get('/') + self.request = Request(self.wrapped_request) + SessionMiddleware().process_request(self.wrapped_request) + AuthenticationMiddleware().process_request(self.wrapped_request) + User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') + self.user = authenticate(username='ringo', password='yellow') def test_user_can_be_set(self): - self.request.user = self.user - assert self.request.user == self.user + self.request.user = self.user + assert self.request.user == self.user def test_user_can_login(self): - login(self.request, self.user) - assert self.request.user == self.user + login(self.request, self.user) + assert self.request.user == self.user def test_user_can_logout(self): - self.request.user = self.user - assert not self.request.user.is_anonymous - logout(self.request) - assert self.request.user.is_anonymous + self.request.user = self.user + assert not self.request.user.is_anonymous + logout(self.request) + assert self.request.user.is_anonymous def test_logged_in_user_is_set_on_wrapped_request(self): - login(self.request, self.user) - assert self.wrapped_request.user == self.user + login(self.request, self.user) + assert self.wrapped_request.user == self.user def test_calling_user_fails_when_attribute_error_is_raised(self): - """ + """ This proves that when an AttributeError is raised inside of the request.user property, that we can handle this and report the true, underlying error. """ - class AuthRaisesAttributeError: - def authenticate(self, request): - self.MISSPELLED_NAME_THAT_DOESNT_EXIST + class AuthRaisesAttributeError: + def authenticate(self, request): + self.MISSPELLED_NAME_THAT_DOESNT_EXIST request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),)) - - # The middleware processes the underlying Django request, sets anonymous user - assert self.wrapped_request.user.is_anonymous - - # The DRF request object does not have a user and should run authenticators - expected = r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'" - with pytest.raises(WrappedAttributeError, match=expected): - request.user +assertself.wrapped_request.user.is_anonymous +expected=r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'" +withpytest.raises(WrappedAttributeError,match=expected): + request.user with pytest.raises(WrappedAttributeError, match=expected): - hasattr(request, 'user') + hasattr(request, 'user') with pytest.raises(WrappedAttributeError, match=expected): - login(request, self.user) + login(request, self.user) -class TestAuthSetter(TestCase): - def test_auth_can_be_set(self): - request = Request(factory.get('/')) - request.auth = 'DUMMY' - assert request.auth == 'DUMMY' +def test_auth_can_be_set(self): + request = Request(factory.get('/')) + request.auth = 'DUMMY' + assert request.auth == 'DUMMY' -class TestSecure(TestCase): - def test_default_secure_false(self): - request = Request(factory.get('/', secure=False)) - assert request.scheme == 'http' +def test_default_secure_false(self): + request = Request(factory.get('/', secure=False)) + assert request.scheme == 'http' def test_default_secure_true(self): - request = Request(factory.get('/', secure=True)) - assert request.scheme == 'https' + request = Request(factory.get('/', secure=True)) + assert request.scheme == 'https' -class TestHttpRequest(TestCase): - def test_attribute_access_proxy(self): - http_request = factory.get('/') - request = Request(http_request) - - inner_sentinel = object() - http_request.inner_property = inner_sentinel - assert request.inner_property is inner_sentinel - - outer_sentinel = object() - request.inner_property = outer_sentinel - assert request.inner_property is outer_sentinel +def test_attribute_access_proxy(self): + http_request = factory.get('/') + request = Request(http_request) + inner_sentinel = object() + http_request.inner_property = inner_sentinel + assert request.inner_property is inner_sentinel + outer_sentinel = object() + request.inner_property = outer_sentinel + assert request.inner_property is outer_sentinel def test_exception_proxy(self): # ensure the exception message is not for the underlying WSGIRequest - http_request = factory.get('/') - request = Request(http_request) - - message = "'Request' object has no attribute 'inner_property'" - with self.assertRaisesMessage(AttributeError, message): - request.inner_property + http_request = factory.get('/') + request = Request(http_request) + message = "'Request' object has no attribute 'inner_property'" + with self.assertRaisesMessage(AttributeError, message): + request.inner_property @override_settings(ROOT_URLCONF='tests.test_request') - def test_duplicate_request_stream_parsing_exception(self): - """ +deftest_duplicate_request_stream_parsing_exception(self): + """ Check assumption that duplicate stream parsing will result in a `RawPostDataException` being raised. """ - response = APIClient().post('/echo/', data={'a': 'b'}, format='json') - request = response.renderer_context['request'] - - # ensure that request stream was consumed by json parser - assert request.content_type.startswith('application/json') - assert response.data == {'a': 'b'} - - # pass same HttpRequest to view, stream already consumed - with pytest.raises(RawPostDataException): - EchoView.as_view()(request._request) + response = APIClient().post('/echo/', data={'a': 'b'}, format='json') + request = response.renderer_context['request'] + assert request.content_type.startswith('application/json') + assert response.data == {'a': 'b'} + with pytest.raises(RawPostDataException): + EchoView.as_view()(request._request) @override_settings(ROOT_URLCONF='tests.test_request') - def test_duplicate_request_form_data_access(self): - """ +deftest_duplicate_request_form_data_access(self): + """ Form data is copied to the underlying django request for middleware and file closing reasons. Duplicate processing of a request with form data is 'safe' in so far as accessing `request.POST` does not trigger the duplicate stream parse exception. """ - response = APIClient().post('/echo/', data={'a': 'b'}) - request = response.renderer_context['request'] - - # ensure that request stream was consumed by form parser - assert request.content_type.startswith('multipart/form-data') - assert response.data == {'a': ['b']} - - # pass same HttpRequest to view, form data set on underlying request - response = EchoView.as_view()(request._request) - request = response.renderer_context['request'] - - # ensure that request stream was consumed by form parser - assert request.content_type.startswith('multipart/form-data') - assert response.data == {'a': ['b']} + response = APIClient().post('/echo/', data={'a': 'b'}) + request = response.renderer_context['request'] + assert request.content_type.startswith('multipart/form-data') + assert response.data == {'a': ['b']} + response = EchoView.as_view()(request._request) + request = response.renderer_context['request'] + assert request.content_type.startswith('multipart/form-data') + assert response.data == {'a': ['b']} diff --git a/tests/test_response.py b/tests/test_response.py index d3a56d01b..9f2027541 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -131,155 +131,146 @@ urlpatterns = [ # TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... @override_settings(ROOT_URLCONF='tests.test_response') -class RendererIntegrationTests(TestCase): - """ +""" End-to-end testing of renderers using an ResponseMixin on a generic view. """ - def test_default_renderer_serializes_content(self): - """If the Accept header is not set the default renderer should serialize the response.""" - resp = self.client.get('/') - self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) +deftest_default_renderer_serializes_content(self): + """If the Accept header is not set the default renderer should serialize the response.""" + resp = self.client.get('/') + assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8' + assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_head_method_serializes_no_content(self): - """No response must be included in HEAD requests.""" - resp = self.client.head('/') - self.assertEqual(resp.status_code, DUMMYSTATUS) - self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') - self.assertEqual(resp.content, b'') + """No response must be included in HEAD requests.""" + resp = self.client.head('/') + assert resp.status_code == DUMMYSTATUS + assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8' + assert resp.content == b'' def test_default_renderer_serializes_content_on_accept_any(self): - """If the Accept header is set to */* the default renderer should serialize the response.""" - resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + """If the Accept header is set to */* the default renderer should serialize the response.""" + resp = self.client.get('/', HTTP_ACCEPT='*/*') + assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8' + assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_serializes_content_default_case(self): - """If the Accept header is set the specified renderer should serialize the response. + """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" - resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) + assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8' + assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_serializes_content_non_default_case(self): - """If the Accept header is set the specified renderer should serialize the response. + """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" - resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) + assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8' + assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_serializes_content_on_format_query(self): - """If a 'format' query is specified, the renderer with the matching + """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" - resp = self.client.get('/?format=%s' % RendererB.format) - self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + resp = self.client.get('/?format=%s' % RendererB.format) + assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8' + assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_serializes_content_on_format_kwargs(self): - """If a 'format' keyword arg is specified, the renderer with the matching + """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" - resp = self.client.get('/something.formatb') - self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + resp = self.client.get('/something.formatb') + assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8' + assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): - """If both a 'format' query and a matching Accept header specified, + """If both a 'format' query and a matching Accept header specified, the renderer with the matching format attribute should serialize the response.""" - resp = self.client.get('/?format=%s' % RendererB.format, - HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') - self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEqual(resp.status_code, DUMMYSTATUS) + resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type) + assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8' + assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT) + assert resp.status_code == DUMMYSTATUS @override_settings(ROOT_URLCONF='tests.test_response') -class UnsupportedMediaTypeTests(TestCase): - def test_should_allow_posting_json(self): - response = self.client.post('/json', data='{"test": 123}', content_type='application/json') - - self.assertEqual(response.status_code, 200) +def test_should_allow_posting_json(self): + response = self.client.post('/json', data='{"test": 123}', content_type='application/json') + assert response.status_code == 200 def test_should_not_allow_posting_xml(self): - response = self.client.post('/json', data='123', content_type='application/xml') - - self.assertEqual(response.status_code, 415) + response = self.client.post('/json', data='123', content_type='application/xml') + assert response.status_code == 415 def test_should_not_allow_posting_a_form(self): - response = self.client.post('/json', data={'test': 123}) - - self.assertEqual(response.status_code, 415) + response = self.client.post('/json', data={'test': 123}) + assert response.status_code == 415 @override_settings(ROOT_URLCONF='tests.test_response') -class Issue122Tests(TestCase): - """ +""" Tests that covers #122. """ - def test_only_html_renderer(self): - """ +deftest_only_html_renderer(self): + """ Test if no infinite recursion occurs. """ - self.client.get('/html') + self.client.get('/html') def test_html_renderer_is_first(self): - """ + """ Test if no infinite recursion occurs. """ - self.client.get('/html1') + self.client.get('/html1') @override_settings(ROOT_URLCONF='tests.test_response') -class Issue467Tests(TestCase): - """ +""" Tests for #467 """ - def test_form_has_label_and_help_text(self): - resp = self.client.get('/html_new_model') - self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +deftest_form_has_label_and_help_text(self): + resp = self.client.get('/html_new_model') + assert resp['Content-Type'] == 'text/html; charset=utf-8' # self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text description.') @override_settings(ROOT_URLCONF='tests.test_response') -class Issue807Tests(TestCase): - """ +""" Covers #807 """ - def test_does_not_append_charset_by_default(self): - """ +deftest_does_not_append_charset_by_default(self): + """ Renderers don't include a charset unless set explicitly. """ - headers = {"HTTP_ACCEPT": RendererA.media_type} - resp = self.client.get('/', **headers) - expected = "{}; charset={}".format(RendererA.media_type, 'utf-8') - self.assertEqual(expected, resp['Content-Type']) + headers = {"HTTP_ACCEPT": RendererA.media_type} + resp = self.client.get('/', **headers) + expected = "{}; charset={}".format(RendererA.media_type, 'utf-8') + assert expected == resp['Content-Type'] def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): - """ + """ If renderer class has charset attribute declared, it gets appended to Response's Content-Type """ - headers = {"HTTP_ACCEPT": RendererC.media_type} - resp = self.client.get('/', **headers) - expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset) - self.assertEqual(expected, resp['Content-Type']) + headers = {"HTTP_ACCEPT": RendererC.media_type} + resp = self.client.get('/', **headers) + expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset) + assert expected == resp['Content-Type'] def test_content_type_set_explicitly_on_response(self): - """ + """ The content type may be set explicitly on the response. """ - headers = {"HTTP_ACCEPT": RendererC.media_type} - resp = self.client.get('/setbyview', **headers) - self.assertEqual('setbyview', resp['Content-Type']) + headers = {"HTTP_ACCEPT": RendererC.media_type} + resp = self.client.get('/setbyview', **headers) + assert 'setbyview' == resp['Content-Type'] def test_form_has_label_and_help_text(self): - resp = self.client.get('/html_new_model') - self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') + resp = self.client.get('/html_new_model') + assert resp['Content-Type'] == 'text/html; charset=utf-8' # self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text description.') diff --git a/tests/test_reverse.py b/tests/test_reverse.py index 9ab1667c5..be360ed97 100644 --- a/tests/test_reverse.py +++ b/tests/test_reverse.py @@ -30,25 +30,22 @@ class MockVersioningScheme: @override_settings(ROOT_URLCONF='tests.test_reverse') -class ReverseTests(TestCase): - """ +""" Tests for fully qualified URLs when using `reverse`. """ - def test_reversed_urls_are_fully_qualified(self): - request = factory.get('/view') - url = reverse('view', request=request) - assert url == 'http://testserver/view' +deftest_reversed_urls_are_fully_qualified(self): + request = factory.get('/view') + url = reverse('view', request=request) + assert url == 'http://testserver/view' def test_reverse_with_versioning_scheme(self): - request = factory.get('/view') - request.versioning_scheme = MockVersioningScheme() - - url = reverse('view', request=request) - assert url == 'http://scheme-reversed/view' + request = factory.get('/view') + request.versioning_scheme = MockVersioningScheme() + url = reverse('view', request=request) + assert url == 'http://scheme-reversed/view' def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self): - request = factory.get('/view') - request.versioning_scheme = MockVersioningScheme(raise_error=True) - - url = reverse('view', request=request) - assert url == 'http://testserver/view' + request = factory.get('/view') + request.versioning_scheme = MockVersioningScheme(raise_error=True) + url = reverse('view', request=request) + assert url == 'http://testserver/view' diff --git a/tests/test_routers.py b/tests/test_routers.py index 0f428e2a5..3341f8fd6 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -214,25 +214,24 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase): assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"} -class TestLookupValueRegex(TestCase): - """ +""" Ensure the router honors lookup_value_regex when applied to the viewset. """ - def setUp(self): - class NoteViewSet(viewsets.ModelViewSet): - queryset = RouterTestModel.objects.all() - lookup_field = 'uuid' - lookup_value_regex = '[0-9a-f]{32}' +defsetUp(self): + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() + lookup_field = 'uuid' + lookup_value_regex = '[0-9a-f]{32}' self.router = SimpleRouter() - self.router.register(r'notes', NoteViewSet) - self.urls = self.router.urls +self.router.register(r'notes',NoteViewSet) +self.urls=self.router.urls def test_urls_limited_by_lookup_value_regex(self): - expected = ['^notes/$', '^notes/(?P[0-9a-f]{32})/$'] - for idx in range(len(expected)): - assert expected[idx] == get_regex_pattern(self.urls[idx]) + expected = ['^notes/$', '^notes/(?P[0-9a-f]{32})/$'] + for idx in range(len(expected)): + assert expected[idx] == get_regex_pattern(self.urls[idx]) @override_settings(ROOT_URLCONF='tests.test_routers') @@ -264,97 +263,84 @@ class TestLookupUrlKwargs(URLPatternsTestCase, TestCase): assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"} -class TestTrailingSlashIncluded(TestCase): - def setUp(self): - class NoteViewSet(viewsets.ModelViewSet): - queryset = RouterTestModel.objects.all() +def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() self.router = SimpleRouter() - self.router.register(r'notes', NoteViewSet) - self.urls = self.router.urls +self.router.register(r'notes',NoteViewSet) +self.urls=self.router.urls def test_urls_have_trailing_slash_by_default(self): - expected = ['^notes/$', '^notes/(?P[^/.]+)/$'] - for idx in range(len(expected)): - assert expected[idx] == get_regex_pattern(self.urls[idx]) + expected = ['^notes/$', '^notes/(?P[^/.]+)/$'] + for idx in range(len(expected)): + assert expected[idx] == get_regex_pattern(self.urls[idx]) -class TestTrailingSlashRemoved(TestCase): - def setUp(self): - class NoteViewSet(viewsets.ModelViewSet): - queryset = RouterTestModel.objects.all() +def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() self.router = SimpleRouter(trailing_slash=False) - self.router.register(r'notes', NoteViewSet) - self.urls = self.router.urls +self.router.register(r'notes',NoteViewSet) +self.urls=self.router.urls def test_urls_can_have_trailing_slash_removed(self): - expected = ['^notes$', '^notes/(?P[^/.]+)$'] - for idx in range(len(expected)): - assert expected[idx] == get_regex_pattern(self.urls[idx]) + expected = ['^notes$', '^notes/(?P[^/.]+)$'] + for idx in range(len(expected)): + assert expected[idx] == get_regex_pattern(self.urls[idx]) -class TestNameableRoot(TestCase): - def setUp(self): - class NoteViewSet(viewsets.ModelViewSet): - queryset = RouterTestModel.objects.all() +def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() self.router = DefaultRouter() - self.router.root_view_name = 'nameable-root' - self.router.register(r'notes', NoteViewSet) - self.urls = self.router.urls +self.router.root_view_name='nameable-root' +self.router.register(r'notes',NoteViewSet) +self.urls=self.router.urls def test_router_has_custom_name(self): - expected = 'nameable-root' - assert expected == self.urls[-1].name + expected = 'nameable-root' + assert expected == self.urls[-1].name -class TestActionKeywordArgs(TestCase): - """ +""" Ensure keyword arguments passed in the `@action` decorator are properly handled. Refs #940. """ - - def setUp(self): - class TestViewSet(viewsets.ModelViewSet): - permission_classes = [] - - @action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny]) - def custom(self, request, *args, **kwargs): - return Response({ - 'permission_classes': self.permission_classes - }) +defsetUp(self): + class TestViewSet(viewsets.ModelViewSet): + permission_classes = [] + @action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny]) + def custom(self, request, *args, **kwargs): + return Response({ 'permission_classes': self.permission_classes }) self.router = SimpleRouter() - self.router.register(r'test', TestViewSet, basename='test') - self.view = self.router.urls[-1].callback +self.router.register(r'test',TestViewSet,basename='test') +self.view=self.router.urls[-1].callback def test_action_kwargs(self): - request = factory.post('/test/0/custom/') - response = self.view(request) - assert response.data == {'permission_classes': [permissions.AllowAny]} + request = factory.post('/test/0/custom/') + response = self.view(request) + assert response.data == {'permission_classes': [permissions.AllowAny]} -class TestActionAppliedToExistingRoute(TestCase): - """ +""" Ensure `@action` decorator raises an except when applied to an existing route """ +deftest_exception_raised_when_action_applied_to_existing_route(self): + class TestViewSet(viewsets.ModelViewSet): - def test_exception_raised_when_action_applied_to_existing_route(self): - class TestViewSet(viewsets.ModelViewSet): - - @action(methods=['post'], detail=True) - def retrieve(self, request, *args, **kwargs): - return Response({ - 'hello': 'world' - }) + @action(methods=['post'], detail=True) + def retrieve(self, request, *args, **kwargs): + return Response({ 'hello': 'world' }) self.router = SimpleRouter() - self.router.register(r'test', TestViewSet, basename='test') - - with pytest.raises(ImproperlyConfigured): - self.router.urls +self.router.register(r'test',TestViewSet,basename='test') +withpytest.raises(ImproperlyConfigured): + self.router.urls class DynamicListAndDetailViewSet(viewsets.ViewSet): @@ -390,44 +376,33 @@ class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet): pass -class TestDynamicListAndDetailRouter(TestCase): - def setUp(self): - self.router = SimpleRouter() +def setUp(self): + self.router = SimpleRouter() def _test_list_and_detail_route_decorators(self, viewset): - routes = self.router.get_routes(viewset) - decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] - - MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path') - # Make sure all these endpoints exist and none have been clobbered - for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), - MethodNamesMap('list_route_get', 'list_route_get'), - MethodNamesMap('list_route_post', 'list_route_post'), - MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), - MethodNamesMap('detail_route_get', 'detail_route_get'), - MethodNamesMap('detail_route_post', 'detail_route_post') - ]): - route = decorator_routes[i] - # check url listing - method_name = endpoint.method_name - url_path = endpoint.url_path - - if method_name.startswith('list_'): - assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path) + routes = self.router.get_routes(viewset) + decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] + MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path') + for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), MethodNamesMap('list_route_get', 'list_route_get'), MethodNamesMap('list_route_post', 'list_route_post'), MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), MethodNamesMap('detail_route_get', 'detail_route_get'), MethodNamesMap('detail_route_post', 'detail_route_post') ]): + route = decorator_routes[i] + method_name = endpoint.method_name + url_path = endpoint.url_path + if method_name.startswith('list_'): + assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path) else: - assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path) + assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path) # check method to function mapping if method_name.endswith('_post'): - method_map = 'post' + method_map = 'post' else: - method_map = 'get' + method_map = 'get' assert route.mapping[method_map] == method_name def test_list_and_detail_route_decorators(self): - self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet) + self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet) def test_inherited_list_and_detail_route_decorators(self): - self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet) + self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet) class TestEmptyPrefix(URLPatternsTestCase, TestCase): @@ -490,69 +465,52 @@ class TestViewInitkwargs(URLPatternsTestCase, TestCase): assert initkwargs['basename'] == 'routertestmodel' -class TestBaseNameRename(TestCase): - def test_base_name_and_basename_assertion(self): - router = SimpleRouter() - - msg = "Do not provide both the `basename` and `base_name` arguments." - with warnings.catch_warnings(record=True) as w, \ - self.assertRaisesMessage(AssertionError, msg): - warnings.simplefilter('always') - router.register('mock', MockViewSet, 'mock', base_name='mock') +def test_base_name_and_basename_assertion(self): + router = SimpleRouter() + msg = "Do not provide both the `basename` and `base_name` arguments." + with warnings.catch_warnings(record=True) as w, self.assertRaisesMessage(AssertionError, msg): + warnings.simplefilter('always') + router.register('mock', MockViewSet, 'mock', base_name='mock') msg = "The `base_name` argument is pending deprecation in favor of `basename`." - assert len(w) == 1 - assert str(w[0].message) == msg +assertlen(w)==1 +assertstr(w[0].message)==msg def test_base_name_argument_deprecation(self): - router = SimpleRouter() - - with pytest.warns(RemovedInDRF311Warning) as w: - warnings.simplefilter('always') - router.register('mock', MockViewSet, base_name='mock') + router = SimpleRouter() + with pytest.warns(RemovedInDRF311Warning) as w: + warnings.simplefilter('always') + router.register('mock', MockViewSet, base_name='mock') msg = "The `base_name` argument is pending deprecation in favor of `basename`." - assert len(w) == 1 - assert str(w[0].message) == msg - assert router.registry == [ - ('mock', MockViewSet, 'mock'), - ] +assertlen(w)==1 +assertstr(w[0].message)==msg +assertrouter.registry==[('mock',MockViewSet,'mock'),] def test_basename_argument_no_warnings(self): - router = SimpleRouter() - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - router.register('mock', MockViewSet, basename='mock') + router = SimpleRouter() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + router.register('mock', MockViewSet, basename='mock') assert len(w) == 0 - assert router.registry == [ - ('mock', MockViewSet, 'mock'), - ] +assertrouter.registry==[('mock',MockViewSet,'mock'),] def test_get_default_base_name_deprecation(self): - msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`." - - # Class definition should raise a warning - with pytest.warns(RemovedInDRF311Warning) as w: - warnings.simplefilter('always') - - class CustomRouter(SimpleRouter): - def get_default_base_name(self, viewset): - return 'foo' + msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`." + with pytest.warns(RemovedInDRF311Warning) as w: + warnings.simplefilter('always') + class CustomRouter(SimpleRouter): + def get_default_base_name(self, viewset): + return 'foo' assert len(w) == 1 - assert str(w[0].message) == msg - - # Deprecated method implementation should still be called - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - - router = CustomRouter() - router.register('mock', MockViewSet) +assertstr(w[0].message)==msg +withwarnings.catch_warnings(record=True)asw: + warnings.simplefilter('always') + router = CustomRouter() + router.register('mock', MockViewSet) assert len(w) == 0 - assert router.registry == [ - ('mock', MockViewSet, 'foo'), - ] +assertrouter.registry==[('mock',MockViewSet,'foo'),] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 230f8f012..8af8d85bc 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -149,206 +149,20 @@ urlpatterns = [ @unittest.skipUnless(coreapi, 'coreapi is not installed') @override_settings(ROOT_URLCONF='tests.test_schemas') -class TestRouterGeneratedSchema(TestCase): - def test_anonymous_request(self): - client = APIClient() - response = client.get('/', HTTP_ACCEPT='application/coreapi+json') - assert response.status_code == 200 - expected = coreapi.Document( - url='http://testserver/', - title='Example API', - content={ - 'example': { - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[ - coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), - coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'custom_list_action': coreapi.Link( - url='/example/custom_list_action/', - action='get' - ), - 'custom_list_action_multiple_methods': { - 'read': coreapi.Link( - url='/example/custom_list_action_multiple_methods/', - action='get', - description='Custom description.', - ) - }, - 'documented_custom_action': { - 'read': coreapi.Link( - url='/example/documented_custom_action/', - action='get', - description='A description of the get method on the custom action.', - ) - }, - 'read': coreapi.Link( - url='/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ) - } - } - ) - assert response.data == expected +def test_anonymous_request(self): + client = APIClient() + response = client.get('/', HTTP_ACCEPT='application/coreapi+json') + assert response.status_code == 200 + expected = coreapi.Document( url='http://testserver/', title='Example API', content={ 'example': { 'list': coreapi.Link( url='/example/', action='get', fields=[ coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'custom_list_action': coreapi.Link( url='/example/custom_list_action/', action='get' ), 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='get', description='Custom description.', ) }, 'documented_custom_action': { 'read': coreapi.Link( url='/example/documented_custom_action/', action='get', description='A description of the get method on the custom action.', ) }, 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ) } } ) + assert response.data == expected def test_authenticated_request(self): - client = APIClient() - client.force_authenticate(MockUser()) - response = client.get('/', HTTP_ACCEPT='application/coreapi+json') - assert response.status_code == 200 - expected = coreapi.Document( - url='http://testserver/', - title='Example API', - content={ - 'example': { - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[ - coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), - coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'create': coreapi.Link( - url='/example/', - action='post', - encoding='application/json', - fields=[ - coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) - ] - ), - 'read': coreapi.Link( - url='/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'custom_action': coreapi.Link( - url='/example/{id}/custom_action/', - action='post', - encoding='application/json', - description='A description of custom action.', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('c', required=True, location='form', schema=coreschema.String(title='C')), - coreapi.Field('d', required=False, location='form', schema=coreschema.String(title='D')), - ] - ), - 'custom_action_with_dict_field': coreapi.Link( - url='/example/{id}/custom_action_with_dict_field/', - action='post', - encoding='application/json', - description='A custom action using a dict field in the serializer.', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('a', required=True, location='form', schema=coreschema.Object(title='A')), - ] - ), - 'custom_action_with_list_fields': coreapi.Link( - url='/example/{id}/custom_action_with_list_fields/', - action='post', - encoding='application/json', - description='A custom action using both list field and list serializer in the serializer.', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('a', required=True, location='form', schema=coreschema.Array(title='A', items=coreschema.Integer())), - coreapi.Field('b', required=True, location='form', schema=coreschema.Array(title='B', items=coreschema.String())), - ] - ), - 'custom_list_action': coreapi.Link( - url='/example/custom_list_action/', - action='get' - ), - 'custom_list_action_multiple_methods': { - 'read': coreapi.Link( - url='/example/custom_list_action_multiple_methods/', - action='get', - description='Custom description.', - ), - 'create': coreapi.Link( - url='/example/custom_list_action_multiple_methods/', - action='post', - description='Custom description.', - ), - 'delete': coreapi.Link( - url='/example/custom_list_action_multiple_methods/', - action='delete', - description='Deletion description.', - ), - }, - 'documented_custom_action': { - 'read': coreapi.Link( - url='/example/documented_custom_action/', - action='get', - description='A description of the get method on the custom action.', - ), - 'create': coreapi.Link( - url='/example/documented_custom_action/', - action='post', - description='A description of the post method on the custom action.', - encoding='application/json', - fields=[ - coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) - ] - ), - 'update': coreapi.Link( - url='/example/documented_custom_action/', - action='put', - description='A description of the put method on the custom action from mapping.', - encoding='application/json', - fields=[ - coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) - ] - ), - }, - 'update': coreapi.Link( - url='/example/{id}/', - action='put', - encoding='application/json', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description=('A field description'))), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'partial_update': coreapi.Link( - url='/example/{id}/', - action='patch', - encoding='application/json', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('a', required=False, location='form', schema=coreschema.String(title='A', description='A field description')), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'delete': coreapi.Link( - url='/example/{id}/', - action='delete', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ) - } - } - ) - assert response.data == expected + client = APIClient() + client.force_authenticate(MockUser()) + response = client.get('/', HTTP_ACCEPT='application/coreapi+json') + assert response.status_code == 200 + expected = coreapi.Document( url='http://testserver/', title='Example API', content={ 'example': { 'list': coreapi.Link( url='/example/', action='get', fields=[ coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'create': coreapi.Link( url='/example/', action='post', encoding='application/json', fields=[ coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) ] ), 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'custom_action': coreapi.Link( url='/example/{id}/custom_action/', action='post', encoding='application/json', description='A description of custom action.', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('c', required=True, location='form', schema=coreschema.String(title='C')), coreapi.Field('d', required=False, location='form', schema=coreschema.String(title='D')), ] ), 'custom_action_with_dict_field': coreapi.Link( url='/example/{id}/custom_action_with_dict_field/', action='post', encoding='application/json', description='A custom action using a dict field in the serializer.', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('a', required=True, location='form', schema=coreschema.Object(title='A')), ] ), 'custom_action_with_list_fields': coreapi.Link( url='/example/{id}/custom_action_with_list_fields/', action='post', encoding='application/json', description='A custom action using both list field and list serializer in the serializer.', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('a', required=True, location='form', schema=coreschema.Array(title='A', items=coreschema.Integer())), coreapi.Field('b', required=True, location='form', schema=coreschema.Array(title='B', items=coreschema.String())), ] ), 'custom_list_action': coreapi.Link( url='/example/custom_list_action/', action='get' ), 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='get', description='Custom description.', ), 'create': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='post', description='Custom description.', ), 'delete': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='delete', description='Deletion description.', ), }, 'documented_custom_action': { 'read': coreapi.Link( url='/example/documented_custom_action/', action='get', description='A description of the get method on the custom action.', ), 'create': coreapi.Link( url='/example/documented_custom_action/', action='post', description='A description of the post method on the custom action.', encoding='application/json', fields=[ coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) ] ), 'update': coreapi.Link( url='/example/documented_custom_action/', action='put', description='A description of the put method on the custom action from mapping.', encoding='application/json', fields=[ coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) ] ), }, 'update': coreapi.Link( url='/example/{id}/', action='put', encoding='application/json', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description=('A field description'))), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'partial_update': coreapi.Link( url='/example/{id}/', action='patch', encoding='application/json', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('a', required=False, location='form', schema=coreschema.String(title='A', description='A field description')), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'delete': coreapi.Link( url='/example/{id}/', action='delete', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ) } } ) + assert response.data == expected class DenyAllUsingHttp404(permissions.BasePermission): @@ -400,260 +214,84 @@ class ExampleDetailView(APIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') -class TestSchemaGenerator(TestCase): - def setUp(self): - self.patterns = [ - url(r'^example/?$', ExampleListView.as_view()), - url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), - ] +def setUp(self): + self.patterns = [ url(r'^example/?$', ExampleListView.as_view()), url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()), url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): - """ + """ Ensure that schema generation works for APIView classes. """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/example/', - action='post', - fields=[] - ), - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[] - ), - 'read': coreapi.Link( - url='/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ), - 'sub': { - 'list': coreapi.Link( - url='/example/{id}/sub/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ) - } - } - } - ) - assert schema == expected + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/example/', action='post', fields=[] ), 'list': coreapi.Link( url='/example/', action='get', fields=[] ), 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ), 'sub': { 'list': coreapi.Link( url='/example/{id}/sub/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ) } } } ) + assert schema == expected @unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(path, 'needs Django 2') -class TestSchemaGeneratorDjango2(TestCase): - def setUp(self): - self.patterns = [ - path('example/', ExampleListView.as_view()), - path('example//', ExampleDetailView.as_view()), - path('example//sub/', ExampleDetailView.as_view()), - ] +def setUp(self): + self.patterns = [ path('example/', ExampleListView.as_view()), path('example//', ExampleDetailView.as_view()), path('example//sub/', ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): - """ + """ Ensure that schema generation works for APIView classes. """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/example/', - action='post', - fields=[] - ), - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[] - ), - 'read': coreapi.Link( - url='/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ), - 'sub': { - 'list': coreapi.Link( - url='/example/{id}/sub/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ) - } - } - } - ) - assert schema == expected + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/example/', action='post', fields=[] ), 'list': coreapi.Link( url='/example/', action='get', fields=[] ), 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ), 'sub': { 'list': coreapi.Link( url='/example/{id}/sub/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ) } } } ) + assert schema == expected @unittest.skipUnless(coreapi, 'coreapi is not installed') -class TestSchemaGeneratorNotAtRoot(TestCase): - def setUp(self): - self.patterns = [ - url(r'^api/v1/example/?$', ExampleListView.as_view()), - url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), - ] +def setUp(self): + self.patterns = [ url(r'^api/v1/example/?$', ExampleListView.as_view()), url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): - """ + """ Ensure that schema generation with an API that is not at the URL root continues to use correct structure for link keys. """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/api/v1/example/', - action='post', - fields=[] - ), - 'list': coreapi.Link( - url='/api/v1/example/', - action='get', - fields=[] - ), - 'read': coreapi.Link( - url='/api/v1/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ), - 'sub': { - 'list': coreapi.Link( - url='/api/v1/example/{id}/sub/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ) - } - } - } - ) - assert schema == expected + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/api/v1/example/', action='post', fields=[] ), 'list': coreapi.Link( url='/api/v1/example/', action='get', fields=[] ), 'read': coreapi.Link( url='/api/v1/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ), 'sub': { 'list': coreapi.Link( url='/api/v1/example/{id}/sub/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ) } } } ) + assert schema == expected @unittest.skipUnless(coreapi, 'coreapi is not installed') -class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): - def setUp(self): - router = DefaultRouter() - router.register('example1', MethodLimitedViewSet, basename='example1') - self.patterns = [ - url(r'^', include(router.urls)) - ] +def setUp(self): + router = DefaultRouter() + router.register('example1', MethodLimitedViewSet, basename='example1') + self.patterns = [ url(r'^', include(router.urls)) ] def test_schema_for_regular_views(self): - """ + """ Ensure that schema generation works for ViewSet classes with method limitation by Django CBV's http_method_names attribute """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - request = factory.get('/example1/') - schema = generator.get_schema(Request(request)) - - expected = coreapi.Document( - url='http://testserver/example1/', - title='Example API', - content={ - 'example1': { - 'list': coreapi.Link( - url='/example1/', - action='get', - fields=[ - coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), - coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'custom_list_action': coreapi.Link( - url='/example1/custom_list_action/', - action='get' - ), - 'custom_list_action_multiple_methods': { - 'read': coreapi.Link( - url='/example1/custom_list_action_multiple_methods/', - action='get', - description='Custom description.', - ) - }, - 'documented_custom_action': { - 'read': coreapi.Link( - url='/example1/documented_custom_action/', - action='get', - description='A description of the get method on the custom action.', - ), - }, - 'read': coreapi.Link( - url='/example1/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ) - } - } - ) - assert schema == expected + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + request = factory.get('/example1/') + schema = generator.get_schema(Request(request)) + expected = coreapi.Document( url='http://testserver/example1/', title='Example API', content={ 'example1': { 'list': coreapi.Link( url='/example1/', action='get', fields=[ coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'custom_list_action': coreapi.Link( url='/example1/custom_list_action/', action='get' ), 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example1/custom_list_action_multiple_methods/', action='get', description='Custom description.', ) }, 'documented_custom_action': { 'read': coreapi.Link( url='/example1/documented_custom_action/', action='get', description='A description of the get method on the custom action.', ), }, 'read': coreapi.Link( url='/example1/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ) } } ) + assert schema == expected @unittest.skipUnless(coreapi, 'coreapi is not installed') -class TestSchemaGeneratorWithRestrictedViewSets(TestCase): - def setUp(self): - router = DefaultRouter() - router.register('example1', Http404ExampleViewSet, basename='example1') - router.register('example2', PermissionDeniedExampleViewSet, basename='example2') - self.patterns = [ - url('^example/?$', ExampleListView.as_view()), - url(r'^', include(router.urls)) - ] +def setUp(self): + router = DefaultRouter() + router.register('example1', Http404ExampleViewSet, basename='example1') + router.register('example2', PermissionDeniedExampleViewSet, basename='example2') + self.patterns = [ url('^example/?$', ExampleListView.as_view()), url(r'^', include(router.urls)) ] def test_schema_for_regular_views(self): - """ + """ Ensure that schema generation works for ViewSet classes with permission classes raising exceptions. """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - request = factory.get('/') - schema = generator.get_schema(Request(request)) - expected = coreapi.Document( - url='http://testserver/', - title='Example API', - content={ - 'example': { - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[] - ), - }, - } - ) - assert schema == expected + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + request = factory.get('/') + schema = generator.get_schema(Request(request)) + expected = coreapi.Document( url='http://testserver/', title='Example API', content={ 'example': { 'list': coreapi.Link( url='/example/', action='get', fields=[] ), }, } ) + assert schema == expected class ForeignKeySourceSerializer(serializers.ModelSerializer): @@ -668,37 +306,17 @@ class ForeignKeySourceView(generics.CreateAPIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') -class TestSchemaGeneratorWithForeignKey(TestCase): - def setUp(self): - self.patterns = [ - url(r'^example/?$', ForeignKeySourceView.as_view()), - ] +def setUp(self): + self.patterns = [ url(r'^example/?$', ForeignKeySourceView.as_view()), ] def test_schema_for_regular_views(self): - """ + """ Ensure that AutoField foreign keys are output as Integer. """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/example/', - action='post', - encoding='application/json', - fields=[ - coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')), - coreapi.Field('target', required=True, location='form', schema=coreschema.Integer(description='Target', title='Target')), - ] - ) - } - } - ) - assert schema == expected + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/example/', action='post', encoding='application/json', fields=[ coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')), coreapi.Field('target', required=True, location='form', schema=coreschema.Integer(description='Target', title='Target')), ] ) } } ) + assert schema == expected class ManyToManySourceSerializer(serializers.ModelSerializer): @@ -713,48 +331,24 @@ class ManyToManySourceView(generics.CreateAPIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') -class TestSchemaGeneratorWithManyToMany(TestCase): - def setUp(self): - self.patterns = [ - url(r'^example/?$', ManyToManySourceView.as_view()), - ] +def setUp(self): + self.patterns = [ url(r'^example/?$', ManyToManySourceView.as_view()), ] def test_schema_for_regular_views(self): - """ + """ Ensure that AutoField many to many fields are output as Integer. """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/example/', - action='post', - encoding='application/json', - fields=[ - coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')), - coreapi.Field('targets', required=True, location='form', schema=coreschema.Array(title='Targets', items=coreschema.Integer())), - ] - ) - } - } - ) - assert schema == expected + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/example/', action='post', encoding='application/json', fields=[ coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')), coreapi.Field('targets', required=True, location='form', schema=coreschema.Array(title='Targets', items=coreschema.Integer())), ] ) } } ) + assert schema == expected @unittest.skipUnless(coreapi, 'coreapi is not installed') -class Test4605Regression(TestCase): - def test_4605_regression(self): - generator = SchemaGenerator() - prefix = generator.determine_path_prefix([ - '/api/v1/items/', - '/auth/convert-token/' - ]) - assert prefix == '/' +def test_4605_regression(self): + generator = SchemaGenerator() + prefix = generator.determine_path_prefix([ '/api/v1/items/', '/auth/convert-token/' ]) + assert prefix == '/' class CustomViewInspector(AutoSchema): @@ -762,213 +356,113 @@ class CustomViewInspector(AutoSchema): pass -class TestAutoSchema(TestCase): - def test_apiview_schema_descriptor(self): - view = APIView() - assert hasattr(view, 'schema') - assert isinstance(view.schema, AutoSchema) +def test_apiview_schema_descriptor(self): + view = APIView() + assert hasattr(view, 'schema') + assert isinstance(view.schema, AutoSchema) def test_set_custom_inspector_class_on_view(self): - class CustomView(APIView): - schema = CustomViewInspector() + class CustomView(APIView): + schema = CustomViewInspector() view = CustomView() - assert isinstance(view.schema, CustomViewInspector) +assertisinstance(view.schema,CustomViewInspector) def test_set_custom_inspector_class_via_settings(self): - with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): - view = APIView() - assert isinstance(view.schema, CustomViewInspector) + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): + view = APIView() + assert isinstance(view.schema, CustomViewInspector) def test_get_link_requires_instance(self): - descriptor = APIView.schema # Accessed from class - with pytest.raises(AssertionError): - descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? + descriptor = APIView.schema # Accessed from class + with pytest.raises(AssertionError): + descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_update_fields(self): - """ +deftest_update_fields(self): + """ That updating fields by-name helper is correct Recall: `update_fields(fields, update_with)` """ - schema = AutoSchema() - fields = [] - - # Adds a field... - fields = schema.update_fields(fields, [ - coreapi.Field( - "my_field", - required=True, - location="path", - schema=coreschema.String() - ), - ]) - - assert len(fields) == 1 - assert fields[0].name == "my_field" - - # Replaces a field... - fields = schema.update_fields(fields, [ - coreapi.Field( - "my_field", - required=False, - location="path", - schema=coreschema.String() - ), - ]) - - assert len(fields) == 1 - assert fields[0].required is False + schema = AutoSchema() + fields = [] + fields = schema.update_fields(fields, [ coreapi.Field( "my_field", required=True, location="path", schema=coreschema.String() ), ]) + assert len(fields) == 1 + assert fields[0].name == "my_field" + fields = schema.update_fields(fields, [ coreapi.Field( "my_field", required=False, location="path", schema=coreschema.String() ), ]) + assert len(fields) == 1 + assert fields[0].required is False @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_get_manual_fields(self): - """That get_manual_fields is applied during get_link""" - - class CustomView(APIView): - schema = AutoSchema(manual_fields=[ - coreapi.Field( - "my_extra_field", - required=True, - location="path", - schema=coreschema.String() - ), - ]) +deftest_get_manual_fields(self): + """That get_manual_fields is applied during get_link""" + class CustomView(APIView): + schema = AutoSchema(manual_fields=[ coreapi.Field( "my_extra_field", required=True, location="path", schema=coreschema.String() ), ]) view = CustomView() - link = view.schema.get_link('/a/url/{id}/', 'GET', '') - fields = link.fields - - assert len(fields) == 2 - assert "my_extra_field" in [f.name for f in fields] +link=view.schema.get_link('/a/url/{id}/','GET','') +fields=link.fields +assertlen(fields)==2 +assert"my_extra_field"in[f.nameforfinfields] @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_viewset_action_with_schema(self): - class CustomViewSet(GenericViewSet): - @action(detail=True, schema=AutoSchema(manual_fields=[ - coreapi.Field( - "my_extra_field", - required=True, - location="path", - schema=coreschema.String() - ), - ])) - def extra_action(self, pk, **kwargs): - pass +deftest_viewset_action_with_schema(self): + class CustomViewSet(GenericViewSet): + @action(detail=True, schema=AutoSchema(manual_fields=[ coreapi.Field( "my_extra_field", required=True, location="path", schema=coreschema.String() ), ])) + def extra_action(self, pk, **kwargs): + pass router = SimpleRouter() - router.register(r'detail', CustomViewSet, basename='detail') - - generator = SchemaGenerator() - view = generator.create_view(router.urls[0].callback, 'GET') - link = view.schema.get_link('/a/url/{id}/', 'GET', '') - fields = link.fields - - assert len(fields) == 2 - assert "my_extra_field" in [f.name for f in fields] +router.register(r'detail',CustomViewSet,basename='detail') +generator=SchemaGenerator() +view=generator.create_view(router.urls[0].callback,'GET') +link=view.schema.get_link('/a/url/{id}/','GET','') +fields=link.fields +assertlen(fields)==2 +assert"my_extra_field"in[f.nameforfinfields] @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_viewset_action_with_null_schema(self): - class CustomViewSet(GenericViewSet): - @action(detail=True, schema=None) - def extra_action(self, pk, **kwargs): - pass +deftest_viewset_action_with_null_schema(self): + class CustomViewSet(GenericViewSet): + @action(detail=True, schema=None) + def extra_action(self, pk, **kwargs): + pass router = SimpleRouter() - router.register(r'detail', CustomViewSet, basename='detail') - - generator = SchemaGenerator() - view = generator.create_view(router.urls[0].callback, 'GET') - assert view.schema is None +router.register(r'detail',CustomViewSet,basename='detail') +generator=SchemaGenerator() +view=generator.create_view(router.urls[0].callback,'GET') +assertview.schemaisNone @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_view_with_manual_schema(self): +deftest_view_with_manual_schema(self): - path = '/example' - method = 'get' - base_url = None - - fields = [ - coreapi.Field( - "first_field", - required=True, - location="path", - schema=coreschema.String() - ), - coreapi.Field( - "second_field", - required=True, - location="path", - schema=coreschema.String() - ), - coreapi.Field( - "third_field", - required=True, - location="path", - schema=coreschema.String() - ), - ] - description = "A test endpoint" - - class CustomView(APIView): - """ + path = '/example' + method = 'get' + base_url = None + fields = [ coreapi.Field( "first_field", required=True, location="path", schema=coreschema.String() ), coreapi.Field( "second_field", required=True, location="path", schema=coreschema.String() ), coreapi.Field( "third_field", required=True, location="path", schema=coreschema.String() ), ] + description = "A test endpoint" + class CustomView(APIView): + """ ManualSchema takes list of fields for endpoint. - Provides url and action, which are always dynamic """ - schema = ManualSchema(fields, description) + schema = ManualSchema(fields, description) - expected = coreapi.Link( - url=path, - action=method, - fields=fields, - description=description - ) - - view = CustomView() - link = view.schema.get_link(path, method, base_url) - assert link == expected + expected = coreapi.Link(url=path,action=method,fields=fields,description=description) +view=CustomView() +link=view.schema.get_link(path,method,base_url) +assertlink==expected @unittest.skipUnless(coreschema, 'coreschema is not installed') - def test_field_to_schema(self): - label = 'Test label' - help_text = 'This is a helpful test text' - - cases = [ - # tuples are ([field], [expected schema]) - # TODO: Add remaining cases - ( - serializers.BooleanField(label=label, help_text=help_text), - coreschema.Boolean(title=label, description=help_text) - ), - ( - serializers.DecimalField(1000, 1000, label=label, help_text=help_text), - coreschema.Number(title=label, description=help_text) - ), - ( - serializers.FloatField(label=label, help_text=help_text), - coreschema.Number(title=label, description=help_text) - ), - ( - serializers.IntegerField(label=label, help_text=help_text), - coreschema.Integer(title=label, description=help_text) - ), - ( - serializers.DateField(label=label, help_text=help_text), - coreschema.String(title=label, description=help_text, format='date') - ), - ( - serializers.DateTimeField(label=label, help_text=help_text), - coreschema.String(title=label, description=help_text, format='date-time') - ), - ( - serializers.JSONField(label=label, help_text=help_text), - coreschema.Object(title=label, description=help_text) - ), - ] - - for case in cases: - self.assertEqual(field_to_schema(case[0]), case[1]) +deftest_field_to_schema(self): + label = 'Test label' + help_text = 'This is a helpful test text' + cases = [ ( serializers.BooleanField(label=label, help_text=help_text), coreschema.Boolean(title=label, description=help_text) ), ( serializers.DecimalField(1000, 1000, label=label, help_text=help_text), coreschema.Number(title=label, description=help_text) ), ( serializers.FloatField(label=label, help_text=help_text), coreschema.Number(title=label, description=help_text) ), ( serializers.IntegerField(label=label, help_text=help_text), coreschema.Integer(title=label, description=help_text) ), ( serializers.DateField(label=label, help_text=help_text), coreschema.String(title=label, description=help_text, format='date') ), ( serializers.DateTimeField(label=label, help_text=help_text), coreschema.String(title=label, description=help_text, format='date-time') ), ( serializers.JSONField(label=label, help_text=help_text), coreschema.Object(title=label, description=help_text) ), ] + for case in cases: + assert field_to_schema(case[0]) == case[1] def test_docstring_is_not_stripped_by_get_description(): @@ -1026,56 +520,33 @@ def included_fbv(request): @unittest.skipUnless(coreapi, 'coreapi is not installed') -class SchemaGenerationExclusionTests(TestCase): - def setUp(self): - self.patterns = [ - url('^excluded-cbv/$', ExcludedAPIView.as_view()), - url('^excluded-fbv/$', excluded_fbv), - url('^included-fbv/$', included_fbv), - ] +def setUp(self): + self.patterns = [ url('^excluded-cbv/$', ExcludedAPIView.as_view()), url('^excluded-fbv/$', excluded_fbv), url('^included-fbv/$', included_fbv), ] def test_schema_generator_excludes_correctly(self): - """Schema should not include excluded views""" - generator = SchemaGenerator(title='Exclusions', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Exclusions', - content={ - 'included-fbv': { - 'list': coreapi.Link(url='/included-fbv/', action='get') - } - } - ) - - assert len(schema.data) == 1 - assert 'included-fbv' in schema.data - assert schema == expected + """Schema should not include excluded views""" + generator = SchemaGenerator(title='Exclusions', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( url='', title='Exclusions', content={ 'included-fbv': { 'list': coreapi.Link(url='/included-fbv/', action='get') } } ) + assert len(schema.data) == 1 + assert 'included-fbv' in schema.data + assert schema == expected def test_endpoint_enumerator_excludes_correctly(self): - """It is responsibility of EndpointEnumerator to exclude views""" - inspector = EndpointEnumerator(self.patterns) - endpoints = inspector.get_api_endpoints() - - assert len(endpoints) == 1 - path, method, callback = endpoints[0] - assert path == '/included-fbv/' + """It is responsibility of EndpointEnumerator to exclude views""" + inspector = EndpointEnumerator(self.patterns) + endpoints = inspector.get_api_endpoints() + assert len(endpoints) == 1 + path, method, callback = endpoints[0] + assert path == '/included-fbv/' def test_should_include_endpoint_excludes_correctly(self): - """This is the specific method that should handle the exclusion""" - inspector = EndpointEnumerator(self.patterns) - - # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test - pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback) - for pattern in self.patterns] - - should_include = [ - inspector.should_include_endpoint(*pair) for pair in pairs - ] - - expected = [False, False, True] - - assert should_include == expected + """This is the specific method that should handle the exclusion""" + inspector = EndpointEnumerator(self.patterns) + pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback) for pattern in self.patterns] + should_include = [ inspector.should_include_endpoint(*pair) for pair in pairs ] + expected = [False, False, True] + assert should_include == expected @api_view(["GET"]) @@ -1118,126 +589,63 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class TestURLNamingCollisions(TestCase): - """ +""" Ref: https://github.com/encode/django-rest-framework/issues/4704 """ - def test_manually_routing_nested_routes(self): - patterns = [ - url(r'^test', simple_fbv), - url(r'^test/list/', simple_fbv), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - schema = generator.get_schema() - - expected = coreapi.Document( - url='', - title='Naming Colisions', - content={ - 'test': { - 'list': { - 'list': coreapi.Link(url='/test/list/', action='get') - }, - 'list_0': coreapi.Link(url='/test', action='get') - } - } - ) - - assert expected == schema +deftest_manually_routing_nested_routes(self): + patterns = [ url(r'^test', simple_fbv), url(r'^test/list/', simple_fbv), ] + generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) + schema = generator.get_schema() + expected = coreapi.Document( url='', title='Naming Colisions', content={ 'test': { 'list': { 'list': coreapi.Link(url='/test/list/', action='get') }, 'list_0': coreapi.Link(url='/test', action='get') } } ) + assert expected == schema def _verify_cbv_links(self, loc, url, methods=None, suffixes=None): - if methods is None: - methods = ('read', 'update', 'partial_update', 'delete') + if methods is None: + methods = ('read', 'update', 'partial_update', 'delete') if suffixes is None: - suffixes = (None for m in methods) + suffixes = (None for m in methods) for method, suffix in zip(methods, suffixes): - if suffix is not None: - key = '{}_{}'.format(method, suffix) + if suffix is not None: + key = '{}_{}'.format(method, suffix) else: - key = method + key = method assert loc[key].url == url def test_manually_routing_generic_view(self): - patterns = [ - url(r'^test', NamingCollisionView.as_view()), - url(r'^test/retrieve/', NamingCollisionView.as_view()), - url(r'^test/update/', NamingCollisionView.as_view()), - - # Fails with method names: - url(r'^test/get/', NamingCollisionView.as_view()), - url(r'^test/put/', NamingCollisionView.as_view()), - url(r'^test/delete/', NamingCollisionView.as_view()), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - - schema = generator.get_schema() - - self._verify_cbv_links(schema['test']['delete'], '/test/delete/') - self._verify_cbv_links(schema['test']['put'], '/test/put/') - self._verify_cbv_links(schema['test']['get'], '/test/get/') - self._verify_cbv_links(schema['test']['update'], '/test/update/') - self._verify_cbv_links(schema['test']['retrieve'], '/test/retrieve/') - self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0')) + patterns = [ url(r'^test', NamingCollisionView.as_view()), url(r'^test/retrieve/', NamingCollisionView.as_view()), url(r'^test/update/', NamingCollisionView.as_view()), url(r'^test/get/', NamingCollisionView.as_view()), url(r'^test/put/', NamingCollisionView.as_view()), url(r'^test/delete/', NamingCollisionView.as_view()), ] + generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) + schema = generator.get_schema() + self._verify_cbv_links(schema['test']['delete'], '/test/delete/') + self._verify_cbv_links(schema['test']['put'], '/test/put/') + self._verify_cbv_links(schema['test']['get'], '/test/get/') + self._verify_cbv_links(schema['test']['update'], '/test/update/') + self._verify_cbv_links(schema['test']['retrieve'], '/test/retrieve/') + self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0')) def test_from_router(self): - patterns = [ - url(r'from-router', include(naming_collisions_router.urls)), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - schema = generator.get_schema() - - # not important here - desc_0 = schema['detail']['detail_export'].description - desc_1 = schema['detail_0'].description - - expected = coreapi.Document( - url='', - title='Naming Colisions', - content={ - 'detail': { - 'detail_export': coreapi.Link( - url='/from-routercollision/detail/export/', - action='get', - description=desc_0) - }, - 'detail_0': coreapi.Link( - url='/from-routercollision/detail/', - action='get', - description=desc_1 - ) - } - ) - - assert schema == expected + patterns = [ url(r'from-router', include(naming_collisions_router.urls)), ] + generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) + schema = generator.get_schema() + desc_0 = schema['detail']['detail_export'].description + desc_1 = schema['detail_0'].description + expected = coreapi.Document( url='', title='Naming Colisions', content={ 'detail': { 'detail_export': coreapi.Link( url='/from-routercollision/detail/export/', action='get', description=desc_0) }, 'detail_0': coreapi.Link( url='/from-routercollision/detail/', action='get', description=desc_1 ) } ) + assert schema == expected def test_url_under_same_key_not_replaced(self): - patterns = [ - url(r'example/(?P\d+)/$', BasicNamingCollisionView.as_view()), - url(r'example/(?P\w+)/$', BasicNamingCollisionView.as_view()), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - schema = generator.get_schema() - - assert schema['example']['read'].url == '/example/{id}/' - assert schema['example']['read_0'].url == '/example/{slug}/' + patterns = [ url(r'example/(?P\d+)/$', BasicNamingCollisionView.as_view()), url(r'example/(?P\w+)/$', BasicNamingCollisionView.as_view()), ] + generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) + schema = generator.get_schema() + assert schema['example']['read'].url == '/example/{id}/' + assert schema['example']['read_0'].url == '/example/{slug}/' def test_url_under_same_key_not_replaced_another(self): - patterns = [ - url(r'^test/list/', simple_fbv), - url(r'^test/(?P\d+)/list/', simple_fbv), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - schema = generator.get_schema() - - assert schema['test']['list']['list'].url == '/test/list/' - assert schema['test']['list']['list_0'].url == '/test/{id}/list/' + patterns = [ url(r'^test/list/', simple_fbv), url(r'^test/(?P\d+)/list/', simple_fbv), ] + generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) + schema = generator.get_schema() + assert schema['test']['list']['list'].url == '/test/list/' + assert schema['test']['list']['list_0'].url == '/test/{id}/list/' def test_is_list_view_recognises_retrieve_view_subclasses(): diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py index 0465578bb..689aa05bd 100644 --- a/tests/test_serializer_bulk_update.py +++ b/tests/test_serializer_bulk_update.py @@ -4,120 +4,66 @@ Tests to cover bulk create and update using serializers. from django.test import TestCase from rest_framework import serializers - - -class BulkCreateSerializerTests(TestCase): - """ +""" Creating multiple instances using serializers. """ - - def setUp(self): - class BookSerializer(serializers.Serializer): - id = serializers.IntegerField() - title = serializers.CharField(max_length=100) - author = serializers.CharField(max_length=100) +defsetUp(self): + class BookSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=100) + author = serializers.CharField(max_length=100) self.BookSerializer = BookSerializer def test_bulk_create_success(self): - """ + """ Correct bulk update serialization should return the input data. """ - - data = [ - { - 'id': 0, - 'title': 'The electric kool-aid acid test', - 'author': 'Tom Wolfe' - }, { - 'id': 1, - 'title': 'If this is a man', - 'author': 'Primo Levi' - }, { - 'id': 2, - 'title': 'The wind-up bird chronicle', - 'author': 'Haruki Murakami' - } - ] - - serializer = self.BookSerializer(data=data, many=True) - assert serializer.is_valid() is True - assert serializer.validated_data == data - assert serializer.errors == [] + data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 2, 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ] + serializer = self.BookSerializer(data=data, many=True) + assert serializer.is_valid() is True + assert serializer.validated_data == data + assert serializer.errors == [] def test_bulk_create_errors(self): - """ + """ Incorrect bulk create serialization should return errors. """ - - data = [ - { - 'id': 0, - 'title': 'The electric kool-aid acid test', - 'author': 'Tom Wolfe' - }, { - 'id': 1, - 'title': 'If this is a man', - 'author': 'Primo Levi' - }, { - 'id': 'foo', - 'title': 'The wind-up bird chronicle', - 'author': 'Haruki Murakami' - } - ] - expected_errors = [ - {}, - {}, - {'id': ['A valid integer is required.']} - ] - - serializer = self.BookSerializer(data=data, many=True) - assert serializer.is_valid() is False - assert serializer.errors == expected_errors - assert serializer.validated_data == [] + data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 'foo', 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ] + expected_errors = [ {}, {}, {'id': ['A valid integer is required.']} ] + serializer = self.BookSerializer(data=data, many=True) + assert serializer.is_valid() is False + assert serializer.errors == expected_errors + assert serializer.validated_data == [] def test_invalid_list_datatype(self): - """ + """ Data containing list of incorrect data type should return errors. """ - data = ['foo', 'bar', 'baz'] - serializer = self.BookSerializer(data=data, many=True) - assert serializer.is_valid() is False - - message = 'Invalid data. Expected a dictionary, but got str.' - expected_errors = [ - {'non_field_errors': [message]}, - {'non_field_errors': [message]}, - {'non_field_errors': [message]} - ] - - assert serializer.errors == expected_errors + data = ['foo', 'bar', 'baz'] + serializer = self.BookSerializer(data=data, many=True) + assert serializer.is_valid() is False + message = 'Invalid data. Expected a dictionary, but got str.' + expected_errors = [ {'non_field_errors': [message]}, {'non_field_errors': [message]}, {'non_field_errors': [message]} ] + assert serializer.errors == expected_errors def test_invalid_single_datatype(self): - """ + """ Data containing a single incorrect data type should return errors. """ - data = 123 - serializer = self.BookSerializer(data=data, many=True) - assert serializer.is_valid() is False - - expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} - - assert serializer.errors == expected_errors + data = 123 + serializer = self.BookSerializer(data=data, many=True) + assert serializer.is_valid() is False + expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} + assert serializer.errors == expected_errors def test_invalid_single_object(self): - """ + """ Data containing only a single object, instead of a list of objects should return errors. """ - data = { - 'id': 0, - 'title': 'The electric kool-aid acid test', - 'author': 'Tom Wolfe' - } - serializer = self.BookSerializer(data=data, many=True) - assert serializer.is_valid() is False - - expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} - - assert serializer.errors == expected_errors + data = { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' } + serializer = self.BookSerializer(data=data, many=True) + assert serializer.is_valid() is False + expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} + assert serializer.errors == expected_errors