diff --git a/.gitignore b/.gitignore
index 52d63d225..82e885ede 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,7 +2,7 @@
*.db
*~
.*
-.py.bak
+*.py.bak
/site/
diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py
index de04d2c06..4c88b3586 100644
--- a/tests/test_atomic_requests.py
+++ b/tests/test_atomic_requests.py
@@ -11,6 +11,7 @@ from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
from tests.models import BasicModel
+import pytest
factory = APIRequestFactory()
@@ -52,51 +53,49 @@ urlpatterns = (
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
-class DBTransactionTests(TestCase):
- def setUp(self):
- self.view = BasicView.as_view()
- connections.databases['default']['ATOMIC_REQUESTS'] = True
+def setUp(self):
+ self.view = BasicView.as_view()
+ connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self):
- connections.databases['default']['ATOMIC_REQUESTS'] = False
+ connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_no_exception_commit_transaction(self):
- request = factory.post('/')
-
- with self.assertNumQueries(1):
- response = self.view(request)
+ request = factory.post('/')
+ with self.assertNumQueries(1):
+ response = self.view(request)
assert not transaction.get_rollback()
- assert response.status_code == status.HTTP_200_OK
- assert BasicModel.objects.count() == 1
+assertresponse.status_code==status.HTTP_200_OK
+assertBasicModel.objects.count()==1
@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
-class DBTransactionErrorTests(TestCase):
- def setUp(self):
- self.view = ErrorView.as_view()
- connections.databases['default']['ATOMIC_REQUESTS'] = True
+def setUp(self):
+ self.view = ErrorView.as_view()
+ connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self):
- connections.databases['default']['ATOMIC_REQUESTS'] = False
+ connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_generic_exception_delegate_transaction_management(self):
- """
+ """
Transaction is eventually managed by outer-most transaction atomic
block. DRF do not try to interfere here.
We let django deal with the transaction when it will catch the Exception.
"""
- request = factory.post('/')
- with self.assertNumQueries(3):
+ request = factory.post('/')
+ with self.assertNumQueries(3):
# 1 - begin savepoint
# 2 - insert
# 3 - release savepoint
- with transaction.atomic():
- self.assertRaises(Exception, self.view, request)
- assert not transaction.get_rollback()
+ with transaction.atomic():
+ with pytest.raises(Exception):
+ self.view(request)
+ assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1
@@ -104,30 +103,29 @@ class DBTransactionErrorTests(TestCase):
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
-class DBTransactionAPIExceptionTests(TestCase):
- def setUp(self):
- self.view = APIExceptionView.as_view()
- connections.databases['default']['ATOMIC_REQUESTS'] = True
+def setUp(self):
+ self.view = APIExceptionView.as_view()
+ connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self):
- connections.databases['default']['ATOMIC_REQUESTS'] = False
+ connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_api_exception_rollback_transaction(self):
- """
+ """
Transaction is rollbacked by our transaction atomic block.
"""
- request = factory.post('/')
- num_queries = 4 if connection.features.can_release_savepoints else 3
- with self.assertNumQueries(num_queries):
+ request = factory.post('/')
+ num_queries = 4 if connection.features.can_release_savepoints else 3
+ with self.assertNumQueries(num_queries):
# 1 - begin savepoint
# 2 - insert
# 3 - rollback savepoint
# 4 - release savepoint
- with transaction.atomic():
- response = self.view(request)
- assert transaction.get_rollback()
+ with transaction.atomic():
+ response = self.view(request)
+ assert transaction.get_rollback()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
- assert BasicModel.objects.count() == 0
+assertBasicModel.objects.count()==0
@unittest.skipUnless(
diff --git a/tests/test_authtoken.py b/tests/test_authtoken.py
index 036e317ef..f22551b26 100644
--- a/tests/test_authtoken.py
+++ b/tests/test_authtoken.py
@@ -13,74 +13,68 @@ from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.exceptions import ValidationError
-
-class AuthTokenTests(TestCase):
-
- def setUp(self):
- self.site = site
- self.user = User.objects.create_user(username='test_user')
- self.token = Token.objects.create(key='test token', user=self.user)
+def setUp(self):
+ self.site = site
+ self.user = User.objects.create_user(username='test_user')
+ self.token = Token.objects.create(key='test token', user=self.user)
def test_model_admin_displayed_fields(self):
- mock_request = object()
- token_admin = TokenAdmin(self.token, self.site)
- assert token_admin.get_fields(mock_request) == ('user',)
+ mock_request = object()
+ token_admin = TokenAdmin(self.token, self.site)
+ assert token_admin.get_fields(mock_request) == ('user',)
def test_token_string_representation(self):
- assert str(self.token) == 'test token'
+ assert str(self.token) == 'test token'
def test_validate_raise_error_if_no_credentials_provided(self):
- with pytest.raises(ValidationError):
- AuthTokenSerializer().validate({})
+ with pytest.raises(ValidationError):
+ AuthTokenSerializer().validate({})
def test_whitespace_in_password(self):
- data = {'username': self.user.username, 'password': 'test pass '}
- self.user.set_password(data['password'])
- self.user.save()
- assert AuthTokenSerializer(data=data).is_valid()
+ data = {'username': self.user.username, 'password': 'test pass '}
+ self.user.set_password(data['password'])
+ self.user.save()
+ assert AuthTokenSerializer(data=data).is_valid()
-class AuthTokenCommandTests(TestCase):
- def setUp(self):
- self.site = site
- self.user = User.objects.create_user(username='test_user')
+def setUp(self):
+ self.site = site
+ self.user = User.objects.create_user(username='test_user')
def test_command_create_user_token(self):
- token = AuthTokenCommand().create_user_token(self.user.username, False)
- assert token is not None
- token_saved = Token.objects.first()
- assert token.key == token_saved.key
+ token = AuthTokenCommand().create_user_token(self.user.username, False)
+ assert token is not None
+ token_saved = Token.objects.first()
+ assert token.key == token_saved.key
def test_command_create_user_token_invalid_user(self):
- with pytest.raises(User.DoesNotExist):
- AuthTokenCommand().create_user_token('not_existing_user', False)
+ with pytest.raises(User.DoesNotExist):
+ AuthTokenCommand().create_user_token('not_existing_user', False)
def test_command_reset_user_token(self):
- AuthTokenCommand().create_user_token(self.user.username, False)
- first_token_key = Token.objects.first().key
- AuthTokenCommand().create_user_token(self.user.username, True)
- second_token_key = Token.objects.first().key
-
- assert first_token_key != second_token_key
+ AuthTokenCommand().create_user_token(self.user.username, False)
+ first_token_key = Token.objects.first().key
+ AuthTokenCommand().create_user_token(self.user.username, True)
+ second_token_key = Token.objects.first().key
+ assert first_token_key != second_token_key
def test_command_do_not_reset_user_token(self):
- AuthTokenCommand().create_user_token(self.user.username, False)
- first_token_key = Token.objects.first().key
- AuthTokenCommand().create_user_token(self.user.username, False)
- second_token_key = Token.objects.first().key
-
- assert first_token_key == second_token_key
+ AuthTokenCommand().create_user_token(self.user.username, False)
+ first_token_key = Token.objects.first().key
+ AuthTokenCommand().create_user_token(self.user.username, False)
+ second_token_key = Token.objects.first().key
+ assert first_token_key == second_token_key
def test_command_raising_error_for_invalid_user(self):
- out = StringIO()
- with pytest.raises(CommandError):
- call_command('drf_create_token', 'not_existing_user', stdout=out)
+ out = StringIO()
+ with pytest.raises(CommandError):
+ call_command('drf_create_token', 'not_existing_user', stdout=out)
def test_command_output(self):
- out = StringIO()
- call_command('drf_create_token', self.user.username, stdout=out)
- token_saved = Token.objects.first()
- self.assertIn('Generated token', out.getvalue())
- self.assertIn(self.user.username, out.getvalue())
- self.assertIn(token_saved.key, out.getvalue())
+ out = StringIO()
+ call_command('drf_create_token', self.user.username, stdout=out)
+ token_saved = Token.objects.first()
+ assert 'Generated token' in out.getvalue()
+ assert self.user.username in out.getvalue()
+ assert token_saved.key in out.getvalue()
diff --git a/tests/test_decorators.py b/tests/test_decorators.py
index bd30449e5..8b81ad461 100644
--- a/tests/test_decorators.py
+++ b/tests/test_decorators.py
@@ -17,307 +17,270 @@ from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView
-
-class DecoratorTestCase(TestCase):
-
- def setUp(self):
- self.factory = APIRequestFactory()
+def setUp(self):
+ self.factory = APIRequestFactory()
def _finalize_response(self, request, response, *args, **kwargs):
- response.request = request
- return APIView.finalize_response(self, request, response, *args, **kwargs)
+ response.request = request
+ return APIView.finalize_response(self, request, response, *args, **kwargs)
def test_api_view_incorrect(self):
- """
+ """
If @api_view is not applied correct, we should raise an assertion.
"""
+ @api_view
+ def view(request):
+ return Response()
- @api_view
+ request = self.factory.get('/')
+withpytest.raises(AssertionError):
+view(request)
+
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
+ with pytest.raises(AssertionError):
+ @api_view('GET')
def view(request):
return Response()
- request = self.factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- def test_api_view_incorrect_arguments(self):
- """
- If @api_view is missing arguments, we should raise an assertion.
- """
-
- with self.assertRaises(AssertionError):
- @api_view('GET')
- def view(request):
- return Response()
-
def test_calling_method(self):
- @api_view(['GET'])
- def view(request):
- return Response({})
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
request = self.factory.get('/')
- response = view(request)
- assert response.status_code == status.HTTP_200_OK
-
- request = self.factory.post('/')
- response = view(request)
- assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
+response=view(request)
+assertresponse.status_code==status.HTTP_200_OK
+request=self.factory.post('/')
+response=view(request)
+assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_put_method(self):
- @api_view(['GET', 'PUT'])
- def view(request):
- return Response({})
+ @api_view(['GET', 'PUT'])
+ def view(request):
+ return Response({})
request = self.factory.put('/')
- response = view(request)
- assert response.status_code == status.HTTP_200_OK
-
- request = self.factory.post('/')
- response = view(request)
- assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
+response=view(request)
+assertresponse.status_code==status.HTTP_200_OK
+request=self.factory.post('/')
+response=view(request)
+assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_patch_method(self):
- @api_view(['GET', 'PATCH'])
- def view(request):
- return Response({})
+ @api_view(['GET', 'PATCH'])
+ def view(request):
+ return Response({})
request = self.factory.patch('/')
- response = view(request)
- assert response.status_code == status.HTTP_200_OK
-
- request = self.factory.post('/')
- response = view(request)
- assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
+response=view(request)
+assertresponse.status_code==status.HTTP_200_OK
+request=self.factory.post('/')
+response=view(request)
+assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
def test_renderer_classes(self):
- @api_view(['GET'])
- @renderer_classes([JSONRenderer])
- def view(request):
- return Response({})
+ @api_view(['GET'])
+ @renderer_classes([JSONRenderer])
+ def view(request):
+ return Response({})
request = self.factory.get('/')
- response = view(request)
- assert isinstance(response.accepted_renderer, JSONRenderer)
+response=view(request)
+assertisinstance(response.accepted_renderer,JSONRenderer)
def test_parser_classes(self):
- @api_view(['GET'])
- @parser_classes([JSONParser])
- def view(request):
- assert len(request.parsers) == 1
- assert isinstance(request.parsers[0], JSONParser)
- return Response({})
+ @api_view(['GET'])
+ @parser_classes([JSONParser])
+ def view(request):
+ assert len(request.parsers) == 1
+ assert isinstance(request.parsers[0], JSONParser)
+ return Response({})
request = self.factory.get('/')
- view(request)
+view(request)
def test_authentication_classes(self):
- @api_view(['GET'])
- @authentication_classes([BasicAuthentication])
- def view(request):
- assert len(request.authenticators) == 1
- assert isinstance(request.authenticators[0], BasicAuthentication)
- return Response({})
+ @api_view(['GET'])
+ @authentication_classes([BasicAuthentication])
+ def view(request):
+ assert len(request.authenticators) == 1
+ assert isinstance(request.authenticators[0], BasicAuthentication)
+ return Response({})
request = self.factory.get('/')
- view(request)
+view(request)
def test_permission_classes(self):
- @api_view(['GET'])
- @permission_classes([IsAuthenticated])
- def view(request):
- return Response({})
+ @api_view(['GET'])
+ @permission_classes([IsAuthenticated])
+ def view(request):
+ return Response({})
request = self.factory.get('/')
- response = view(request)
- assert response.status_code == status.HTTP_403_FORBIDDEN
+response=view(request)
+assertresponse.status_code==status.HTTP_403_FORBIDDEN
def test_throttle_classes(self):
- class OncePerDayUserThrottle(UserRateThrottle):
- rate = '1/day'
+ class OncePerDayUserThrottle(UserRateThrottle):
+ rate = '1/day'
@api_view(['GET'])
- @throttle_classes([OncePerDayUserThrottle])
- def view(request):
- return Response({})
+@throttle_classes([OncePerDayUserThrottle])
+defview(request):
+ return Response({})
request = self.factory.get('/')
- response = view(request)
- assert response.status_code == status.HTTP_200_OK
-
- response = view(request)
- assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
+response=view(request)
+assertresponse.status_code==status.HTTP_200_OK
+response=view(request)
+assertresponse.status_code==status.HTTP_429_TOO_MANY_REQUESTS
def test_schema(self):
- """
+ """
Checks CustomSchema class is set on view
"""
- class CustomSchema(AutoSchema):
- pass
+ class CustomSchema(AutoSchema):
+ pass
@api_view(['GET'])
- @schema(CustomSchema())
- def view(request):
- return Response({})
+@schema(CustomSchema())
+defview(request):
+ return Response({})
assert isinstance(view.cls.schema, CustomSchema)
-class ActionDecoratorTestCase(TestCase):
- def test_defaults(self):
- @action(detail=True)
- def test_action(request):
- """Description"""
+def test_defaults(self):
+ @action(detail=True)
+ def test_action(request):
+ """Description"""
assert test_action.mapping == {'get': 'test_action'}
- assert test_action.detail is True
- assert test_action.url_path == 'test_action'
- assert test_action.url_name == 'test-action'
- assert test_action.kwargs == {
- 'name': 'Test action',
- 'description': 'Description',
- }
+asserttest_action.detailisTrue
+asserttest_action.url_path=='test_action'
+asserttest_action.url_name=='test-action'
+asserttest_action.kwargs=={'name':'Test action','description':'Description',}
def test_detail_required(self):
- with pytest.raises(AssertionError) as excinfo:
- @action()
- def test_action(request):
- raise NotImplementedError
+ with pytest.raises(AssertionError) as excinfo:
+ @action()
+ def test_action(request):
+ raise NotImplementedError
assert str(excinfo.value) == "@action() missing required argument: 'detail'"
def test_method_mapping_http_methods(self):
# All HTTP methods should be mappable
- @action(detail=False, methods=[])
- def test_action():
- raise NotImplementedError
+ @action(detail=False, methods=[])
+ def test_action():
+ raise NotImplementedError
for name in APIView.http_method_names:
- def method():
- raise NotImplementedError
+ def method():
+ raise NotImplementedError
method.__name__ = name
- getattr(test_action.mapping, name)(method)
+getattr(test_action.mapping,name)(method)
# ensure the mapping returns the correct method name
for name in APIView.http_method_names:
- assert test_action.mapping[name] == name
+ assert test_action.mapping[name] == name
def test_view_name_kwargs(self):
- """
+ """
'name' and 'suffix' are mutually exclusive kwargs used for generating
a view's display name.
"""
- # by default, generate name from method
- @action(detail=True)
- def test_action(request):
- raise NotImplementedError
+ @action(detail=True)
+ def test_action(request):
+ raise NotImplementedError
- assert test_action.kwargs == {
- 'description': None,
- 'name': 'Test action',
- }
+ assert test_action.kwargs == {'description':None,'name':'Test action',}
+@action(detail=True,name='test name')
+deftest_action(request):
+ raise NotImplementedError
- # name kwarg supersedes name generation
- @action(detail=True, name='test name')
- def test_action(request):
- raise NotImplementedError
+ assert test_action.kwargs == {'description':None,'name':'test name',}
+@action(detail=True,suffix='Suffix')
+deftest_action(request):
+ raise NotImplementedError
- assert test_action.kwargs == {
- 'description': None,
- 'name': 'test name',
- }
-
- # suffix kwarg supersedes name generation
- @action(detail=True, suffix='Suffix')
- def test_action(request):
- raise NotImplementedError
-
- assert test_action.kwargs == {
- 'description': None,
- 'suffix': 'Suffix',
- }
-
- # name + suffix is a conflict.
- with pytest.raises(TypeError) as excinfo:
- action(detail=True, name='test name', suffix='Suffix')
+ assert test_action.kwargs == {'description':None,'suffix':'Suffix',}
+withpytest.raises(TypeError)asexcinfo:
+ action(detail=True, name='test name', suffix='Suffix')
assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments."
def test_method_mapping(self):
- @action(detail=False)
- def test_action(request):
- raise NotImplementedError
+ @action(detail=False)
+ def test_action(request):
+ raise NotImplementedError
@test_action.mapping.post
- def test_action_post(request):
- raise NotImplementedError
+deftest_action_post(request):
+ raise NotImplementedError
# The secondary handler methods should not have the action attributes
for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']:
- assert hasattr(test_action, name) and not hasattr(test_action_post, name)
+ assert hasattr(test_action, name) and not hasattr(test_action_post, name)
def test_method_mapping_already_mapped(self):
- @action(detail=True)
- def test_action(request):
- raise NotImplementedError
+ @action(detail=True)
+ def test_action(request):
+ raise NotImplementedError
msg = "Method 'get' has already been mapped to '.test_action'."
- with self.assertRaisesMessage(AssertionError, msg):
- @test_action.mapping.get
- def test_action_get(request):
- raise NotImplementedError
+withself.assertRaisesMessage(AssertionError,msg):
+ @test_action.mapping.get
+ def test_action_get(request):
+ raise NotImplementedError
def test_method_mapping_overwrite(self):
- @action(detail=True)
+ @action(detail=True)
+ def test_action():
+ raise NotImplementedError
+
+ msg = ("Method mapping does not behave like the property decorator. You ""cannot use the same method name for each mapping declaration.")
+withself.assertRaisesMessage(AssertionError,msg):
+ @test_action.mapping.post
def test_action():
raise NotImplementedError
- msg = ("Method mapping does not behave like the property decorator. You "
- "cannot use the same method name for each mapping declaration.")
- with self.assertRaisesMessage(AssertionError, msg):
- @test_action.mapping.post
- def test_action():
- raise NotImplementedError
-
def test_detail_route_deprecation(self):
- with pytest.warns(RemovedInDRF310Warning) as record:
- @detail_route()
- def view(request):
- raise NotImplementedError
+ with pytest.warns(RemovedInDRF310Warning) as record:
+ @detail_route()
+ def view(request):
+ raise NotImplementedError
assert len(record) == 1
- assert str(record[0].message) == (
- "`detail_route` is deprecated and will be removed in "
- "3.10 in favor of `action`, which accepts a `detail` bool. Use "
- "`@action(detail=True)` instead."
- )
+assertstr(record[0].message)==("`detail_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=True)` instead.")
def test_list_route_deprecation(self):
- with pytest.warns(RemovedInDRF310Warning) as record:
- @list_route()
- def view(request):
- raise NotImplementedError
+ with pytest.warns(RemovedInDRF310Warning) as record:
+ @list_route()
+ def view(request):
+ raise NotImplementedError
assert len(record) == 1
- assert str(record[0].message) == (
- "`list_route` is deprecated and will be removed in "
- "3.10 in favor of `action`, which accepts a `detail` bool. Use "
- "`@action(detail=False)` instead."
- )
+assertstr(record[0].message)==("`list_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=False)` instead.")
def test_route_url_name_from_path(self):
# pre-3.8 behavior was to base the `url_name` off of the `url_path`
- with pytest.warns(RemovedInDRF310Warning):
- @list_route(url_path='foo_bar')
- def view(request):
- raise NotImplementedError
+ with pytest.warns(RemovedInDRF310Warning):
+ @list_route(url_path='foo_bar')
+ def view(request):
+ raise NotImplementedError
assert view.url_path == 'foo_bar'
- assert view.url_name == 'foo-bar'
+assertview.url_name=='foo-bar'
diff --git a/tests/test_description.py b/tests/test_description.py
index ae00fe4a9..a585eabb5 100644
--- a/tests/test_description.py
+++ b/tests/test_description.py
@@ -69,37 +69,34 @@ MARKED_DOWN_gte_21 = """
an example docstring
indented
%s"""
-
-
-class TestViewNamesAndDescriptions(TestCase):
- def test_view_name_uses_class_name(self):
- """
+def test_view_name_uses_class_name(self):
+ """
Ensure view names are based on the class name.
"""
- class MockView(APIView):
- pass
+ class MockView(APIView):
+ pass
assert MockView().get_view_name() == 'Mock'
def test_view_name_uses_name_attribute(self):
- class MockView(APIView):
- name = 'Foo'
+ class MockView(APIView):
+ name = 'Foo'
assert MockView().get_view_name() == 'Foo'
def test_view_name_uses_suffix_attribute(self):
- class MockView(APIView):
- suffix = 'List'
+ class MockView(APIView):
+ suffix = 'List'
assert MockView().get_view_name() == 'Mock List'
def test_view_name_preferences_name_over_suffix(self):
- class MockView(APIView):
- name = 'Foo'
- suffix = 'List'
+ class MockView(APIView):
+ name = 'Foo'
+ suffix = 'List'
assert MockView().get_view_name() == 'Foo'
def test_view_description_uses_docstring(self):
- """Ensure view descriptions are based on the docstring."""
- class MockView(APIView):
- """an example docstring
+ """Ensure view descriptions are based on the docstring."""
+ class MockView(APIView):
+ """an example docstring
====================
* list
@@ -124,64 +121,53 @@ class TestViewNamesAndDescriptions(TestCase):
assert MockView().get_view_description() == DESCRIPTION
def test_view_description_uses_description_attribute(self):
- class MockView(APIView):
- description = 'Foo'
+ class MockView(APIView):
+ description = 'Foo'
assert MockView().get_view_description() == 'Foo'
def test_view_description_allows_empty_description(self):
- class MockView(APIView):
- """Description."""
- description = ''
+ class MockView(APIView):
+ """Description."""
+ description = ''
assert MockView().get_view_description() == ''
def test_view_description_can_be_empty(self):
- """
+ """
Ensure that if a view has no docstring,
then it's description is the empty string.
"""
- class MockView(APIView):
- pass
+ class MockView(APIView):
+ pass
assert MockView().get_view_description() == ''
def test_view_description_can_be_promise(self):
- """
+ """
Ensure a view may have a docstring that is actually a lazily evaluated
class that can be converted to a string.
See: https://github.com/encode/django-rest-framework/issues/1708
"""
- # use a mock object instead of gettext_lazy to ensure that we can't end
- # up with a test case string in our l10n catalog
-
- class MockLazyStr:
- def __init__(self, string):
- self.s = string
+ class MockLazyStr:
+ def __init__(self, string):
+ self.s = string
def __str__(self):
- return self.s
+ return self.s
class MockView(APIView):
- __doc__ = MockLazyStr("a gettext string")
+ __doc__ = MockLazyStr("a gettext string")
assert MockView().get_view_description() == 'a gettext string'
def test_markdown(self):
- """
+ """
Ensure markdown to HTML works as expected.
"""
- if apply_markdown:
- md_applied = apply_markdown(DESCRIPTION)
- gte_21_match = (
- md_applied == (
- MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or
- md_applied == (
- MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
- lt_21_match = (
- md_applied == (
- MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or
- md_applied == (
- MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
- assert gte_21_match or lt_21_match
+ if apply_markdown:
+ md_applied = apply_markdown(DESCRIPTION)
+ gte_21_match = ( md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
+ lt_21_match = ( md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
+ assert gte_21_match or lt_21_match
def test_dedent_tabs():
diff --git a/tests/test_encoders.py b/tests/test_encoders.py
index c66954b80..232725000 100644
--- a/tests/test_encoders.py
+++ b/tests/test_encoders.py
@@ -15,81 +15,79 @@ class MockList:
return [1, 2, 3]
-class JSONEncoderTests(TestCase):
- """
+"""
Tests the JSONEncoder method
"""
-
- def setUp(self):
- self.encoder = JSONEncoder()
+defsetUp(self):
+ self.encoder = JSONEncoder()
def test_encode_decimal(self):
- """
+ """
Tests encoding a decimal
"""
- d = Decimal(3.14)
- assert self.encoder.default(d) == float(d)
+ d = Decimal(3.14)
+ assert self.encoder.default(d) == float(d)
def test_encode_datetime(self):
- """
+ """
Tests encoding a datetime object
"""
- current_time = datetime.now()
- assert self.encoder.default(current_time) == current_time.isoformat()
- current_time_utc = current_time.replace(tzinfo=utc)
- assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z'
+ current_time = datetime.now()
+ assert self.encoder.default(current_time) == current_time.isoformat()
+ current_time_utc = current_time.replace(tzinfo=utc)
+ assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z'
def test_encode_time(self):
- """
+ """
Tests encoding a timezone
"""
- current_time = datetime.now().time()
- assert self.encoder.default(current_time) == current_time.isoformat()
+ current_time = datetime.now().time()
+ assert self.encoder.default(current_time) == current_time.isoformat()
def test_encode_time_tz(self):
- """
+ """
Tests encoding a timezone aware timestamp
"""
- current_time = datetime.now().time()
- current_time = current_time.replace(tzinfo=utc)
- with pytest.raises(ValueError):
- self.encoder.default(current_time)
+ current_time = datetime.now().time()
+ current_time = current_time.replace(tzinfo=utc)
+ with pytest.raises(ValueError):
+ self.encoder.default(current_time)
def test_encode_date(self):
- """
+ """
Tests encoding a date object
"""
- current_date = date.today()
- assert self.encoder.default(current_date) == current_date.isoformat()
+ current_date = date.today()
+ assert self.encoder.default(current_date) == current_date.isoformat()
def test_encode_timedelta(self):
- """
+ """
Tests encoding a timedelta object
"""
- delta = timedelta(hours=1)
- assert self.encoder.default(delta) == str(delta.total_seconds())
+ delta = timedelta(hours=1)
+ assert self.encoder.default(delta) == str(delta.total_seconds())
def test_encode_uuid(self):
- """
+ """
Tests encoding a UUID object
"""
- unique_id = uuid4()
- assert self.encoder.default(unique_id) == str(unique_id)
+ unique_id = uuid4()
+ assert self.encoder.default(unique_id) == str(unique_id)
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
- def test_encode_coreapi_raises_error(self):
- """
+deftest_encode_coreapi_raises_error(self):
+ """
Tests encoding a coreapi objects raises proper error
"""
- with pytest.raises(RuntimeError):
- self.encoder.default(coreapi.Document())
+ with pytest.raises(RuntimeError):
+ self.encoder.default(coreapi.Document())
with pytest.raises(RuntimeError):
- self.encoder.default(coreapi.Error())
+ self.encoder.default(coreapi.Error())
def test_encode_object_with_tolist(self):
- """
+ """
Tests encoding a object with tolist method
"""
- foo = MockList()
- assert self.encoder.default(foo) == [1, 2, 3]
+ foo = MockList()
+ assert self.encoder.default(foo) == [1, 2, 3]
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index 9516bfec9..70ad77f1f 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -7,89 +7,58 @@ from rest_framework.exceptions import (
server_error
)
+def test_get_error_details(self):
-class ExceptionTestCase(TestCase):
-
- def test_get_error_details(self):
-
- example = "string"
- lazy_example = _(example)
-
- assert _get_error_details(lazy_example) == example
-
- assert isinstance(
- _get_error_details(lazy_example),
- ErrorDetail
- )
-
- assert _get_error_details({'nested': lazy_example})['nested'] == example
-
- assert isinstance(
- _get_error_details({'nested': lazy_example})['nested'],
- ErrorDetail
- )
-
- assert _get_error_details([[lazy_example]])[0][0] == example
-
- assert isinstance(
- _get_error_details([[lazy_example]])[0][0],
- ErrorDetail
- )
+ example = "string"
+ lazy_example = _(example)
+ assert _get_error_details(lazy_example) == example
+ assert isinstance( _get_error_details(lazy_example), ErrorDetail )
+ assert _get_error_details({'nested': lazy_example})['nested'] == example
+ assert isinstance( _get_error_details({'nested': lazy_example})['nested'], ErrorDetail )
+ assert _get_error_details([[lazy_example]])[0][0] == example
+ assert isinstance( _get_error_details([[lazy_example]])[0][0], ErrorDetail )
def test_get_full_details_with_throttling(self):
- exception = Throttled()
- assert exception.get_full_details() == {
- 'message': 'Request was throttled.', 'code': 'throttled'}
-
- exception = Throttled(wait=2)
- assert exception.get_full_details() == {
- 'message': 'Request was throttled. Expected available in {} seconds.'.format(2),
- 'code': 'throttled'}
-
- exception = Throttled(wait=2, detail='Slow down!')
- assert exception.get_full_details() == {
- 'message': 'Slow down! Expected available in {} seconds.'.format(2),
- 'code': 'throttled'}
+ exception = Throttled()
+ assert exception.get_full_details() == { 'message': 'Request was throttled.', 'code': 'throttled'}
+ exception = Throttled(wait=2)
+ assert exception.get_full_details() == { 'message': 'Request was throttled. Expected available in {} seconds.'.format(2), 'code': 'throttled'}
+ exception = Throttled(wait=2, detail='Slow down!')
+ assert exception.get_full_details() == { 'message': 'Slow down! Expected available in {} seconds.'.format(2), 'code': 'throttled'}
-class ErrorDetailTests(TestCase):
- def test_eq(self):
- assert ErrorDetail('msg') == ErrorDetail('msg')
- assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
-
- assert ErrorDetail('msg') == 'msg'
- assert ErrorDetail('msg', 'code') == 'msg'
+def test_eq(self):
+ assert ErrorDetail('msg') == ErrorDetail('msg')
+ assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
+ assert ErrorDetail('msg') == 'msg'
+ assert ErrorDetail('msg', 'code') == 'msg'
def test_ne(self):
- assert ErrorDetail('msg1') != ErrorDetail('msg2')
- assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
-
- assert ErrorDetail('msg1') != 'msg2'
- assert ErrorDetail('msg1', 'code') != 'msg2'
+ assert ErrorDetail('msg1') != ErrorDetail('msg2')
+ assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
+ assert ErrorDetail('msg1') != 'msg2'
+ assert ErrorDetail('msg1', 'code') != 'msg2'
def test_repr(self):
- assert repr(ErrorDetail('msg1')) == \
- 'ErrorDetail(string={!r}, code=None)'.format('msg1')
- assert repr(ErrorDetail('msg1', 'code')) == \
- 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
+ assert repr(ErrorDetail('msg1')) == 'ErrorDetail(string={!r}, code=None)'.format('msg1')
+ assert repr(ErrorDetail('msg1', 'code')) == 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
def test_str(self):
- assert str(ErrorDetail('msg1')) == 'msg1'
- assert str(ErrorDetail('msg1', 'code')) == 'msg1'
+ assert str(ErrorDetail('msg1')) == 'msg1'
+ assert str(ErrorDetail('msg1', 'code')) == 'msg1'
def test_hash(self):
- assert hash(ErrorDetail('msg')) == hash('msg')
- assert hash(ErrorDetail('msg', 'code')) == hash('msg')
+ assert hash(ErrorDetail('msg')) == hash('msg')
+ assert hash(ErrorDetail('msg', 'code')) == hash('msg')
-class TranslationTests(TestCase):
- @translation.override('fr')
- def test_message(self):
+@translation.override('fr')
+deftest_message(self):
# this test largely acts as a sanity test to ensure the translation files are present.
- self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.')
- self.assertEqual(str(APIException()), 'Une erreur du serveur est survenue.')
+ assert _('A server error occurred.') == 'Une erreur du serveur est survenue.'
+ assert str(APIException()) == 'Une erreur du serveur est survenue.'
def test_server_error():
diff --git a/tests/test_fields.py b/tests/test_fields.py
index e0833564b..909895f51 100644
--- a/tests/test_fields.py
+++ b/tests/test_fields.py
@@ -1134,40 +1134,38 @@ class TestNoStringCoercionDecimalField(FieldValues):
)
-class TestLocalizedDecimalField(TestCase):
- @override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
- def test_to_internal_value(self):
- field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
- assert field.to_internal_value('1,1') == Decimal('1.1')
+@override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
+deftest_to_internal_value(self):
+ field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
+ assert field.to_internal_value('1,1') == Decimal('1.1')
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
- def test_to_representation(self):
- field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
- assert field.to_representation(Decimal('1.1')) == '1,1'
+deftest_to_representation(self):
+ field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
+ assert field.to_representation(Decimal('1.1')) == '1,1'
def test_localize_forces_coerce_to_string(self):
- field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True)
- assert isinstance(field.to_representation(Decimal('1.1')), str)
+ field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True)
+ assert isinstance(field.to_representation(Decimal('1.1')), str)
-class TestQuantizedValueForDecimal(TestCase):
- def test_int_quantized_value_for_decimal(self):
- field = serializers.DecimalField(max_digits=4, decimal_places=2)
- value = field.to_internal_value(12).as_tuple()
- expected_digit_tuple = (0, (1, 2, 0, 0), -2)
- assert value == expected_digit_tuple
+def test_int_quantized_value_for_decimal(self):
+ field = serializers.DecimalField(max_digits=4, decimal_places=2)
+ value = field.to_internal_value(12).as_tuple()
+ expected_digit_tuple = (0, (1, 2, 0, 0), -2)
+ assert value == expected_digit_tuple
def test_string_quantized_value_for_decimal(self):
- field = serializers.DecimalField(max_digits=4, decimal_places=2)
- value = field.to_internal_value('12').as_tuple()
- expected_digit_tuple = (0, (1, 2, 0, 0), -2)
- assert value == expected_digit_tuple
+ field = serializers.DecimalField(max_digits=4, decimal_places=2)
+ value = field.to_internal_value('12').as_tuple()
+ expected_digit_tuple = (0, (1, 2, 0, 0), -2)
+ assert value == expected_digit_tuple
def test_part_precision_string_quantized_value_for_decimal(self):
- field = serializers.DecimalField(max_digits=4, decimal_places=2)
- value = field.to_internal_value('12.0').as_tuple()
- expected_digit_tuple = (0, (1, 2, 0, 0), -2)
- assert value == expected_digit_tuple
+ field = serializers.DecimalField(max_digits=4, decimal_places=2)
+ value = field.to_internal_value('12.0').as_tuple()
+ expected_digit_tuple = (0, (1, 2, 0, 0), -2)
+ assert value == expected_digit_tuple
class TestNoDecimalPlaces(FieldValues):
@@ -1185,17 +1183,15 @@ class TestNoDecimalPlaces(FieldValues):
field = serializers.DecimalField(max_digits=6, decimal_places=None)
-class TestRoundingDecimalField(TestCase):
- def test_valid_rounding(self):
- field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
- assert field.to_representation(Decimal('1.234')) == '1.24'
-
- field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
- assert field.to_representation(Decimal('1.234')) == '1.23'
+def test_valid_rounding(self):
+ field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
+ assert field.to_representation(Decimal('1.234')) == '1.24'
+ field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
+ assert field.to_representation(Decimal('1.234')) == '1.23'
def test_invalid_rounding(self):
- with pytest.raises(AssertionError) as excinfo:
- serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
+ with pytest.raises(AssertionError) as excinfo:
+ serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
assert 'Invalid rounding option' in str(excinfo.value)
@@ -1369,46 +1365,41 @@ class TestTZWithDateTimeField(FieldValues):
@override_settings(TIME_ZONE='UTC', USE_TZ=True)
-class TestDefaultTZDateTimeField(TestCase):
- """
+"""
Test the current/default timezone handling in `DateTimeField`.
"""
-
- @classmethod
- def setup_class(cls):
- cls.field = serializers.DateTimeField()
- cls.kolkata = pytz.timezone('Asia/Kolkata')
+@classmethod
+defsetup_class(cls):
+ cls.field = serializers.DateTimeField()
+ cls.kolkata = pytz.timezone('Asia/Kolkata')
def test_default_timezone(self):
- assert self.field.default_timezone() == utc
+ assert self.field.default_timezone() == utc
def test_current_timezone(self):
- assert self.field.default_timezone() == utc
- activate(self.kolkata)
- assert self.field.default_timezone() == self.kolkata
- deactivate()
- assert self.field.default_timezone() == utc
+ assert self.field.default_timezone() == utc
+ activate(self.kolkata)
+ assert self.field.default_timezone() == self.kolkata
+ deactivate()
+ assert self.field.default_timezone() == utc
@pytest.mark.skipif(pytz is None, reason='pytz not installed')
@override_settings(TIME_ZONE='UTC', USE_TZ=True)
-class TestCustomTimezoneForDateTimeField(TestCase):
- @classmethod
- def setup_class(cls):
- cls.kolkata = pytz.timezone('Asia/Kolkata')
- cls.date_format = '%d/%m/%Y %H:%M'
+@classmethod
+defsetup_class(cls):
+ cls.kolkata = pytz.timezone('Asia/Kolkata')
+ cls.date_format = '%d/%m/%Y %H:%M'
def test_should_render_date_time_in_default_timezone(self):
- field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
- dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
-
- with override(self.kolkata):
- rendered_date = field.to_representation(dt)
+ field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
+ dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
+ with override(self.kolkata):
+ rendered_date = field.to_representation(dt)
rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format)
-
- assert rendered_date == rendered_date_in_timezone
+assertrendered_date==rendered_date_in_timezone
class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
diff --git a/tests/test_filters.py b/tests/test_filters.py
index a52f40103..ec678eb09 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -13,28 +13,25 @@ from rest_framework.compat import coreschema
from rest_framework.test import APIRequestFactory
factory = APIRequestFactory()
-
-
-class BaseFilterTests(TestCase):
- def setUp(self):
- self.original_coreapi = filters.coreapi
- filters.coreapi = True # mock it, because not None value needed
- self.filter_backend = filters.BaseFilterBackend()
+def setUp(self):
+ self.original_coreapi = filters.coreapi
+ filters.coreapi = True
+ self.filter_backend = filters.BaseFilterBackend()
def tearDown(self):
- filters.coreapi = self.original_coreapi
+ filters.coreapi = self.original_coreapi
def test_filter_queryset_raises_error(self):
- with pytest.raises(NotImplementedError):
- self.filter_backend.filter_queryset(None, None, None)
+ with pytest.raises(NotImplementedError):
+ self.filter_backend.filter_queryset(None, None, None)
@pytest.mark.skipif(not coreschema, reason='coreschema is not installed')
- def test_get_schema_fields_checks_for_coreapi(self):
- filters.coreapi = None
- with pytest.raises(AssertionError):
- self.filter_backend.get_schema_fields({})
+deftest_get_schema_fields_checks_for_coreapi(self):
+ filters.coreapi = None
+ with pytest.raises(AssertionError):
+ self.filter_backend.get_schema_fields({})
filters.coreapi = True
- assert self.filter_backend.get_schema_fields({}) == []
+assertself.filter_backend.get_schema_fields({})==[]
class SearchFilterModel(models.Model):
@@ -48,137 +45,115 @@ class SearchFilterSerializer(serializers.ModelSerializer):
fields = '__all__'
-class SearchFilterTests(TestCase):
- def setUp(self):
+def setUp(self):
# Sequence of title/text is:
#
# z abc
# zz bcd
# zzz cde
# ...
- for idx in range(10):
- title = 'z' * (idx + 1)
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
- SearchFilterModel(title=title, text=text).save()
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
+ SearchFilterModel(title=title, text=text).save()
def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+request=factory.get('/',{'search':'b'})
+response=view(request)
+assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
+
+ def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+
+ view = SearchListView.as_view()
+request=factory.get('/')
+response=view(request)
+expected=SearchFilterSerializer(SearchFilterModel.objects.all(),many=True).data
+assertresponse.data==expected
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+request=factory.get('/',{'search':'zzz'})
+response=view(request)
+assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+request=factory.get('/',{'search':'b'})
+response=view(request)
+assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
+
+ def test_regexp_search(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('$title', '$text')
+
+ view = SearchListView.as_view()
+request=factory.get('/',{'search':'z{2} ^b'})
+response=view(request)
+assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
+
+ def test_search_with_nonstandard_search_param(self):
+ with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
+ reload_module(filters)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
- view = SearchListView.as_view()
- request = factory.get('/', {'search': 'b'})
- response = view(request)
- assert response.data == [
- {'id': 1, 'title': 'z', 'text': 'abc'},
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
-
- def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
- class SearchListView(generics.ListAPIView):
- queryset = SearchFilterModel.objects.all()
- serializer_class = SearchFilterSerializer
- filter_backends = (filters.SearchFilter,)
-
- view = SearchListView.as_view()
- request = factory.get('/')
- response = view(request)
- expected = SearchFilterSerializer(SearchFilterModel.objects.all(),
- many=True).data
- assert response.data == expected
-
- def test_exact_search(self):
- class SearchListView(generics.ListAPIView):
- queryset = SearchFilterModel.objects.all()
- serializer_class = SearchFilterSerializer
- filter_backends = (filters.SearchFilter,)
- search_fields = ('=title', 'text')
-
- view = SearchListView.as_view()
- request = factory.get('/', {'search': 'zzz'})
- response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'zzz', 'text': 'cde'}
- ]
-
- def test_startswith_search(self):
- class SearchListView(generics.ListAPIView):
- queryset = SearchFilterModel.objects.all()
- serializer_class = SearchFilterSerializer
- filter_backends = (filters.SearchFilter,)
- search_fields = ('title', '^text')
-
- view = SearchListView.as_view()
- request = factory.get('/', {'search': 'b'})
- response = view(request)
- assert response.data == [
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
-
- def test_regexp_search(self):
- class SearchListView(generics.ListAPIView):
- queryset = SearchFilterModel.objects.all()
- serializer_class = SearchFilterSerializer
- filter_backends = (filters.SearchFilter,)
- search_fields = ('$title', '$text')
-
- view = SearchListView.as_view()
- request = factory.get('/', {'search': 'z{2} ^b'})
- response = view(request)
- assert response.data == [
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
-
- def test_search_with_nonstandard_search_param(self):
- with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
- reload_module(filters)
-
- class SearchListView(generics.ListAPIView):
- queryset = SearchFilterModel.objects.all()
- serializer_class = SearchFilterSerializer
- filter_backends = (filters.SearchFilter,)
- search_fields = ('title', 'text')
-
view = SearchListView.as_view()
- request = factory.get('/', {'query': 'b'})
- response = view(request)
- assert response.data == [
- {'id': 1, 'title': 'z', 'text': 'abc'},
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
+request=factory.get('/',{'query':'b'})
+response=view(request)
+assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
reload_module(filters)
def test_search_with_filter_subclass(self):
- class CustomSearchFilter(filters.SearchFilter):
+ class CustomSearchFilter(filters.SearchFilter):
# Filter that dynamically changes search fields
- def get_search_fields(self, view, request):
- if request.query_params.get('title_only'):
- return ('$title',)
+ def get_search_fields(self, view, request):
+ if request.query_params.get('title_only'):
+ return ('$title',)
return super().get_search_fields(view, request)
class SearchListView(generics.ListAPIView):
- queryset = SearchFilterModel.objects.all()
- serializer_class = SearchFilterSerializer
- filter_backends = (CustomSearchFilter,)
- search_fields = ('$title', '$text')
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (CustomSearchFilter,)
+ search_fields = ('$title', '$text')
view = SearchListView.as_view()
- request = factory.get('/', {'search': r'^\w{3}$'})
- response = view(request)
- assert len(response.data) == 10
-
- request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
- response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'zzz', 'text': 'cde'}
- ]
+request=factory.get('/',{'search':r'^\w{3}$'})
+response=view(request)
+assertlen(response.data)==10
+request=factory.get('/',{'search':r'^\w{3}$','title_only':'true'})
+response=view(request)
+assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
class AttributeModel(models.Model):
@@ -196,31 +171,21 @@ class SearchFilterFkSerializer(serializers.ModelSerializer):
fields = '__all__'
-class SearchFilterFkTests(TestCase):
- def test_must_call_distinct(self):
- filter_ = filters.SearchFilter()
- prefixes = [''] + list(filter_.lookup_prefixes)
- for prefix in prefixes:
- assert not filter_.must_call_distinct(
- SearchFilterModelFk._meta,
- ["%stitle" % prefix]
- )
- assert not filter_.must_call_distinct(
- SearchFilterModelFk._meta,
- ["%stitle" % prefix, "%sattribute__label" % prefix]
- )
+def test_must_call_distinct(self):
+ filter_ = filters.SearchFilter()
+ prefixes = [''] + list(filter_.lookup_prefixes)
+ for prefix in prefixes:
+ assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix] )
+ assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix, "%sattribute__label" % prefix] )
def test_must_call_distinct_restores_meta_for_each_field(self):
# In this test case the attribute of the fk model comes first in the
# list of search fields.
- filter_ = filters.SearchFilter()
- prefixes = [''] + list(filter_.lookup_prefixes)
- for prefix in prefixes:
- assert not filter_.must_call_distinct(
- SearchFilterModelFk._meta,
- ["%sattribute__label" % prefix, "%stitle" % prefix]
- )
+ filter_ = filters.SearchFilter()
+ prefixes = [''] + list(filter_.lookup_prefixes)
+ for prefix in prefixes:
+ assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%sattribute__label" % prefix, "%stitle" % prefix] )
class SearchFilterModelM2M(models.Model):
@@ -235,53 +200,41 @@ class SearchFilterM2MSerializer(serializers.ModelSerializer):
fields = '__all__'
-class SearchFilterM2MTests(TestCase):
- def setUp(self):
+def setUp(self):
# Sequence of title/text/attributes is:
#
# z abc [1, 2, 3]
# zz bcd [1, 2, 3]
# zzz cde [1, 2, 3]
# ...
- for idx in range(3):
- label = 'w' * (idx + 1)
- AttributeModel.objects.create(label=label)
+ for idx in range(3):
+ label = 'w' * (idx + 1)
+ AttributeModel.objects.create(label=label)
for idx in range(10):
- title = 'z' * (idx + 1)
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
- SearchFilterModelM2M(title=title, text=text).save()
+ title = 'z' * (idx + 1)
+ text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
+ SearchFilterModelM2M(title=title, text=text).save()
SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3)
def test_m2m_search(self):
- class SearchListView(generics.ListAPIView):
- queryset = SearchFilterModelM2M.objects.all()
- serializer_class = SearchFilterM2MSerializer
- filter_backends = (filters.SearchFilter,)
- search_fields = ('=title', 'text', 'attributes__label')
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModelM2M.objects.all()
+ serializer_class = SearchFilterM2MSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text', 'attributes__label')
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'zz'})
- response = view(request)
- assert len(response.data) == 1
+request=factory.get('/',{'search':'zz'})
+response=view(request)
+assertlen(response.data)==1
def test_must_call_distinct(self):
- filter_ = filters.SearchFilter()
- prefixes = [''] + list(filter_.lookup_prefixes)
- for prefix in prefixes:
- assert not filter_.must_call_distinct(
- SearchFilterModelM2M._meta,
- ["%stitle" % prefix]
- )
-
- assert filter_.must_call_distinct(
- SearchFilterModelM2M._meta,
- ["%stitle" % prefix, "%sattributes__label" % prefix]
- )
+ filter_ = filters.SearchFilter()
+ prefixes = [''] + list(filter_.lookup_prefixes)
+ for prefix in prefixes:
+ assert not filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix] )
+ assert filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix] )
class Blog(models.Model):
@@ -300,32 +253,27 @@ class BlogSerializer(serializers.ModelSerializer):
fields = '__all__'
-class SearchFilterToManyTests(TestCase):
- @classmethod
- def setUpTestData(cls):
- b1 = Blog.objects.create(name='Blog 1')
- b2 = Blog.objects.create(name='Blog 2')
-
- # Multiple entries on Lennon published in 1979 - distinct should deduplicate
- Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
- Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
-
- # Entry on Lennon *and* a separate entry in 1979 - should not match
- Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
- Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
+@classmethod
+defsetUpTestData(cls):
+ b1 = Blog.objects.create(name='Blog 1')
+ b2 = Blog.objects.create(name='Blog 2')
+ Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
+ Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
+ Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
+ Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
def test_multiple_filter_conditions(self):
- class SearchListView(generics.ListAPIView):
- queryset = Blog.objects.all()
- serializer_class = BlogSerializer
- filter_backends = (filters.SearchFilter,)
- search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
+ class SearchListView(generics.ListAPIView):
+ queryset = Blog.objects.all()
+ serializer_class = BlogSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'Lennon,1979'})
- response = view(request)
- assert len(response.data) == 1
+request=factory.get('/',{'search':'Lennon,1979'})
+response=view(request)
+assertlen(response.data)==1
class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
@@ -336,28 +284,23 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
fields = ('title', 'text', 'title_text')
-class SearchFilterAnnotatedFieldTests(TestCase):
- @classmethod
- def setUpTestData(cls):
- SearchFilterModel.objects.create(title='abc', text='def')
- SearchFilterModel.objects.create(title='ghi', text='jkl')
+@classmethod
+defsetUpTestData(cls):
+ SearchFilterModel.objects.create(title='abc', text='def')
+ SearchFilterModel.objects.create(title='ghi', text='jkl')
def test_search_in_annotated_field(self):
- class SearchListView(generics.ListAPIView):
- queryset = SearchFilterModel.objects.annotate(
- title_text=Upper(
- Concat(models.F('title'), models.F('text'))
- )
- ).all()
- serializer_class = SearchFilterAnnotatedSerializer
- filter_backends = (filters.SearchFilter,)
- search_fields = ('title_text',)
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.annotate( title_text=Upper( Concat(models.F('title'), models.F('text')) ) ).all()
+ serializer_class = SearchFilterAnnotatedSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title_text',)
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'ABCDEF'})
- response = view(request)
- assert len(response.data) == 1
- assert response.data[0]['title_text'] == 'ABCDEF'
+request=factory.get('/',{'search':'ABCDEF'})
+response=view(request)
+assertlen(response.data)==1
+assertresponse.data[0]['title_text']=='ABCDEF'
class OrderingFilterModel(models.Model):
@@ -403,253 +346,187 @@ class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
fields = '__all__'
-class OrderingFilterTests(TestCase):
- def setUp(self):
+def setUp(self):
# Sequence of title/text is:
#
# zyx abc
# yxw bcd
# xwv cde
- for idx in range(3):
- title = (
- chr(ord('z') - idx) +
- chr(ord('y') - idx) +
- chr(ord('x') - idx)
- )
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
- OrderingFilterModel(title=title, text=text).save()
+ for idx in range(3):
+ title = ( chr(ord('z') - idx) + chr(ord('y') - idx) + chr(ord('x') - idx) )
+ text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
+ OrderingFilterModel(title=title, text=text).save()
def test_ordering(self):
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- serializer_class = OrderingFilterSerializer
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'text'})
- response = view(request)
- assert response.data == [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- ]
+request=factory.get('/',{'ordering':'text'})
+response=view(request)
+assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
def test_reverse_ordering(self):
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- serializer_class = OrderingFilterSerializer
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': '-text'})
- response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
+request=factory.get('/',{'ordering':'-text'})
+response=view(request)
+assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_incorrecturl_extrahyphens_ordering(self):
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- serializer_class = OrderingFilterSerializer
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': '--text'})
- response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
+request=factory.get('/',{'ordering':'--text'})
+response=view(request)
+assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_incorrectfield_ordering(self):
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- serializer_class = OrderingFilterSerializer
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'foobar'})
- response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
+request=factory.get('/',{'ordering':'foobar'})
+response=view(request)
+assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_default_ordering(self):
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- serializer_class = OrderingFilterSerializer
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('')
- response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
+request=factory.get('')
+response=view(request)
+assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_default_ordering_using_string(self):
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- serializer_class = OrderingFilterSerializer
- filter_backends = (filters.OrderingFilter,)
- ordering = 'title'
- ordering_fields = ('text',)
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('')
- response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
+request=factory.get('')
+response=view(request)
+assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_ordering_by_aggregate_field(self):
# create some related models to aggregate order by
- num_objs = [2, 5, 3]
- for obj, num_relateds in zip(OrderingFilterModel.objects.all(),
- num_objs):
- for _ in range(num_relateds):
- new_related = OrderingFilterRelatedModel(
- related_object=obj
- )
- new_related.save()
+ num_objs = [2, 5, 3]
+ for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs):
+ for _ in range(num_relateds):
+ new_related = OrderingFilterRelatedModel( related_object=obj )
+ new_related.save()
class OrderingListView(generics.ListAPIView):
- serializer_class = OrderingFilterSerializer
- filter_backends = (filters.OrderingFilter,)
- ordering = 'title'
- ordering_fields = '__all__'
- queryset = OrderingFilterModel.objects.all().annotate(
- models.Count("relateds"))
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = '__all__'
+ queryset = OrderingFilterModel.objects.all().annotate( models.Count("relateds"))
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'relateds__count'})
- response = view(request)
- assert response.data == [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- ]
+request=factory.get('/',{'ordering':'relateds__count'})
+response=view(request)
+assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},]
def test_ordering_by_dotted_source(self):
- for index, obj in enumerate(OrderingFilterModel.objects.all()):
- OrderingFilterRelatedModel.objects.create(
- related_object=obj,
- index=index
- )
+ for index, obj in enumerate(OrderingFilterModel.objects.all()):
+ OrderingFilterRelatedModel.objects.create( related_object=obj, index=index )
class OrderingListView(generics.ListAPIView):
- serializer_class = OrderingDottedRelatedSerializer
- filter_backends = (filters.OrderingFilter,)
- queryset = OrderingFilterRelatedModel.objects.all()
+ serializer_class = OrderingDottedRelatedSerializer
+ filter_backends = (filters.OrderingFilter,)
+ queryset = OrderingFilterRelatedModel.objects.all()
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'related_object__text'})
- response = view(request)
- assert response.data == [
- {'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
- {'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
- {'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
- ]
-
- request = factory.get('/', {'ordering': '-index'})
- response = view(request)
- assert response.data == [
- {'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
- {'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
- {'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
- ]
+request=factory.get('/',{'ordering':'related_object__text'})
+response=view(request)
+assertresponse.data==[{'related_title':'zyx','related_text':'abc','index':0},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'xwv','related_text':'cde','index':2},]
+request=factory.get('/',{'ordering':'-index'})
+response=view(request)
+assertresponse.data==[{'related_title':'xwv','related_text':'cde','index':2},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'zyx','related_text':'abc','index':0},]
def test_ordering_with_nonstandard_ordering_param(self):
- with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
- reload_module(filters)
-
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- serializer_class = OrderingFilterSerializer
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
+ reload_module(filters)
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('/', {'order': 'text'})
- response = view(request)
- assert response.data == [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- ]
+request=factory.get('/',{'order':'text'})
+response=view(request)
+assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
reload_module(filters)
def test_get_template_context(self):
- class OrderingListView(generics.ListAPIView):
- ordering_fields = '__all__'
- serializer_class = OrderingFilterSerializer
- queryset = OrderingFilterModel.objects.all()
- filter_backends = (filters.OrderingFilter,)
+ class OrderingListView(generics.ListAPIView):
+ ordering_fields = '__all__'
+ serializer_class = OrderingFilterSerializer
+ queryset = OrderingFilterModel.objects.all()
+ filter_backends = (filters.OrderingFilter,)
request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html')
- view = OrderingListView.as_view()
- response = view(request)
-
- self.assertContains(response, 'verbose title')
+view=OrderingListView.as_view()
+response=view(request)
+self.assertContains(response,'verbose title')
def test_ordering_with_overridden_get_serializer_class(self):
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
-
- # note: no ordering_fields and serializer_class specified
-
- def get_serializer_class(self):
- return OrderingFilterSerializer
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ def get_serializer_class(self):
+ return OrderingFilterSerializer
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'text'})
- response = view(request)
- assert response.data == [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- ]
+request=factory.get('/',{'ordering':'text'})
+response=view(request)
+assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
def test_ordering_with_improper_configuration(self):
- class OrderingListView(generics.ListAPIView):
- queryset = OrderingFilterModel.objects.all()
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
# note: no ordering_fields and serializer_class
# or get_serializer_class specified
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'text'})
- with self.assertRaises(ImproperlyConfigured):
- view(request)
+request=factory.get('/',{'ordering':'text'})
+withpytest.raises(ImproperlyConfigured):
+ view(request)
class SensitiveOrderingFilterModel(models.Model):
@@ -684,63 +561,44 @@ class SensitiveDataSerializer3(serializers.ModelSerializer):
fields = ('id', 'user')
-class SensitiveOrderingFilterTests(TestCase):
- def setUp(self):
- for idx in range(3):
- username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
- password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
- SensitiveOrderingFilterModel(username=username, password=password).save()
+def setUp(self):
+ for idx in range(3):
+ username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
+ password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
+ SensitiveOrderingFilterModel(username=username, password=password).save()
def test_order_by_serializer_fields(self):
- for serializer_cls in [
- SensitiveDataSerializer1,
- SensitiveDataSerializer2,
- SensitiveDataSerializer3
- ]:
- class OrderingListView(generics.ListAPIView):
- queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
- filter_backends = (filters.OrderingFilter,)
- serializer_class = serializer_cls
+ for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': '-username'})
- response = view(request)
-
- if serializer_cls == SensitiveDataSerializer3:
- username_field = 'user'
+request=factory.get('/',{'ordering':'-username'})
+response=view(request)
+ifserializer_cls==SensitiveDataSerializer3:
+ username_field = 'user'
else:
- username_field = 'username'
+ username_field = 'username'
# Note: Inverse username ordering correctly applied.
- assert response.data == [
- {'id': 3, username_field: 'userC'},
- {'id': 2, username_field: 'userB'},
- {'id': 1, username_field: 'userA'},
- ]
+ assert response.data == [{'id':3,username_field:'userC'},{'id':2,username_field:'userB'},{'id':1,username_field:'userA'},]
def test_cannot_order_by_non_serializer_fields(self):
- for serializer_cls in [
- SensitiveDataSerializer1,
- SensitiveDataSerializer2,
- SensitiveDataSerializer3
- ]:
- class OrderingListView(generics.ListAPIView):
- queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
- filter_backends = (filters.OrderingFilter,)
- serializer_class = serializer_cls
+ for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'password'})
- response = view(request)
-
- if serializer_cls == SensitiveDataSerializer3:
- username_field = 'user'
+request=factory.get('/',{'ordering':'password'})
+response=view(request)
+ifserializer_cls==SensitiveDataSerializer3:
+ username_field = 'user'
else:
- username_field = 'username'
+ username_field = 'username'
# Note: The passwords are not in order. Default ordering is used.
- assert response.data == [
- {'id': 1, username_field: 'userA'}, # PassB
- {'id': 2, username_field: 'userB'}, # PassC
- {'id': 3, username_field: 'userC'}, # PassA
- ]
+ assert response.data == [{'id':1,username_field:'userA'},{'id':2,username_field:'userB'},{'id':3,username_field:'userC'},]
diff --git a/tests/test_generateschema.py b/tests/test_generateschema.py
index a6a1f2bed..02180f681 100644
--- a/tests/test_generateschema.py
+++ b/tests/test_generateschema.py
@@ -23,14 +23,12 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_generateschema')
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
-class GenerateSchemaTests(TestCase):
- """Tests for management command generateschema."""
-
- def setUp(self):
- self.out = io.StringIO()
+"""Tests for management command generateschema."""
+defsetUp(self):
+ self.out = io.StringIO()
def test_renders_default_schema_with_custom_title_url_and_description(self):
- expected_out = """info:
+ expected_out = """info:
description: Sample description
title: SampleAPI
version: ''
@@ -42,45 +40,16 @@ class GenerateSchemaTests(TestCase):
servers:
- url: http://api.sample.com/
"""
- call_command('generateschema',
- '--title=SampleAPI',
- '--url=http://api.sample.com',
- '--description=Sample description',
- stdout=self.out)
-
- self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
+ call_command('generateschema', '--title=SampleAPI', '--url=http://api.sample.com', '--description=Sample description', stdout=self.out)
+ assert formatting.dedent(expected_out) in self.out.getvalue()
def test_renders_openapi_json_schema(self):
- expected_out = {
- "openapi": "3.0.0",
- "info": {
- "version": "",
- "title": "",
- "description": ""
- },
- "servers": [
- {
- "url": ""
- }
- ],
- "paths": {
- "/": {
- "get": {
- "operationId": "list"
- }
- }
- }
- }
- call_command('generateschema',
- '--format=openapi-json',
- stdout=self.out)
- out_json = json.loads(self.out.getvalue())
-
- self.assertDictEqual(out_json, expected_out)
+ expected_out = { "openapi": "3.0.0", "info": { "version": "", "title": "", "description": "" }, "servers": [ { "url": "" } ], "paths": { "/": { "get": { "operationId": "list" } } } }
+ call_command('generateschema', '--format=openapi-json', stdout=self.out)
+ out_json = json.loads(self.out.getvalue())
+ assert out_json == expected_out
def test_renders_corejson_schema(self):
- expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
- call_command('generateschema',
- '--format=corejson',
- stdout=self.out)
- self.assertIn(expected_out, self.out.getvalue())
+ expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
+ call_command('generateschema', '--format=corejson', stdout=self.out)
+ assert expected_out in self.out.getvalue()
diff --git a/tests/test_generics.py b/tests/test_generics.py
index 0b91e3465..48a65cd5d 100644
--- a/tests/test_generics.py
+++ b/tests/test_generics.py
@@ -75,304 +75,282 @@ class SlugBasedInstanceView(InstanceView):
# Tests
-class TestRootView(TestCase):
- def setUp(self):
- """
+def setUp(self):
+ """
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = RootView.as_view()
+self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
+self.view=RootView.as_view()
def test_get_root_view(self):
- """
+ """
GET requests to ListCreateAPIView should return list of objects.
"""
- request = factory.get('/')
- with self.assertNumQueries(1):
- response = self.view(request).render()
+ request = factory.get('/')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
- assert response.data == self.data
+assertresponse.data==self.data
def test_head_root_view(self):
- """
+ """
HEAD requests to ListCreateAPIView should return 200.
"""
- request = factory.head('/')
- with self.assertNumQueries(1):
- response = self.view(request).render()
+ request = factory.head('/')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
def test_post_root_view(self):
- """
+ """
POST requests to ListCreateAPIView should create a new object.
"""
- data = {'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request).render()
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
- assert response.data == {'id': 4, 'text': 'foobar'}
- created = self.objects.get(id=4)
- assert created.text == 'foobar'
+assertresponse.data=={'id':4,'text':'foobar'}
+created=self.objects.get(id=4)
+assertcreated.text=='foobar'
def test_put_root_view(self):
- """
+ """
PUT requests to ListCreateAPIView should not be allowed
"""
- data = {'text': 'foobar'}
- request = factory.put('/', data, format='json')
- with self.assertNumQueries(0):
- response = self.view(request).render()
+ data = {'text': 'foobar'}
+ request = factory.put('/', data, format='json')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
- assert response.data == {"detail": 'Method "PUT" not allowed.'}
+assertresponse.data=={"detail":'Method "PUT" not allowed.'}
def test_delete_root_view(self):
- """
+ """
DELETE requests to ListCreateAPIView should not be allowed
"""
- request = factory.delete('/')
- with self.assertNumQueries(0):
- response = self.view(request).render()
+ request = factory.delete('/')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
- assert response.data == {"detail": 'Method "DELETE" not allowed.'}
+assertresponse.data=={"detail":'Method "DELETE" not allowed.'}
def test_post_cannot_set_id(self):
- """
+ """
POST requests to create a new object should not be able to set the id.
"""
- data = {'id': 999, 'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request).render()
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
- assert response.data == {'id': 4, 'text': 'foobar'}
- created = self.objects.get(id=4)
- assert created.text == 'foobar'
+assertresponse.data=={'id':4,'text':'foobar'}
+created=self.objects.get(id=4)
+assertcreated.text=='foobar'
def test_post_error_root_view(self):
- """
+ """
POST requests to ListCreateAPIView in HTML should include a form error.
"""
- data = {'text': 'foobar' * 100}
- request = factory.post('/', data, HTTP_ACCEPT='text/html')
- response = self.view(request).render()
- expected_error = 'Ensure this field has no more than 100 characters.'
- assert expected_error in response.rendered_content.decode()
+ data = {'text': 'foobar' * 100}
+ request = factory.post('/', data, HTTP_ACCEPT='text/html')
+ response = self.view(request).render()
+ expected_error = 'Ensure this field has no more than 100 characters.'
+ assert expected_error in response.rendered_content.decode()
EXPECTED_QUERIES_FOR_PUT = 2
-
-
-class TestInstanceView(TestCase):
- def setUp(self):
- """
+def setUp(self):
+ """
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz', 'filtered out']
- for item in items:
- BasicModel(text=item).save()
+ items = ['foo', 'bar', 'baz', 'filtered out']
+ for item in items:
+ BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out')
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = InstanceView.as_view()
- self.slug_based_view = SlugBasedInstanceView.as_view()
+self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
+self.view=InstanceView.as_view()
+self.slug_based_view=SlugBasedInstanceView.as_view()
def test_get_instance_view(self):
- """
+ """
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
- request = factory.get('/1')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
- assert response.data == self.data[0]
+assertresponse.data==self.data[0]
def test_post_instance_view(self):
- """
+ """
POST requests to RetrieveUpdateDestroyAPIView should not be allowed
"""
- data = {'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(0):
- response = self.view(request).render()
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
- assert response.data == {"detail": 'Method "POST" not allowed.'}
+assertresponse.data=={"detail":'Method "POST" not allowed.'}
def test_put_instance_view(self):
- """
+ """
PUT requests to RetrieveUpdateDestroyAPIView should update an object.
"""
- data = {'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
- response = self.view(request, pk='1').render()
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
+ response = self.view(request, pk='1').render()
assert response.status_code == status.HTTP_200_OK
- assert dict(response.data) == {'id': 1, 'text': 'foobar'}
- updated = self.objects.get(id=1)
- assert updated.text == 'foobar'
+assertdict(response.data)=={'id':1,'text':'foobar'}
+updated=self.objects.get(id=1)
+assertupdated.text=='foobar'
def test_patch_instance_view(self):
- """
+ """
PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
"""
- data = {'text': 'foobar'}
- request = factory.patch('/1', data, format='json')
-
- with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
- response = self.view(request, pk=1).render()
+ data = {'text': 'foobar'}
+ request = factory.patch('/1', data, format='json')
+ with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
+ response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'id': 1, 'text': 'foobar'}
- updated = self.objects.get(id=1)
- assert updated.text == 'foobar'
+assertresponse.data=={'id':1,'text':'foobar'}
+updated=self.objects.get(id=1)
+assertupdated.text=='foobar'
def test_delete_instance_view(self):
- """
+ """
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
"""
- request = factory.delete('/1')
- with self.assertNumQueries(2):
- response = self.view(request, pk=1).render()
+ request = factory.delete('/1')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_204_NO_CONTENT
- assert response.content == b''
- ids = [obj.id for obj in self.objects.all()]
- assert ids == [2, 3]
+assertresponse.content==b''
+ids=[obj.idforobjinself.objects.all()]
+assertids==[2,3]
def test_get_instance_view_incorrect_arg(self):
- """
+ """
GET requests with an incorrect pk type, should raise 404, not 500.
Regression test for #890.
"""
- request = factory.get('/a')
- with self.assertNumQueries(0):
- response = self.view(request, pk='a').render()
+ request = factory.get('/a')
+ with self.assertNumQueries(0):
+ response = self.view(request, pk='a').render()
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self):
- """
+ """
PUT requests to create a new object should not be able to set the id.
"""
- data = {'id': 999, 'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
- response = self.view(request, pk=1).render()
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
+ response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'id': 1, 'text': 'foobar'}
- updated = self.objects.get(id=1)
- assert updated.text == 'foobar'
+assertresponse.data=={'id':1,'text':'foobar'}
+updated=self.objects.get(id=1)
+assertupdated.text=='foobar'
def test_put_to_deleted_instance(self):
- """
+ """
PUT requests to RetrieveUpdateDestroyAPIView should return 404 if
an object does not currently exist.
"""
- self.objects.get(id=1).delete()
- data = {'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
+ self.objects.get(id=1).delete()
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_to_filtered_out_instance(self):
- """
+ """
PUT requests to an URL of instance which is filtered out should not be
able to create new objects.
"""
- data = {'text': 'foo'}
- filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
- request = factory.put('/{}'.format(filtered_out_pk), data, format='json')
- response = self.view(request, pk=filtered_out_pk).render()
- assert response.status_code == status.HTTP_404_NOT_FOUND
+ data = {'text': 'foo'}
+ filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
+ request = factory.put('/{}'.format(filtered_out_pk), data, format='json')
+ response = self.view(request, pk=filtered_out_pk).render()
+ assert response.status_code == status.HTTP_404_NOT_FOUND
def test_patch_cannot_create_an_object(self):
- """
+ """
PATCH requests should not be able to create objects.
"""
- data = {'text': 'foobar'}
- request = factory.patch('/999', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request, pk=999).render()
+ data = {'text': 'foobar'}
+ request = factory.patch('/999', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
- assert not self.objects.filter(id=999).exists()
+assertnotself.objects.filter(id=999).exists()
def test_put_error_instance_view(self):
- """
+ """
Incorrect PUT requests in HTML should include a form error.
"""
- data = {'text': 'foobar' * 100}
- request = factory.put('/', data, HTTP_ACCEPT='text/html')
- response = self.view(request, pk=1).render()
- expected_error = 'Ensure this field has no more than 100 characters.'
- assert expected_error in response.rendered_content.decode()
+ data = {'text': 'foobar' * 100}
+ request = factory.put('/', data, HTTP_ACCEPT='text/html')
+ response = self.view(request, pk=1).render()
+ expected_error = 'Ensure this field has no more than 100 characters.'
+ assert expected_error in response.rendered_content.decode()
-class TestFKInstanceView(TestCase):
- def setUp(self):
- """
+def setUp(self):
+ """
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz']
- for item in items:
- t = ForeignKeyTarget(name=item)
- t.save()
- ForeignKeySource(name='source_' + item, target=t).save()
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ t = ForeignKeyTarget(name=item)
+ t.save()
+ ForeignKeySource(name='source_' + item, target=t).save()
self.objects = ForeignKeySource.objects
- self.data = [
- {'id': obj.id, 'name': obj.name}
- for obj in self.objects.all()
- ]
- self.view = FKInstanceView.as_view()
+self.data=[{'id':obj.id,'name':obj.name}forobjinself.objects.all()]
+self.view=FKInstanceView.as_view()
-class TestOverriddenGetObject(TestCase):
- """
+"""
Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
queryset/model mechanism but instead overrides get_object()
"""
-
- def setUp(self):
- """
+defsetUp(self):
+ """
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
-
- class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
- """
+self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
+classOverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
Example detail view for override of get_object().
"""
- serializer_class = BasicSerializer
-
- def get_object(self):
- pk = int(self.kwargs['pk'])
- return get_object_or_404(BasicModel.objects.all(), id=pk)
+ serializer_class = BasicSerializer
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view()
def test_overridden_get_object_view(self):
- """
+ """
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
- request = factory.get('/1')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
- assert response.data == self.data[0]
+assertresponse.data==self.data[0]
# Regression test for #285
@@ -388,23 +366,22 @@ class CommentView(generics.ListCreateAPIView):
model = Comment
-class TestCreateModelWithAutoNowAddField(TestCase):
- def setUp(self):
- self.objects = Comment.objects
- self.view = CommentView.as_view()
+def setUp(self):
+ self.objects = Comment.objects
+ self.view = CommentView.as_view()
def test_create_model_with_auto_now_add_field(self):
- """
+ """
Regression test for #285
https://github.com/encode/django-rest-framework/issues/285
"""
- data = {'email': 'foobar@example.com', 'content': 'foobar'}
- request = factory.post('/', data, format='json')
- response = self.view(request).render()
- assert response.status_code == status.HTTP_201_CREATED
- created = self.objects.get(id=1)
- assert created.content == 'foobar'
+ data = {'email': 'foobar@example.com', 'content': 'foobar'}
+ request = factory.post('/', data, format='json')
+ response = self.view(request).render()
+ assert response.status_code == status.HTTP_201_CREATED
+ created = self.objects.get(id=1)
+ assert created.content == 'foobar'
# Test for particularly ugly regression with m2m in browsable API
@@ -432,15 +409,14 @@ class ExampleView(generics.ListCreateAPIView):
queryset = ClassA.objects.all()
-class TestM2MBrowsableAPI(TestCase):
- def test_m2m_in_browsable_api(self):
- """
+def test_m2m_in_browsable_api(self):
+ """
Test for particularly ugly regression with m2m in browsable API
"""
- request = factory.get('/', HTTP_ACCEPT='text/html')
- view = ExampleView().as_view()
- response = view(request).render()
- assert response.status_code == status.HTTP_200_OK
+ request = factory.get('/', HTTP_ACCEPT='text/html')
+ view = ExampleView().as_view()
+ response = view(request).render()
+ assert response.status_code == status.HTTP_200_OK
class InclusiveFilterBackend:
@@ -476,189 +452,179 @@ class DynamicSerializerView(generics.ListCreateAPIView):
return DynamicSerializer
-class TestFilterBackendAppliedToViews(TestCase):
- def setUp(self):
- """
+def setUp(self):
+ """
Create 3 BasicModel instances to filter on.
"""
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
+self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
def test_get_root_view_filters_by_name_with_filter_backend(self):
- """
+ """
GET requests to ListCreateAPIView should return filtered list.
"""
- root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
- request = factory.get('/')
- response = root_view(request).render()
- assert response.status_code == status.HTTP_200_OK
- assert len(response.data) == 1
- assert response.data == [{'id': 1, 'text': 'foo'}]
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ assert response.status_code == status.HTTP_200_OK
+ assert len(response.data) == 1
+ assert response.data == [{'id': 1, 'text': 'foo'}]
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
- """
+ """
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
- root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
- request = factory.get('/')
- response = root_view(request).render()
- assert response.status_code == status.HTTP_200_OK
- assert response.data == []
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == []
def test_get_instance_view_filters_out_name_with_filter_backend(self):
- """
+ """
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
- instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
- request = factory.get('/1')
- response = instance_view(request, pk=1).render()
- assert response.status_code == status.HTTP_404_NOT_FOUND
- assert response.data == {'detail': 'Not found.'}
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.data == {'detail': 'Not found.'}
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
- """
+ """
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
- instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
- request = factory.get('/1')
- response = instance_view(request, pk=1).render()
- assert response.status_code == status.HTTP_200_OK
- assert response.data == {'id': 1, 'text': 'foo'}
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {'id': 1, 'text': 'foo'}
def test_dynamic_serializer_form_in_browsable_api(self):
- """
+ """
GET requests to ListCreateAPIView should return filtered list.
"""
- view = DynamicSerializerView.as_view()
- request = factory.get('/')
- response = view(request).render()
- content = response.content.decode()
- assert 'field_b' in content
- assert 'field_a' not in content
+ view = DynamicSerializerView.as_view()
+ request = factory.get('/')
+ response = view(request).render()
+ content = response.content.decode()
+ assert 'field_b' in content
+ assert 'field_a' not in content
-class TestGuardedQueryset(TestCase):
- def test_guarded_queryset(self):
- class QuerysetAccessError(generics.ListAPIView):
- queryset = BasicModel.objects.all()
-
- def get(self, request):
- return Response(list(self.queryset))
+def test_guarded_queryset(self):
+ class QuerysetAccessError(generics.ListAPIView):
+ queryset = BasicModel.objects.all()
+ def get(self, request):
+ return Response(list(self.queryset))
view = QuerysetAccessError.as_view()
- request = factory.get('/')
- with pytest.raises(RuntimeError):
- view(request).render()
+request=factory.get('/')
+withpytest.raises(RuntimeError):
+ view(request).render()
-class ApiViewsTests(TestCase):
- def test_create_api_view_post(self):
- class MockCreateApiView(generics.CreateAPIView):
- def create(self, request, *args, **kwargs):
- self.called = True
- self.call_args = (request, args, kwargs)
+def test_create_api_view_post(self):
+ class MockCreateApiView(generics.CreateAPIView):
+ def create(self, request, *args, **kwargs):
+ self.called = True
+ self.call_args = (request, args, kwargs)
view = MockCreateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.post('test request', 'test arg', test_kwarg='test')
- assert view.called is True
- assert view.call_args == data
+data=('test request',('test arg',),{'test_kwarg':'test'})
+view.post('test request','test arg',test_kwarg='test')
+assertview.calledisTrue
+assertview.call_args==data
def test_destroy_api_view_delete(self):
- class MockDestroyApiView(generics.DestroyAPIView):
- def destroy(self, request, *args, **kwargs):
- self.called = True
- self.call_args = (request, args, kwargs)
+ class MockDestroyApiView(generics.DestroyAPIView):
+ def destroy(self, request, *args, **kwargs):
+ self.called = True
+ self.call_args = (request, args, kwargs)
view = MockDestroyApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.delete('test request', 'test arg', test_kwarg='test')
- assert view.called is True
- assert view.call_args == data
+data=('test request',('test arg',),{'test_kwarg':'test'})
+view.delete('test request','test arg',test_kwarg='test')
+assertview.calledisTrue
+assertview.call_args==data
def test_update_api_view_partial_update(self):
- class MockUpdateApiView(generics.UpdateAPIView):
- def partial_update(self, request, *args, **kwargs):
- self.called = True
- self.call_args = (request, args, kwargs)
+ class MockUpdateApiView(generics.UpdateAPIView):
+ def partial_update(self, request, *args, **kwargs):
+ self.called = True
+ self.call_args = (request, args, kwargs)
view = MockUpdateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.patch('test request', 'test arg', test_kwarg='test')
- assert view.called is True
- assert view.call_args == data
+data=('test request',('test arg',),{'test_kwarg':'test'})
+view.patch('test request','test arg',test_kwarg='test')
+assertview.calledisTrue
+assertview.call_args==data
def test_retrieve_update_api_view_get(self):
- class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
- def retrieve(self, request, *args, **kwargs):
- self.called = True
- self.call_args = (request, args, kwargs)
+ class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
+ def retrieve(self, request, *args, **kwargs):
+ self.called = True
+ self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.get('test request', 'test arg', test_kwarg='test')
- assert view.called is True
- assert view.call_args == data
+data=('test request',('test arg',),{'test_kwarg':'test'})
+view.get('test request','test arg',test_kwarg='test')
+assertview.calledisTrue
+assertview.call_args==data
def test_retrieve_update_api_view_put(self):
- class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
- def update(self, request, *args, **kwargs):
- self.called = True
- self.call_args = (request, args, kwargs)
+ class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
+ def update(self, request, *args, **kwargs):
+ self.called = True
+ self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.put('test request', 'test arg', test_kwarg='test')
- assert view.called is True
- assert view.call_args == data
+data=('test request',('test arg',),{'test_kwarg':'test'})
+view.put('test request','test arg',test_kwarg='test')
+assertview.calledisTrue
+assertview.call_args==data
def test_retrieve_update_api_view_patch(self):
- class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
- def partial_update(self, request, *args, **kwargs):
- self.called = True
- self.call_args = (request, args, kwargs)
+ class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
+ def partial_update(self, request, *args, **kwargs):
+ self.called = True
+ self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.patch('test request', 'test arg', test_kwarg='test')
- assert view.called is True
- assert view.call_args == data
+data=('test request',('test arg',),{'test_kwarg':'test'})
+view.patch('test request','test arg',test_kwarg='test')
+assertview.calledisTrue
+assertview.call_args==data
def test_retrieve_destroy_api_view_get(self):
- class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
- def retrieve(self, request, *args, **kwargs):
- self.called = True
- self.call_args = (request, args, kwargs)
+ class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
+ def retrieve(self, request, *args, **kwargs):
+ self.called = True
+ self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.get('test request', 'test arg', test_kwarg='test')
- assert view.called is True
- assert view.call_args == data
+data=('test request',('test arg',),{'test_kwarg':'test'})
+view.get('test request','test arg',test_kwarg='test')
+assertview.calledisTrue
+assertview.call_args==data
def test_retrieve_destroy_api_view_delete(self):
- class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
- def destroy(self, request, *args, **kwargs):
- self.called = True
- self.call_args = (request, args, kwargs)
+ class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
+ def destroy(self, request, *args, **kwargs):
+ self.called = True
+ self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.delete('test request', 'test arg', test_kwarg='test')
- assert view.called is True
- assert view.call_args == data
+data=('test request',('test arg',),{'test_kwarg':'test'})
+view.delete('test request','test arg',test_kwarg='test')
+assertview.calledisTrue
+assertview.call_args==data
-class GetObjectOr404Tests(TestCase):
- def setUp(self):
- super().setUp()
- self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
+def setUp(self):
+ super().setUp()
+ self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
def test_get_object_or_404_with_valid_uuid(self):
- obj = generics.get_object_or_404(
- UUIDForeignKeyTarget, pk=self.uuid_object.pk
- )
- assert obj == self.uuid_object
+ obj = generics.get_object_or_404( UUIDForeignKeyTarget, pk=self.uuid_object.pk )
+ assert obj == self.uuid_object
def test_get_object_or_404_with_invalid_string_for_uuid(self):
- with pytest.raises(Http404):
- generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid')
+ with pytest.raises(Http404):
+ generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid')
diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py
index e31a9ced5..ba4e112c9 100644
--- a/tests/test_htmlrenderer.py
+++ b/tests/test_htmlrenderer.py
@@ -42,122 +42,113 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
-class TemplateHTMLRendererTests(TestCase):
- def setUp(self):
- class MockResponse:
- template_name = None
+def setUp(self):
+ class MockResponse:
+ template_name = None
self.mock_response = MockResponse()
- self._monkey_patch_get_template()
+self._monkey_patch_get_template()
def _monkey_patch_get_template(self):
- """
+ """
Monkeypatch get_template
"""
- self.get_template = django.template.loader.get_template
-
- def get_template(template_name, dirs=None):
- if template_name == 'example.html':
- return engines['django'].from_string("example: {{ object }}")
+ self.get_template = django.template.loader.get_template
+ def get_template(template_name, dirs=None):
+ if template_name == 'example.html':
+ return engines['django'].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name)
def select_template(template_name_list, dirs=None, using=None):
- if template_name_list == ['example.html']:
- return engines['django'].from_string("example: {{ object }}")
+ if template_name_list == ['example.html']:
+ return engines['django'].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name_list[0])
django.template.loader.get_template = get_template
- django.template.loader.select_template = select_template
+django.template.loader.select_template=select_template
def tearDown(self):
- """
+ """
Revert monkeypatching
"""
- django.template.loader.get_template = self.get_template
+ django.template.loader.get_template = self.get_template
def test_simple_html_view(self):
- response = self.client.get('/')
- self.assertContains(response, "example: foobar")
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ response = self.client.get('/')
+ self.assertContains(response, "example: foobar")
+ assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_not_found_html_view(self):
- response = self.client.get('/not_found')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.content, b"404 Not Found")
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ response = self.client.get('/not_found')
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.content == b"404 Not Found"
+ assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_permission_denied_html_view(self):
- response = self.client.get('/permission_denied')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(response.content, b"403 Forbidden")
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ response = self.client.get('/permission_denied')
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+ assert response.content == b"403 Forbidden"
+ assert response['Content-Type'] == 'text/html; charset=utf-8'
# 2 tests below are based on order of if statements in corresponding method
# of TemplateHTMLRenderer
def test_get_template_names_returns_own_template_name(self):
- renderer = TemplateHTMLRenderer()
- renderer.template_name = 'test_template'
- template_name = renderer.get_template_names(self.mock_response, view={})
- assert template_name == ['test_template']
+ renderer = TemplateHTMLRenderer()
+ renderer.template_name = 'test_template'
+ template_name = renderer.get_template_names(self.mock_response, view={})
+ assert template_name == ['test_template']
def test_get_template_names_returns_view_template_name(self):
- renderer = TemplateHTMLRenderer()
-
- class MockResponse:
- template_name = None
+ renderer = TemplateHTMLRenderer()
+ class MockResponse:
+ template_name = None
class MockView:
- def get_template_names(self):
- return ['template from get_template_names method']
+ def get_template_names(self):
+ return ['template from get_template_names method']
class MockView2:
- template_name = 'template from template_name attribute'
+ template_name = 'template from template_name attribute'
- template_name = renderer.get_template_names(self.mock_response,
- MockView())
- assert template_name == ['template from get_template_names method']
-
- template_name = renderer.get_template_names(self.mock_response,
- MockView2())
- assert template_name == ['template from template_name attribute']
+ template_name = renderer.get_template_names(self.mock_response,MockView())
+asserttemplate_name==['template from get_template_names method']
+template_name=renderer.get_template_names(self.mock_response,MockView2())
+asserttemplate_name==['template from template_name attribute']
def test_get_template_names_raises_error_if_no_template_found(self):
- renderer = TemplateHTMLRenderer()
- with pytest.raises(ImproperlyConfigured):
- renderer.get_template_names(self.mock_response, view=object())
+ renderer = TemplateHTMLRenderer()
+ with pytest.raises(ImproperlyConfigured):
+ renderer.get_template_names(self.mock_response, view=object())
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
-class TemplateHTMLRendererExceptionTests(TestCase):
- def setUp(self):
- """
+def setUp(self):
+ """
Monkeypatch get_template
"""
- self.get_template = django.template.loader.get_template
-
- def get_template(template_name):
- if template_name == '404.html':
- return engines['django'].from_string("404: {{ detail }}")
+ self.get_template = django.template.loader.get_template
+ def get_template(template_name):
+ if template_name == '404.html':
+ return engines['django'].from_string("404: {{ detail }}")
if template_name == '403.html':
- return engines['django'].from_string("403: {{ detail }}")
+ return engines['django'].from_string("403: {{ detail }}")
raise TemplateDoesNotExist(template_name)
django.template.loader.get_template = get_template
def tearDown(self):
- """
+ """
Revert monkeypatching
"""
- django.template.loader.get_template = self.get_template
+ django.template.loader.get_template = self.get_template
def test_not_found_html_view_with_template(self):
- response = self.client.get('/not_found')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertTrue(response.content in (
- b"404: Not found", b"404 Not Found"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ response = self.client.get('/not_found')
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.content in ( b"404: Not found", b"404 Not Found")
+ assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_permission_denied_html_view_with_template(self):
- response = self.client.get('/permission_denied')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertTrue(response.content in (b"403: Permission denied", b"403 Forbidden"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ response = self.client.get('/permission_denied')
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+ assert response.content in (b"403: Permission denied", b"403 Forbidden")
+ assert response['Content-Type'] == 'text/html; charset=utf-8'
diff --git a/tests/test_lazy_hyperlinks.py b/tests/test_lazy_hyperlinks.py
index cf3ee735f..a4287d8f0 100644
--- a/tests/test_lazy_hyperlinks.py
+++ b/tests/test_lazy_hyperlinks.py
@@ -34,16 +34,15 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks')
-class TestLazyHyperlinkNames(TestCase):
- def setUp(self):
- self.example = Example.objects.create(text='foo')
+def setUp(self):
+ self.example = Example.objects.create(text='foo')
def test_lazy_hyperlink_names(self):
- global str_called
- context = {'request': None}
- serializer = ExampleSerializer(self.example, context=context)
- JSONRenderer().render(serializer.data)
- assert not str_called
- hyperlink_string = format_value(serializer.data['url'])
- assert hyperlink_string == 'An example'
- assert str_called
+ global str_called
+ context = {'request': None}
+ serializer = ExampleSerializer(self.example, context=context)
+ JSONRenderer().render(serializer.data)
+ assert not str_called
+ hyperlink_string = format_value(serializer.data['url'])
+ assert hyperlink_string == 'An example'
+ assert str_called
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index e1a1fd352..d67272f1f 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -308,98 +308,49 @@ class TestMetadata:
assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer)
-class TestSimpleMetadataFieldInfo(TestCase):
- def test_null_boolean_field_info_type(self):
- options = metadata.SimpleMetadata()
- field_info = options.get_field_info(serializers.NullBooleanField())
- assert field_info['type'] == 'boolean'
+def test_null_boolean_field_info_type(self):
+ options = metadata.SimpleMetadata()
+ field_info = options.get_field_info(serializers.NullBooleanField())
+ assert field_info['type'] == 'boolean'
def test_related_field_choices(self):
- options = metadata.SimpleMetadata()
- BasicModel.objects.create()
- with self.assertNumQueries(0):
- field_info = options.get_field_info(
- serializers.RelatedField(queryset=BasicModel.objects.all())
- )
+ options = metadata.SimpleMetadata()
+ BasicModel.objects.create()
+ with self.assertNumQueries(0):
+ field_info = options.get_field_info( serializers.RelatedField(queryset=BasicModel.objects.all()) )
assert 'choices' not in field_info
-class TestModelSerializerMetadata(TestCase):
- def test_read_only_primary_key_related_field(self):
- """
+def test_read_only_primary_key_related_field(self):
+ """
On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests. It should
not fail when a read_only PrimaryKeyRelatedField is present
"""
- class Parent(models.Model):
- integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)])
- children = models.ManyToManyField('Child')
- name = models.CharField(max_length=100, blank=True, null=True)
+ class Parent(models.Model):
+ integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)])
+ children = models.ManyToManyField('Child')
+ name = models.CharField(max_length=100, blank=True, null=True)
class Child(models.Model):
- name = models.CharField(max_length=100)
+ name = models.CharField(max_length=100)
class ExampleSerializer(serializers.ModelSerializer):
- children = serializers.PrimaryKeyRelatedField(read_only=True, many=True)
-
- class Meta:
- model = Parent
- fields = '__all__'
+ children = serializers.PrimaryKeyRelatedField(read_only=True, many=True)
+ class Meta:
+ model = Parent
+ fields = '__all__'
class ExampleView(views.APIView):
- """Example view."""
- def post(self, request):
- pass
+ """Example view."""
+ def post(self, request):
+ pass
def get_serializer(self):
- return ExampleSerializer()
+ return ExampleSerializer()
view = ExampleView.as_view()
- response = view(request=request)
- expected = {
- 'name': 'Example',
- 'description': 'Example view.',
- 'renders': [
- 'application/json',
- 'text/html'
- ],
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'actions': {
- 'POST': {
- 'id': {
- 'type': 'integer',
- 'required': False,
- 'read_only': True,
- 'label': 'ID'
- },
- 'children': {
- 'type': 'field',
- 'required': False,
- 'read_only': True,
- 'label': 'Children'
- },
- 'integer_field': {
- 'type': 'integer',
- 'required': True,
- 'read_only': False,
- 'label': 'Integer field',
- 'min_value': 1,
- 'max_value': 1000
- },
- 'name': {
- 'type': 'string',
- 'required': False,
- 'read_only': False,
- 'label': 'Name',
- 'max_length': 100
- }
- }
- }
- }
-
- assert response.status_code == status.HTTP_200_OK
- assert response.data == expected
+response=view(request=request)
+expected={'name':'Example','description':'Example view.','renders':['application/json','text/html'],'parses':['application/json','application/x-www-form-urlencoded','multipart/form-data'],'actions':{'POST':{'id':{'type':'integer','required':False,'read_only':True,'label':'ID'},'children':{'type':'field','required':False,'read_only':True,'label':'Children'},'integer_field':{'type':'integer','required':True,'read_only':False,'label':'Integer field','min_value':1,'max_value':1000},'name':{'type':'string','required':False,'read_only':False,'label':'Name','max_length':100}}}}
+assertresponse.status_code==status.HTTP_200_OK
+assertresponse.data==expected
diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py
index 413d7885d..cde068f03 100644
--- a/tests/test_model_serializer.py
+++ b/tests/test_model_serializer.py
@@ -110,59 +110,48 @@ class UniqueChoiceModel(models.Model):
name = models.CharField(max_length=254, unique=True, choices=CHOICES)
-class TestModelSerializer(TestCase):
- def test_create_method(self):
- class TestSerializer(serializers.ModelSerializer):
- non_model_field = serializers.CharField()
+def test_create_method(self):
+ class TestSerializer(serializers.ModelSerializer):
+ non_model_field = serializers.CharField()
+ class Meta:
+ model = OneFieldModel
+ fields = ('char_field', 'non_model_field')
- class Meta:
- model = OneFieldModel
- fields = ('char_field', 'non_model_field')
-
- serializer = TestSerializer(data={
- 'char_field': 'foo',
- 'non_model_field': 'bar',
- })
- serializer.is_valid()
-
- msginitial = 'Got a `TypeError` when calling `OneFieldModel.objects.create()`.'
- with self.assertRaisesMessage(TypeError, msginitial):
- serializer.save()
+ serializer = TestSerializer(data={'char_field':'foo','non_model_field':'bar',})
+serializer.is_valid()
+msginitial='Got a `TypeError` when calling `OneFieldModel.objects.create()`.'
+withself.assertRaisesMessage(TypeError,msginitial):
+ serializer.save()
def test_abstract_model(self):
- """
+ """
Test that trying to use ModelSerializer with Abstract Models
throws a ValueError exception.
"""
- class AbstractModel(models.Model):
- afield = models.CharField(max_length=255)
-
- class Meta:
- abstract = True
+ class AbstractModel(models.Model):
+ afield = models.CharField(max_length=255)
+ class Meta:
+ abstract = True
class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = AbstractModel
- fields = ('afield',)
+ class Meta:
+ model = AbstractModel
+ fields = ('afield',)
- serializer = TestSerializer(data={
- 'afield': 'foo',
- })
-
- msginitial = 'Cannot use ModelSerializer with Abstract Models.'
- with self.assertRaisesMessage(ValueError, msginitial):
- serializer.is_valid()
+ serializer = TestSerializer(data={'afield':'foo',})
+msginitial='Cannot use ModelSerializer with Abstract Models.'
+withself.assertRaisesMessage(ValueError,msginitial):
+ serializer.is_valid()
-class TestRegularFieldMappings(TestCase):
- def test_regular_fields(self):
- """
+def test_regular_fields(self):
+ """
Model fields should map to their equivalent serializer fields.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RegularFieldsModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = '__all__'
expected = dedent("""
TestSerializer():
@@ -189,14 +178,13 @@ class TestRegularFieldMappings(TestCase):
custom_field = ModelField(model_field=)
file_path_field = FilePathField(path='/tmp/')
""")
-
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_field_options(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = FieldOptionsModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FieldOptionsModel
+ fields = '__all__'
expected = dedent("""
TestSerializer():
@@ -209,260 +197,248 @@ class TestRegularFieldMappings(TestCase):
descriptive_field = IntegerField(help_text='Some help text', label='A label')
choices_field = ChoiceField(choices=(('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')))
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
# merge this into test_regular_fields / RegularFieldsModel when
# Django 2.1 is the minimum supported version
@pytest.mark.skipif(django.VERSION < (2, 1), reason='Django version < 2.1')
- def test_nullable_boolean_field(self):
- class NullableBooleanModel(models.Model):
- field = models.BooleanField(null=True, default=False)
+deftest_nullable_boolean_field(self):
+ class NullableBooleanModel(models.Model):
+ field = models.BooleanField(null=True, default=False)
class NullableBooleanSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableBooleanModel
- fields = ['field']
+ class Meta:
+ model = NullableBooleanModel
+ fields = ['field']
expected = dedent("""
NullableBooleanSerializer():
field = BooleanField(allow_null=True, required=False)
""")
-
- self.assertEqual(repr(NullableBooleanSerializer()), expected)
+assertrepr(NullableBooleanSerializer())==expected
def test_method_field(self):
- """
+ """
Properties and methods on the model should be allowed as `Meta.fields`
values, and should map to `ReadOnlyField`.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RegularFieldsModel
- fields = ('auto_field', 'method')
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field', 'method')
expected = dedent("""
TestSerializer():
auto_field = IntegerField(read_only=True)
method = ReadOnlyField()
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_pk_fields(self):
- """
+ """
Both `pk` and the actual primary key name are valid in `Meta.fields`.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RegularFieldsModel
- fields = ('pk', 'auto_field')
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('pk', 'auto_field')
expected = dedent("""
TestSerializer():
pk = IntegerField(label='Auto field', read_only=True)
auto_field = IntegerField(read_only=True)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_extra_field_kwargs(self):
- """
+ """
Ensure `extra_kwargs` are passed to generated fields.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RegularFieldsModel
- fields = ('auto_field', 'char_field')
- extra_kwargs = {'char_field': {'default': 'extra'}}
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field', 'char_field')
+ extra_kwargs = {'char_field': {'default': 'extra'}}
expected = dedent("""
TestSerializer():
auto_field = IntegerField(read_only=True)
char_field = CharField(default='extra', max_length=100)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_extra_field_kwargs_required(self):
- """
+ """
Ensure `extra_kwargs` are passed to generated fields.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RegularFieldsModel
- fields = ('auto_field', 'char_field')
- extra_kwargs = {'auto_field': {'required': False, 'read_only': False}}
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field', 'char_field')
+ extra_kwargs = {'auto_field': {'required': False, 'read_only': False}}
expected = dedent("""
TestSerializer():
auto_field = IntegerField(read_only=False, required=False)
char_field = CharField(max_length=100)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_invalid_field(self):
- """
+ """
Field names that do not map to a model field or relationship should
raise a configuration errror.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RegularFieldsModel
- fields = ('auto_field', 'invalid')
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field', 'invalid')
expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.'
- with self.assertRaisesMessage(ImproperlyConfigured, expected):
- TestSerializer().fields
+withself.assertRaisesMessage(ImproperlyConfigured,expected):
+ TestSerializer().fields
def test_missing_field(self):
- """
+ """
Fields that have been declared on the serializer class must be included
in the `Meta.fields` if it exists.
"""
- class TestSerializer(serializers.ModelSerializer):
- missing = serializers.ReadOnlyField()
+ class TestSerializer(serializers.ModelSerializer):
+ missing = serializers.ReadOnlyField()
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field',)
- class Meta:
- model = RegularFieldsModel
- fields = ('auto_field',)
-
- expected = (
- "The field 'missing' was declared on serializer TestSerializer, "
- "but has not been included in the 'fields' option."
- )
- with self.assertRaisesMessage(AssertionError, expected):
- TestSerializer().fields
+ expected = ("The field 'missing' was declared on serializer TestSerializer, ""but has not been included in the 'fields' option.")
+withself.assertRaisesMessage(AssertionError,expected):
+ TestSerializer().fields
def test_missing_superclass_field(self):
- """
+ """
Fields that have been declared on a parent of the serializer class may
be excluded from the `Meta.fields` option.
"""
- class TestSerializer(serializers.ModelSerializer):
- missing = serializers.ReadOnlyField()
+ class TestSerializer(serializers.ModelSerializer):
+ missing = serializers.ReadOnlyField()
class ChildSerializer(TestSerializer):
- class Meta:
- model = RegularFieldsModel
- fields = ('auto_field',)
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field',)
ChildSerializer().fields
def test_choices_with_nonstandard_args(self):
- class ExampleSerializer(serializers.ModelSerializer):
- class Meta:
- model = ChoicesModel
- fields = '__all__'
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoicesModel
+ fields = '__all__'
ExampleSerializer()
-class TestDurationFieldMapping(TestCase):
- def test_duration_field(self):
- class DurationFieldModel(models.Model):
- """
+def test_duration_field(self):
+ class DurationFieldModel(models.Model):
+ """
A model that defines DurationField.
"""
- duration_field = models.DurationField()
+ duration_field = models.DurationField()
class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = DurationFieldModel
- fields = '__all__'
+ class Meta:
+ model = DurationFieldModel
+ fields = '__all__'
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
duration_field = DurationField()
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_duration_field_with_validators(self):
- class ValidatedDurationFieldModel(models.Model):
- """
+ class ValidatedDurationFieldModel(models.Model):
+ """
A model that defines DurationField with validators.
"""
- duration_field = models.DurationField(
- validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))]
- )
+ duration_field = models.DurationField( validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))] )
class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = ValidatedDurationFieldModel
- fields = '__all__'
+ class Meta:
+ model = ValidatedDurationFieldModel
+ fields = '__all__'
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
duration_field = DurationField(max_value=datetime.timedelta(3), min_value=datetime.timedelta(1))
- """) if sys.version_info < (3, 7) else dedent("""
+ """)ifsys.version_info<(3,7)elsededent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
duration_field = DurationField(max_value=datetime.timedelta(days=3), min_value=datetime.timedelta(days=1))
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
-class TestGenericIPAddressFieldValidation(TestCase):
- def test_ip_address_validation(self):
- class IPAddressFieldModel(models.Model):
- address = models.GenericIPAddressField()
+def test_ip_address_validation(self):
+ class IPAddressFieldModel(models.Model):
+ address = models.GenericIPAddressField()
class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = IPAddressFieldModel
- fields = '__all__'
+ class Meta:
+ model = IPAddressFieldModel
+ fields = '__all__'
s = TestSerializer(data={'address': 'not an ip address'})
- self.assertFalse(s.is_valid())
- self.assertEqual(1, len(s.errors['address']),
- 'Unexpected number of validation errors: '
- '{}'.format(s.errors))
+assertnots.is_valid()
+assert1==len(s.errors['address']),'Unexpected number of validation errors: ''{}'.format(s.errors)
@pytest.mark.skipif('not postgres_fields')
-class TestPosgresFieldsMapping(TestCase):
- def test_hstore_field(self):
- class HStoreFieldModel(models.Model):
- hstore_field = postgres_fields.HStoreField()
+def test_hstore_field(self):
+ class HStoreFieldModel(models.Model):
+ hstore_field = postgres_fields.HStoreField()
class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = HStoreFieldModel
- fields = ['hstore_field']
+ class Meta:
+ model = HStoreFieldModel
+ fields = ['hstore_field']
expected = dedent("""
TestSerializer():
hstore_field = HStoreField()
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_array_field(self):
- class ArrayFieldModel(models.Model):
- array_field = postgres_fields.ArrayField(base_field=models.CharField())
+ class ArrayFieldModel(models.Model):
+ array_field = postgres_fields.ArrayField(base_field=models.CharField())
class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = ArrayFieldModel
- fields = ['array_field']
+ class Meta:
+ model = ArrayFieldModel
+ fields = ['array_field']
expected = dedent("""
TestSerializer():
array_field = ListField(child=CharField(label='Array field', validators=[]))
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_json_field(self):
- class JSONFieldModel(models.Model):
- json_field = postgres_fields.JSONField()
+ class JSONFieldModel(models.Model):
+ json_field = postgres_fields.JSONField()
class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = JSONFieldModel
- fields = ['json_field']
+ class Meta:
+ model = JSONFieldModel
+ fields = ['json_field']
expected = dedent("""
TestSerializer():
json_field = JSONField(style={'base_template': 'textarea.html'})
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
# Tests for relational field mappings.
@@ -505,12 +481,11 @@ class UniqueTogetherModel(models.Model):
unique_together = ("foreign_key", "one_to_one")
-class TestRelationalFieldMappings(TestCase):
- def test_pk_relations(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RelationalModel
- fields = '__all__'
+def test_pk_relations(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+ fields = '__all__'
expected = dedent("""
TestSerializer():
@@ -520,14 +495,14 @@ class TestRelationalFieldMappings(TestCase):
many_to_many = PrimaryKeyRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all())
through = PrimaryKeyRelatedField(many=True, read_only=True)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_nested_relations(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RelationalModel
- depth = 1
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+ depth = 1
+ fields = '__all__'
expected = dedent("""
TestSerializer():
@@ -545,13 +520,13 @@ class TestRelationalFieldMappings(TestCase):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_hyperlinked_relations(self):
- class TestSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = RelationalModel
- fields = '__all__'
+ class TestSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RelationalModel
+ fields = '__all__'
expected = dedent("""
TestSerializer():
@@ -561,14 +536,14 @@ class TestRelationalFieldMappings(TestCase):
many_to_many = HyperlinkedRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_nested_hyperlinked_relations(self):
- class TestSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = RelationalModel
- depth = 1
- fields = '__all__'
+ class TestSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RelationalModel
+ depth = 1
+ fields = '__all__'
expected = dedent("""
TestSerializer():
@@ -586,19 +561,15 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_nested_hyperlinked_relations_starred_source(self):
- class TestSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = RelationalModel
- depth = 1
- fields = '__all__'
-
- extra_kwargs = {
- 'url': {
- 'source': '*',
- }}
+ class TestSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RelationalModel
+ depth = 1
+ fields = '__all__'
+ extra_kwargs = { 'url': { 'source': '*', }}
expected = dedent("""
TestSerializer():
@@ -616,15 +587,15 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100)
""")
- self.maxDiff = None
- self.assertEqual(repr(TestSerializer()), expected)
+self.maxDiff=None
+assertrepr(TestSerializer())==expected
def test_nested_unique_together_relations(self):
- class TestSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = UniqueTogetherModel
- depth = 1
- fields = '__all__'
+ class TestSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = UniqueTogetherModel
+ depth = 1
+ fields = '__all__'
expected = dedent("""
TestSerializer():
@@ -636,13 +607,13 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_pk_reverse_foreign_key(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeyTargetModel
- fields = ('id', 'name', 'reverse_foreign_key')
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeyTargetModel
+ fields = ('id', 'name', 'reverse_foreign_key')
expected = dedent("""
TestSerializer():
@@ -650,13 +621,13 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100)
reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_pk_reverse_one_to_one(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTargetModel
- fields = ('id', 'name', 'reverse_one_to_one')
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTargetModel
+ fields = ('id', 'name', 'reverse_one_to_one')
expected = dedent("""
TestSerializer():
@@ -664,13 +635,13 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100)
reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all())
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_pk_reverse_many_to_many(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyTargetModel
- fields = ('id', 'name', 'reverse_many_to_many')
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyTargetModel
+ fields = ('id', 'name', 'reverse_many_to_many')
expected = dedent("""
TestSerializer():
@@ -678,13 +649,13 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100)
reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
def test_pk_reverse_through(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = ThroughTargetModel
- fields = ('id', 'name', 'reverse_through')
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ThroughTargetModel
+ fields = ('id', 'name', 'reverse_through')
expected = dedent("""
TestSerializer():
@@ -692,7 +663,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100)
reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+assertrepr(TestSerializer())==expected
class DisplayValueTargetModel(models.Model):
@@ -706,171 +677,91 @@ class DisplayValueModel(models.Model):
color = models.ForeignKey(DisplayValueTargetModel, on_delete=models.CASCADE)
-class TestRelationalFieldDisplayValue(TestCase):
- def setUp(self):
- DisplayValueTargetModel.objects.bulk_create([
- DisplayValueTargetModel(name='Red'),
- DisplayValueTargetModel(name='Yellow'),
- DisplayValueTargetModel(name='Green'),
- ])
+def setUp(self):
+ DisplayValueTargetModel.objects.bulk_create([ DisplayValueTargetModel(name='Red'), DisplayValueTargetModel(name='Yellow'), DisplayValueTargetModel(name='Green'), ])
def test_default_display_value(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = DisplayValueModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DisplayValueModel
+ fields = '__all__'
serializer = TestSerializer()
- expected = OrderedDict([(1, 'Red Color'), (2, 'Yellow Color'), (3, 'Green Color')])
- self.assertEqual(serializer.fields['color'].choices, expected)
+expected=OrderedDict([(1,'Red Color'),(2,'Yellow Color'),(3,'Green Color')])
+assertserializer.fields['color'].choices==expected
def test_custom_display_value(self):
- class TestField(serializers.PrimaryKeyRelatedField):
- def display_value(self, instance):
- return 'My %s Color' % (instance.name)
+ class TestField(serializers.PrimaryKeyRelatedField):
+ def display_value(self, instance):
+ return 'My %s Color' % (instance.name)
class TestSerializer(serializers.ModelSerializer):
- color = TestField(queryset=DisplayValueTargetModel.objects.all())
-
- class Meta:
- model = DisplayValueModel
- fields = '__all__'
+ color = TestField(queryset=DisplayValueTargetModel.objects.all())
+ class Meta:
+ model = DisplayValueModel
+ fields = '__all__'
serializer = TestSerializer()
- expected = OrderedDict([(1, 'My Red Color'), (2, 'My Yellow Color'), (3, 'My Green Color')])
- self.assertEqual(serializer.fields['color'].choices, expected)
+expected=OrderedDict([(1,'My Red Color'),(2,'My Yellow Color'),(3,'My Green Color')])
+assertserializer.fields['color'].choices==expected
-class TestIntegration(TestCase):
- def setUp(self):
- self.foreign_key_target = ForeignKeyTargetModel.objects.create(
- name='foreign_key'
- )
- self.one_to_one_target = OneToOneTargetModel.objects.create(
- name='one_to_one'
- )
- self.many_to_many_targets = [
- ManyToManyTargetModel.objects.create(
- name='many_to_many (%d)' % idx
- ) for idx in range(3)
- ]
- self.instance = RelationalModel.objects.create(
- foreign_key=self.foreign_key_target,
- one_to_one=self.one_to_one_target,
- )
- self.instance.many_to_many.set(self.many_to_many_targets)
+def setUp(self):
+ self.foreign_key_target = ForeignKeyTargetModel.objects.create( name='foreign_key' )
+ self.one_to_one_target = OneToOneTargetModel.objects.create( name='one_to_one' )
+ self.many_to_many_targets = [ ManyToManyTargetModel.objects.create( name='many_to_many (%d)' % idx ) for idx in range(3) ]
+ self.instance = RelationalModel.objects.create( foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, )
+ self.instance.many_to_many.set(self.many_to_many_targets)
def test_pk_retrival(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RelationalModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+ fields = '__all__'
serializer = TestSerializer(self.instance)
- expected = {
- 'id': self.instance.pk,
- 'foreign_key': self.foreign_key_target.pk,
- 'one_to_one': self.one_to_one_target.pk,
- 'many_to_many': [item.pk for item in self.many_to_many_targets],
- 'through': []
- }
- self.assertEqual(serializer.data, expected)
+expected={'id':self.instance.pk,'foreign_key':self.foreign_key_target.pk,'one_to_one':self.one_to_one_target.pk,'many_to_many':[item.pkforiteminself.many_to_many_targets],'through':[]}
+assertserializer.data==expected
def test_pk_create(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RelationalModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+ fields = '__all__'
- new_foreign_key = ForeignKeyTargetModel.objects.create(
- name='foreign_key'
- )
- new_one_to_one = OneToOneTargetModel.objects.create(
- name='one_to_one'
- )
- new_many_to_many = [
- ManyToManyTargetModel.objects.create(
- name='new many_to_many (%d)' % idx
- ) for idx in range(3)
- ]
- data = {
- 'foreign_key': new_foreign_key.pk,
- 'one_to_one': new_one_to_one.pk,
- 'many_to_many': [item.pk for item in new_many_to_many],
- }
-
- # Serializer should validate okay.
- serializer = TestSerializer(data=data)
- assert serializer.is_valid()
-
- # Creating the instance, relationship attributes should be set.
- instance = serializer.save()
- assert instance.foreign_key.pk == new_foreign_key.pk
- assert instance.one_to_one.pk == new_one_to_one.pk
- assert [
- item.pk for item in instance.many_to_many.all()
- ] == [
- item.pk for item in new_many_to_many
- ]
- assert list(instance.through.all()) == []
-
- # Representation should be correct.
- expected = {
- 'id': instance.pk,
- 'foreign_key': new_foreign_key.pk,
- 'one_to_one': new_one_to_one.pk,
- 'many_to_many': [item.pk for item in new_many_to_many],
- 'through': []
- }
- self.assertEqual(serializer.data, expected)
+ new_foreign_key = ForeignKeyTargetModel.objects.create(name='foreign_key')
+new_one_to_one=OneToOneTargetModel.objects.create(name='one_to_one')
+new_many_to_many=[ManyToManyTargetModel.objects.create(name='new many_to_many (%d)'%idx)foridxinrange(3)]
+data={'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],}
+serializer=TestSerializer(data=data)
+assertserializer.is_valid()
+instance=serializer.save()
+assertinstance.foreign_key.pk==new_foreign_key.pk
+assertinstance.one_to_one.pk==new_one_to_one.pk
+assert[item.pkforitemininstance.many_to_many.all()]==[item.pkforiteminnew_many_to_many]
+assertlist(instance.through.all())==[]
+expected={'id':instance.pk,'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],'through':[]}
+assertserializer.data==expected
def test_pk_update(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RelationalModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+ fields = '__all__'
- new_foreign_key = ForeignKeyTargetModel.objects.create(
- name='foreign_key'
- )
- new_one_to_one = OneToOneTargetModel.objects.create(
- name='one_to_one'
- )
- new_many_to_many = [
- ManyToManyTargetModel.objects.create(
- name='new many_to_many (%d)' % idx
- ) for idx in range(3)
- ]
- data = {
- 'foreign_key': new_foreign_key.pk,
- 'one_to_one': new_one_to_one.pk,
- 'many_to_many': [item.pk for item in new_many_to_many],
- }
-
- # Serializer should validate okay.
- serializer = TestSerializer(self.instance, data=data)
- assert serializer.is_valid()
-
- # Creating the instance, relationship attributes should be set.
- instance = serializer.save()
- assert instance.foreign_key.pk == new_foreign_key.pk
- assert instance.one_to_one.pk == new_one_to_one.pk
- assert [
- item.pk for item in instance.many_to_many.all()
- ] == [
- item.pk for item in new_many_to_many
- ]
- assert list(instance.through.all()) == []
-
- # Representation should be correct.
- expected = {
- 'id': self.instance.pk,
- 'foreign_key': new_foreign_key.pk,
- 'one_to_one': new_one_to_one.pk,
- 'many_to_many': [item.pk for item in new_many_to_many],
- 'through': []
- }
- self.assertEqual(serializer.data, expected)
+ new_foreign_key = ForeignKeyTargetModel.objects.create(name='foreign_key')
+new_one_to_one=OneToOneTargetModel.objects.create(name='one_to_one')
+new_many_to_many=[ManyToManyTargetModel.objects.create(name='new many_to_many (%d)'%idx)foridxinrange(3)]
+data={'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],}
+serializer=TestSerializer(self.instance,data=data)
+assertserializer.is_valid()
+instance=serializer.save()
+assertinstance.foreign_key.pk==new_foreign_key.pk
+assertinstance.one_to_one.pk==new_one_to_one.pk
+assert[item.pkforitemininstance.many_to_many.all()]==[item.pkforiteminnew_many_to_many]
+assertlist(instance.through.all())==[]
+expected={'id':self.instance.pk,'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],'through':[]}
+assertserializer.data==expected
# Tests for bulk create using `ListSerializer`.
@@ -879,109 +770,88 @@ class BulkCreateModel(models.Model):
name = models.CharField(max_length=10)
-class TestBulkCreate(TestCase):
- def test_bulk_create(self):
- class BasicModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = BulkCreateModel
- fields = ('name',)
+def test_bulk_create(self):
+ class BasicModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BulkCreateModel
+ fields = ('name',)
class BulkCreateSerializer(serializers.ListSerializer):
- child = BasicModelSerializer()
+ child = BasicModelSerializer()
data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}]
- serializer = BulkCreateSerializer(data=data)
- assert serializer.is_valid()
-
- # Objects are returned by save().
- instances = serializer.save()
- assert len(instances) == 3
- assert [item.name for item in instances] == ['a', 'b', 'c']
-
- # Objects have been created in the database.
- assert BulkCreateModel.objects.count() == 3
- assert list(BulkCreateModel.objects.values_list('name', flat=True)) == ['a', 'b', 'c']
-
- # Serializer returns correct data.
- assert serializer.data == data
+serializer=BulkCreateSerializer(data=data)
+assertserializer.is_valid()
+instances=serializer.save()
+assertlen(instances)==3
+assert[item.nameforitemininstances]==['a','b','c']
+assertBulkCreateModel.objects.count()==3
+assertlist(BulkCreateModel.objects.values_list('name',flat=True))==['a','b','c']
+assertserializer.data==data
class MetaClassTestModel(models.Model):
text = models.CharField(max_length=100)
-class TestSerializerMetaClass(TestCase):
- def test_meta_class_fields_option(self):
- class ExampleSerializer(serializers.ModelSerializer):
- class Meta:
- model = MetaClassTestModel
- fields = 'text'
+def test_meta_class_fields_option(self):
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = MetaClassTestModel
+ fields = 'text'
msginitial = "The `fields` option must be a list or tuple"
- with self.assertRaisesMessage(TypeError, msginitial):
- ExampleSerializer().fields
+withself.assertRaisesMessage(TypeError,msginitial):
+ ExampleSerializer().fields
def test_meta_class_exclude_option(self):
- class ExampleSerializer(serializers.ModelSerializer):
- class Meta:
- model = MetaClassTestModel
- exclude = 'text'
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = MetaClassTestModel
+ exclude = 'text'
msginitial = "The `exclude` option must be a list or tuple"
- with self.assertRaisesMessage(TypeError, msginitial):
- ExampleSerializer().fields
+withself.assertRaisesMessage(TypeError,msginitial):
+ ExampleSerializer().fields
def test_meta_class_fields_and_exclude_options(self):
- class ExampleSerializer(serializers.ModelSerializer):
- class Meta:
- model = MetaClassTestModel
- fields = ('text',)
- exclude = ('text',)
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = MetaClassTestModel
+ fields = ('text',)
+ exclude = ('text',)
msginitial = "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer."
- with self.assertRaisesMessage(AssertionError, msginitial):
- ExampleSerializer().fields
+withself.assertRaisesMessage(AssertionError,msginitial):
+ ExampleSerializer().fields
def test_declared_fields_with_exclude_option(self):
- class ExampleSerializer(serializers.ModelSerializer):
- text = serializers.CharField()
+ class ExampleSerializer(serializers.ModelSerializer):
+ text = serializers.CharField()
+ class Meta:
+ model = MetaClassTestModel
+ exclude = ('text',)
- class Meta:
- model = MetaClassTestModel
- exclude = ('text',)
-
- expected = (
- "Cannot both declare the field 'text' and include it in the "
- "ExampleSerializer 'exclude' option. Remove the field or, if "
- "inherited from a parent serializer, disable with `text = None`."
- )
- with self.assertRaisesMessage(AssertionError, expected):
- ExampleSerializer().fields
+ expected = ("Cannot both declare the field 'text' and include it in the ""ExampleSerializer 'exclude' option. Remove the field or, if ""inherited from a parent serializer, disable with `text = None`.")
+withself.assertRaisesMessage(AssertionError,expected):
+ ExampleSerializer().fields
-class Issue2704TestCase(TestCase):
- def test_queryset_all(self):
- class TestSerializer(serializers.ModelSerializer):
- additional_attr = serializers.CharField()
-
- class Meta:
- model = OneFieldModel
- fields = ('char_field', 'additional_attr')
+def test_queryset_all(self):
+ class TestSerializer(serializers.ModelSerializer):
+ additional_attr = serializers.CharField()
+ class Meta:
+ model = OneFieldModel
+ fields = ('char_field', 'additional_attr')
OneFieldModel.objects.create(char_field='abc')
- qs = OneFieldModel.objects.all()
-
- for o in qs:
- o.additional_attr = '123'
+qs=OneFieldModel.objects.all()
+foroinqs:
+ o.additional_attr = '123'
serializer = TestSerializer(instance=qs, many=True)
-
- expected = [{
- 'char_field': 'abc',
- 'additional_attr': '123',
- }]
-
- assert serializer.data == expected
+expected=[{'char_field':'abc','additional_attr':'123',}]
+assertserializer.data==expected
class DecimalFieldModel(models.Model):
@@ -992,78 +862,71 @@ class DecimalFieldModel(models.Model):
)
-class TestDecimalFieldMappings(TestCase):
- def test_decimal_field_has_decimal_validator(self):
- """
+def test_decimal_field_has_decimal_validator(self):
+ """
Test that a `DecimalField` has no `DecimalValidator`.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = DecimalFieldModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DecimalFieldModel
+ fields = '__all__'
serializer = TestSerializer()
-
- assert len(serializer.fields['decimal_field'].validators) == 2
+assertlen(serializer.fields['decimal_field'].validators)==2
def test_min_value_is_passed(self):
- """
+ """
Test that the `MinValueValidator` is converted to the `min_value`
argument for the field.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = DecimalFieldModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DecimalFieldModel
+ fields = '__all__'
serializer = TestSerializer()
-
- assert serializer.fields['decimal_field'].min_value == 1
+assertserializer.fields['decimal_field'].min_value==1
def test_max_value_is_passed(self):
- """
+ """
Test that the `MaxValueValidator` is converted to the `max_value`
argument for the field.
"""
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = DecimalFieldModel
- fields = '__all__'
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DecimalFieldModel
+ fields = '__all__'
serializer = TestSerializer()
-
- assert serializer.fields['decimal_field'].max_value == 3
+assertserializer.fields['decimal_field'].max_value==3
-class TestMetaInheritance(TestCase):
- def test_extra_kwargs_not_altered(self):
- class TestSerializer(serializers.ModelSerializer):
- non_model_field = serializers.CharField()
-
- class Meta:
- model = OneFieldModel
- read_only_fields = ('char_field', 'non_model_field')
- fields = read_only_fields
- extra_kwargs = {}
+def test_extra_kwargs_not_altered(self):
+ class TestSerializer(serializers.ModelSerializer):
+ non_model_field = serializers.CharField()
+ class Meta:
+ model = OneFieldModel
+ read_only_fields = ('char_field', 'non_model_field')
+ fields = read_only_fields
+ extra_kwargs = {}
class ChildSerializer(TestSerializer):
- class Meta(TestSerializer.Meta):
- read_only_fields = ()
+ class Meta(TestSerializer.Meta):
+ read_only_fields = ()
test_expected = dedent("""
TestSerializer():
char_field = CharField(read_only=True)
non_model_field = CharField()
""")
-
- child_expected = dedent("""
+child_expected=dedent("""
ChildSerializer():
char_field = CharField(max_length=100)
non_model_field = CharField()
""")
- self.assertEqual(repr(ChildSerializer()), child_expected)
- self.assertEqual(repr(TestSerializer()), test_expected)
- self.assertEqual(repr(ChildSerializer()), child_expected)
+assertrepr(ChildSerializer())==child_expected
+assertrepr(TestSerializer())==test_expected
+assertrepr(ChildSerializer())==child_expected
class OneToOneTargetTestModel(models.Model):
@@ -1074,57 +937,53 @@ class OneToOneSourceTestModel(models.Model):
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE)
-class TestModelFieldValues(TestCase):
- def test_model_field(self):
- class ExampleSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneSourceTestModel
- fields = ('target',)
+def test_model_field(self):
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneSourceTestModel
+ fields = ('target',)
target = OneToOneTargetTestModel(id=1, text='abc')
- source = OneToOneSourceTestModel(target=target)
- serializer = ExampleSerializer(source)
- self.assertEqual(serializer.data, {'target': 1})
+source=OneToOneSourceTestModel(target=target)
+serializer=ExampleSerializer(source)
+assertserializer.data=={'target':1}
-class TestUniquenessOverride(TestCase):
- def test_required_not_overwritten(self):
- class TestModel(models.Model):
- field_1 = models.IntegerField(null=True)
- field_2 = models.IntegerField()
-
- class Meta:
- unique_together = (('field_1', 'field_2'),)
+def test_required_not_overwritten(self):
+ class TestModel(models.Model):
+ field_1 = models.IntegerField(null=True)
+ field_2 = models.IntegerField()
+ class Meta:
+ unique_together = (('field_1', 'field_2'),)
class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = TestModel
- fields = '__all__'
- extra_kwargs = {'field_1': {'required': False}}
+ class Meta:
+ model = TestModel
+ fields = '__all__'
+ extra_kwargs = {'field_1': {'required': False}}
fields = TestSerializer().fields
- self.assertFalse(fields['field_1'].required)
- self.assertTrue(fields['field_2'].required)
+assertnotfields['field_1'].required
+assertfields['field_2'].required
-class Issue3674Test(TestCase):
- def test_nonPK_foreignkey_model_serializer(self):
- class TestParentModel(models.Model):
- title = models.CharField(max_length=64)
+def test_nonPK_foreignkey_model_serializer(self):
+ class TestParentModel(models.Model):
+ title = models.CharField(max_length=64)
class TestChildModel(models.Model):
- parent = models.ForeignKey(TestParentModel, related_name='children', on_delete=models.CASCADE)
- value = models.CharField(primary_key=True, max_length=64)
+ parent = models.ForeignKey(TestParentModel, related_name='children', on_delete=models.CASCADE)
+ value = models.CharField(primary_key=True, max_length=64)
class TestChildModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = TestChildModel
- fields = ('value', 'parent')
+ class Meta:
+ model = TestChildModel
+ fields = ('value', 'parent')
class TestParentModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = TestParentModel
- fields = ('id', 'title', 'children')
+ class Meta:
+ model = TestParentModel
+ fields = ('id', 'title', 'children')
parent_expected = dedent("""
TestParentModelSerializer():
@@ -1132,106 +991,91 @@ class Issue3674Test(TestCase):
title = CharField(max_length=64)
children = PrimaryKeyRelatedField(many=True, queryset=TestChildModel.objects.all())
""")
- self.assertEqual(repr(TestParentModelSerializer()), parent_expected)
-
- child_expected = dedent("""
+assertrepr(TestParentModelSerializer())==parent_expected
+child_expected=dedent("""
TestChildModelSerializer():
value = CharField(max_length=64, validators=[])
parent = PrimaryKeyRelatedField(queryset=TestParentModel.objects.all())
""")
- self.assertEqual(repr(TestChildModelSerializer()), child_expected)
+assertrepr(TestChildModelSerializer())==child_expected
def test_nonID_PK_foreignkey_model_serializer(self):
- class TestChildModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = Issue3674ChildModel
- fields = ('value', 'parent')
+ class TestChildModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Issue3674ChildModel
+ fields = ('value', 'parent')
class TestParentModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = Issue3674ParentModel
- fields = ('id', 'title', 'children')
+ class Meta:
+ model = Issue3674ParentModel
+ fields = ('id', 'title', 'children')
parent = Issue3674ParentModel.objects.create(title='abc')
- child = Issue3674ChildModel.objects.create(value='def', parent=parent)
-
- parent_serializer = TestParentModelSerializer(parent)
- child_serializer = TestChildModelSerializer(child)
-
- parent_expected = {'children': ['def'], 'id': 1, 'title': 'abc'}
- self.assertEqual(parent_serializer.data, parent_expected)
-
- child_expected = {'parent': 1, 'value': 'def'}
- self.assertEqual(child_serializer.data, child_expected)
+child=Issue3674ChildModel.objects.create(value='def',parent=parent)
+parent_serializer=TestParentModelSerializer(parent)
+child_serializer=TestChildModelSerializer(child)
+parent_expected={'children':['def'],'id':1,'title':'abc'}
+assertparent_serializer.data==parent_expected
+child_expected={'parent':1,'value':'def'}
+assertchild_serializer.data==child_expected
-class Issue4897TestCase(TestCase):
- def test_should_assert_if_writing_readonly_fields(self):
- class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneFieldModel
- fields = ('char_field',)
- readonly_fields = fields
+def test_should_assert_if_writing_readonly_fields(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneFieldModel
+ fields = ('char_field',)
+ readonly_fields = fields
obj = OneFieldModel.objects.create(char_field='abc')
-
- with pytest.raises(AssertionError) as cm:
- TestSerializer(obj).fields
+withpytest.raises(AssertionError)ascm:
+ TestSerializer(obj).fields
cm.match(r'readonly_fields')
-class Test5004UniqueChoiceField(TestCase):
- def test_unique_choice_field(self):
- class TestUniqueChoiceSerializer(serializers.ModelSerializer):
- class Meta:
- model = UniqueChoiceModel
- fields = '__all__'
+def test_unique_choice_field(self):
+ class TestUniqueChoiceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = UniqueChoiceModel
+ fields = '__all__'
UniqueChoiceModel.objects.create(name='choice1')
- serializer = TestUniqueChoiceSerializer(data={'name': 'choice1'})
- assert not serializer.is_valid()
- assert serializer.errors == {'name': ['unique choice model with this name already exists.']}
+serializer=TestUniqueChoiceSerializer(data={'name':'choice1'})
+assertnotserializer.is_valid()
+assertserializer.errors=={'name':['unique choice model with this name already exists.']}
-class TestFieldSource(TestCase):
- def test_traverse_nullable_fk(self):
- """
+def test_traverse_nullable_fk(self):
+ """
A dotted source with nullable elements uses default when any item in the chain is None. #5849.
Similar to model example from test_serializer.py `test_default_for_multiple_dotted_source` method,
but using RelatedField, rather than CharField.
"""
- class TestSerializer(serializers.ModelSerializer):
- target = serializers.PrimaryKeyRelatedField(
- source='target.target', read_only=True, allow_null=True, default=None
- )
-
- class Meta:
- model = NestedForeignKeySource
- fields = ('target', )
+ class TestSerializer(serializers.ModelSerializer):
+ target = serializers.PrimaryKeyRelatedField( source='target.target', read_only=True, allow_null=True, default=None )
+ class Meta:
+ model = NestedForeignKeySource
+ fields = ('target', )
model = NestedForeignKeySource.objects.create()
- assert TestSerializer(model).data['target'] is None
+assertTestSerializer(model).data['target']isNone
def test_named_field_source(self):
- class TestSerializer(serializers.ModelSerializer):
+ class TestSerializer(serializers.ModelSerializer):
- class Meta:
- model = RegularFieldsModel
- fields = ('number_field',)
- extra_kwargs = {
- 'number_field': {
- 'source': 'integer_field'
- }
- }
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('number_field',)
+ extra_kwargs = { 'number_field': { 'source': 'integer_field' } }
expected = dedent("""
TestSerializer():
number_field = IntegerField(source='integer_field')
""")
- self.maxDiff = None
- self.assertEqual(repr(TestSerializer()), expected)
+self.maxDiff=None
+assertrepr(TestSerializer())==expected
class Issue6110TestModel(models.Model):
@@ -1247,13 +1091,12 @@ class Issue6110ModelSerializer(serializers.ModelSerializer):
fields = ('name',)
-class Issue6110Test(TestCase):
- def test_model_serializer_custom_manager(self):
- instance = Issue6110ModelSerializer().create({'name': 'test_name'})
- self.assertEqual(instance.name, 'test_name')
+def test_model_serializer_custom_manager(self):
+ instance = Issue6110ModelSerializer().create({'name': 'test_name'})
+ assert instance.name == 'test_name'
def test_model_serializer_custom_manager_error_message(self):
- msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.')
- with self.assertRaisesMessage(TypeError, msginitial):
- Issue6110ModelSerializer().create({'wrong_param': 'wrong_param'})
+ msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.')
+ with self.assertRaisesMessage(TypeError, msginitial):
+ Issue6110ModelSerializer().create({'wrong_param': 'wrong_param'})
diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py
index 1e8ab3448..52b6052b8 100644
--- a/tests/test_multitable_inheritance.py
+++ b/tests/test_multitable_inheritance.py
@@ -33,35 +33,31 @@ class AssociatedModelSerializer(serializers.ModelSerializer):
# Tests
-class InheritedModelSerializationTests(TestCase):
- def test_multitable_inherited_model_fields_as_expected(self):
- """
+def test_multitable_inherited_model_fields_as_expected(self):
+ """
Assert that the parent pointer field is not included in the fields
serialized fields
"""
- child = ChildModel(name1='parent name', name2='child name')
- serializer = DerivedModelSerializer(child)
- assert set(serializer.data) == {'name1', 'name2', 'id'}
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ assert set(serializer.data) == {'name1', 'name2', 'id'}
def test_onetoone_primary_key_model_fields_as_expected(self):
- """
+ """
Assert that a model with a onetoone field that is the primary key is
not treated like a derived model
"""
- parent = ParentModel.objects.create(name1='parent name')
- associate = AssociatedModel.objects.create(name='hello', ref=parent)
- serializer = AssociatedModelSerializer(associate)
- assert set(serializer.data) == {'name', 'ref'}
+ parent = ParentModel.objects.create(name1='parent name')
+ associate = AssociatedModel.objects.create(name='hello', ref=parent)
+ serializer = AssociatedModelSerializer(associate)
+ assert set(serializer.data) == {'name', 'ref'}
def test_data_is_valid_without_parent_ptr(self):
- """
+ """
Assert that the pointer to the parent table is not a required field
for input data
"""
- data = {
- 'name1': 'parent name',
- 'name2': 'child name',
- }
- serializer = DerivedModelSerializer(data=data)
- assert serializer.is_valid() is True
+ data = { 'name1': 'parent name', 'name2': 'child name', }
+ serializer = DerivedModelSerializer(data=data)
+ assert serializer.is_valid() is True
diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py
index 089a86c62..eab8a9b88 100644
--- a/tests/test_negotiation.py
+++ b/tests/test_negotiation.py
@@ -30,70 +30,68 @@ class NoCharsetSpecifiedRenderer(BaseRenderer):
media_type = 'my/media'
-class TestAcceptedMediaType(TestCase):
- def setUp(self):
- self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()]
- self.negotiator = DefaultContentNegotiation()
+def setUp(self):
+ self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()]
+ self.negotiator = DefaultContentNegotiation()
def select_renderer(self, request):
- return self.negotiator.select_renderer(request, self.renderers)
+ return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self):
- request = Request(factory.get('/'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- assert accepted_media_type == 'application/json'
+ request = Request(factory.get('/'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ assert accepted_media_type == 'application/json'
def test_client_underspecifies_accept_use_renderer(self):
- request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- assert accepted_media_type == 'application/json'
+ request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ assert accepted_media_type == 'application/json'
def test_client_overspecifies_accept_use_client(self):
- request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- assert accepted_media_type == 'application/json; indent=8'
+ request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ assert accepted_media_type == 'application/json; indent=8'
def test_client_specifies_parameter(self):
- request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- assert accepted_media_type == 'application/openapi+json;version=2.0'
- assert accepted_renderer.format == 'swagger'
+ request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ assert accepted_media_type == 'application/openapi+json;version=2.0'
+ assert accepted_renderer.format == 'swagger'
def test_match_is_false_if_main_types_not_match(self):
- mediatype = _MediaType('test_1')
- anoter_mediatype = _MediaType('test_2')
- assert mediatype.match(anoter_mediatype) is False
+ mediatype = _MediaType('test_1')
+ anoter_mediatype = _MediaType('test_2')
+ assert mediatype.match(anoter_mediatype) is False
def test_mediatype_match_is_false_if_keys_not_match(self):
- mediatype = _MediaType(';test_param=foo')
- another_mediatype = _MediaType(';test_param=bar')
- assert mediatype.match(another_mediatype) is False
+ mediatype = _MediaType(';test_param=foo')
+ another_mediatype = _MediaType(';test_param=bar')
+ assert mediatype.match(another_mediatype) is False
def test_mediatype_precedence_with_wildcard_subtype(self):
- mediatype = _MediaType('test/*')
- assert mediatype.precedence == 1
+ mediatype = _MediaType('test/*')
+ assert mediatype.precedence == 1
def test_mediatype_string_representation(self):
- mediatype = _MediaType('test/*; foo=bar')
- assert str(mediatype) == 'test/*; foo=bar'
+ mediatype = _MediaType('test/*; foo=bar')
+ assert str(mediatype) == 'test/*; foo=bar'
def test_raise_error_if_no_suitable_renderers_found(self):
- class MockRenderer:
- format = 'xml'
+ class MockRenderer:
+ format = 'xml'
renderers = [MockRenderer()]
- with pytest.raises(Http404):
- self.negotiator.filter_renderers(renderers, format='json')
+withpytest.raises(Http404):
+ self.negotiator.filter_renderers(renderers, format='json')
-class BaseContentNegotiationTests(TestCase):
- def setUp(self):
- self.negotiator = BaseContentNegotiation()
+def setUp(self):
+ self.negotiator = BaseContentNegotiation()
def test_raise_error_for_abstract_select_parser_method(self):
- with pytest.raises(NotImplementedError):
- self.negotiator.select_parser(None, None)
+ with pytest.raises(NotImplementedError):
+ self.negotiator.select_parser(None, None)
def test_raise_error_for_abstract_select_renderer_method(self):
- with pytest.raises(NotImplementedError):
- self.negotiator.select_renderer(None, None)
+ with pytest.raises(NotImplementedError):
+ self.negotiator.select_renderer(None, None)
diff --git a/tests/test_one_to_one_with_inheritance.py b/tests/test_one_to_one_with_inheritance.py
index 40793d7ca..7f67fbb7c 100644
--- a/tests/test_one_to_one_with_inheritance.py
+++ b/tests/test_one_to_one_with_inheritance.py
@@ -28,14 +28,12 @@ class ChildAssociatedModelSerializer(serializers.ModelSerializer):
# Tests
-class InheritedModelSerializationTests(TestCase):
- def test_multitable_inherited_model_fields_as_expected(self):
- """
+def test_multitable_inherited_model_fields_as_expected(self):
+ """
Assert that the parent pointer field is not included in the fields
serialized fields
"""
- child = ChildModel(name1='parent name', name2='child name')
- serializer = DerivedModelSerializer(child)
- self.assertEqual(set(serializer.data),
- {'name1', 'name2', 'id', 'childassociatedmodel'})
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ assert set(serializer.data) == {'name1', 'name2', 'id', 'childassociatedmodel'}
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
index dcd62fac9..8040c160e 100644
--- a/tests/test_parsers.py
+++ b/tests/test_parsers.py
@@ -22,157 +22,138 @@ class Form(forms.Form):
field2 = forms.CharField()
-class TestFormParser(TestCase):
- def setUp(self):
- self.string = "field1=abc&field2=defghijk"
+def setUp(self):
+ self.string = "field1=abc&field2=defghijk"
def test_parse(self):
- """ Make sure the `QueryDict` works OK """
- parser = FormParser()
-
- stream = io.StringIO(self.string)
- data = parser.parse(stream)
-
- assert Form(data).is_valid() is True
+ """ Make sure the `QueryDict` works OK """
+ parser = FormParser()
+ stream = io.StringIO(self.string)
+ data = parser.parse(stream)
+ assert Form(data).is_valid() is True
-class TestFileUploadParser(TestCase):
- def setUp(self):
- class MockRequest:
- pass
+def setUp(self):
+ class MockRequest:
+ pass
self.stream = io.BytesIO(b"Test text file")
- request = MockRequest()
- request.upload_handlers = (MemoryFileUploadHandler(),)
- request.META = {
- 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt',
- 'HTTP_CONTENT_LENGTH': 14,
- }
- self.parser_context = {'request': request, 'kwargs': {}}
+request=MockRequest()
+request.upload_handlers=(MemoryFileUploadHandler(),)
+request.META={'HTTP_CONTENT_DISPOSITION':'Content-Disposition: inline; filename=file.txt','HTTP_CONTENT_LENGTH':14,}
+self.parser_context={'request':request,'kwargs':{}}
def test_parse(self):
- """
+ """
Parse raw file upload.
"""
- parser = FileUploadParser()
- self.stream.seek(0)
- data_and_files = parser.parse(self.stream, None, self.parser_context)
- file_obj = data_and_files.files['file']
- assert file_obj.size == 14
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ data_and_files = parser.parse(self.stream, None, self.parser_context)
+ file_obj = data_and_files.files['file']
+ assert file_obj.size == 14
def test_parse_missing_filename(self):
- """
+ """
Parse raw file upload when filename is missing.
"""
- parser = FileUploadParser()
- self.stream.seek(0)
- self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
- with pytest.raises(ParseError) as excinfo:
- parser.parse(self.stream, None, self.parser_context)
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
+ with pytest.raises(ParseError) as excinfo:
+ parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_parse_missing_filename_multiple_upload_handlers(self):
- """
+ """
Parse raw file upload with multiple handlers when filename is missing.
Regression test for #2109.
"""
- parser = FileUploadParser()
- self.stream.seek(0)
- self.parser_context['request'].upload_handlers = (
- MemoryFileUploadHandler(),
- MemoryFileUploadHandler()
- )
- self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
- with pytest.raises(ParseError) as excinfo:
- parser.parse(self.stream, None, self.parser_context)
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ self.parser_context['request'].upload_handlers = ( MemoryFileUploadHandler(), MemoryFileUploadHandler() )
+ self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
+ with pytest.raises(ParseError) as excinfo:
+ parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_parse_missing_filename_large_file(self):
- """
+ """
Parse raw file upload when filename is missing with TemporaryFileUploadHandler.
"""
- parser = FileUploadParser()
- self.stream.seek(0)
- self.parser_context['request'].upload_handlers = (
- TemporaryFileUploadHandler(),
- )
- self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
- with pytest.raises(ParseError) as excinfo:
- parser.parse(self.stream, None, self.parser_context)
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ self.parser_context['request'].upload_handlers = ( TemporaryFileUploadHandler(), )
+ self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
+ with pytest.raises(ParseError) as excinfo:
+ parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_get_filename(self):
- parser = FileUploadParser()
- filename = parser.get_filename(self.stream, None, self.parser_context)
- assert filename == 'file.txt'
+ parser = FileUploadParser()
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ assert filename == 'file.txt'
def test_get_encoded_filename(self):
- parser = FileUploadParser()
-
- self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
- filename = parser.get_filename(self.stream, None, self.parser_context)
- assert filename == 'ÀĥƦ.txt'
-
- self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
- filename = parser.get_filename(self.stream, None, self.parser_context)
- assert filename == 'ÀĥƦ.txt'
-
- self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
- filename = parser.get_filename(self.stream, None, self.parser_context)
- assert filename == 'ÀĥƦ.txt'
+ parser = FileUploadParser()
+ self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ assert filename == 'ÀĥƦ.txt'
+ self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ assert filename == 'ÀĥƦ.txt'
+ self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ assert filename == 'ÀĥƦ.txt'
def __replace_content_disposition(self, disposition):
- self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
+ self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
-class TestJSONParser(TestCase):
- def bytes(self, value):
- return io.BytesIO(value.encode())
+def bytes(self, value):
+ return io.BytesIO(value.encode())
def test_float_strictness(self):
- parser = JSONParser()
-
- # Default to strict
- for value in ['Infinity', '-Infinity', 'NaN']:
- with pytest.raises(ParseError):
- parser.parse(self.bytes(value))
+ parser = JSONParser()
+ for value in ['Infinity', '-Infinity', 'NaN']:
+ with pytest.raises(ParseError):
+ parser.parse(self.bytes(value))
parser.strict = False
- assert parser.parse(self.bytes('Infinity')) == float('inf')
- assert parser.parse(self.bytes('-Infinity')) == float('-inf')
- assert math.isnan(parser.parse(self.bytes('NaN')))
+assertparser.parse(self.bytes('Infinity'))==float('inf')
+assertparser.parse(self.bytes('-Infinity'))==float('-inf')
+assertmath.isnan(parser.parse(self.bytes('NaN')))
-class TestPOSTAccessed(TestCase):
- def setUp(self):
- self.factory = APIRequestFactory()
+def setUp(self):
+ self.factory = APIRequestFactory()
def test_post_accessed_in_post_method(self):
- django_request = self.factory.post('/', {'foo': 'bar'})
- request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
- django_request.POST
- assert request.POST == {'foo': ['bar']}
- assert request.data == {'foo': ['bar']}
+ django_request = self.factory.post('/', {'foo': 'bar'})
+ request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
+ django_request.POST
+ assert request.POST == {'foo': ['bar']}
+ assert request.data == {'foo': ['bar']}
def test_post_accessed_in_post_method_with_json_parser(self):
- django_request = self.factory.post('/', {'foo': 'bar'})
- request = Request(django_request, parsers=[JSONParser()])
- django_request.POST
- assert request.POST == {}
- assert request.data == {}
+ django_request = self.factory.post('/', {'foo': 'bar'})
+ request = Request(django_request, parsers=[JSONParser()])
+ django_request.POST
+ assert request.POST == {}
+ assert request.data == {}
def test_post_accessed_in_put_method(self):
- django_request = self.factory.put('/', {'foo': 'bar'})
- request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
- django_request.POST
- assert request.POST == {'foo': ['bar']}
- assert request.data == {'foo': ['bar']}
+ django_request = self.factory.put('/', {'foo': 'bar'})
+ request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
+ django_request.POST
+ assert request.POST == {'foo': ['bar']}
+ assert request.data == {'foo': ['bar']}
def test_request_read_before_parsing(self):
- django_request = self.factory.put('/', {'foo': 'bar'})
- request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
- django_request.read()
+ django_request = self.factory.put('/', {'foo': 'bar'})
+ request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
+ django_request.read()
+ with pytest.raises(RawPostDataException):
+ request.POST
with pytest.raises(RawPostDataException):
- request.POST
- with pytest.raises(RawPostDataException):
- request.POST
- request.data
+ request.POST
+ request.data
diff --git a/tests/test_permissions.py b/tests/test_permissions.py
index 9c9300694..92d1d4c4c 100644
--- a/tests/test_permissions.py
+++ b/tests/test_permissions.py
@@ -72,177 +72,136 @@ def basic_auth_header(username, password):
return 'Basic %s' % base64_credentials
-class ModelPermissionsIntegrationTests(TestCase):
- def setUp(self):
- User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
- user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
- user.user_permissions.set([
- Permission.objects.get(codename='add_basicmodel'),
- Permission.objects.get(codename='change_basicmodel'),
- Permission.objects.get(codename='delete_basicmodel')
- ])
-
- user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
- user.user_permissions.set([
- Permission.objects.get(codename='change_basicmodel'),
- ])
-
- self.permitted_credentials = basic_auth_header('permitted', 'password')
- self.disallowed_credentials = basic_auth_header('disallowed', 'password')
- self.updateonly_credentials = basic_auth_header('updateonly', 'password')
-
- BasicModel(text='foo').save()
+def setUp(self):
+ User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
+ user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
+ user.user_permissions.set([ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') ])
+ user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
+ user.user_permissions.set([ Permission.objects.get(codename='change_basicmodel'), ])
+ self.permitted_credentials = basic_auth_header('permitted', 'password')
+ self.disallowed_credentials = basic_auth_header('disallowed', 'password')
+ self.updateonly_credentials = basic_auth_header('updateonly', 'password')
+ BasicModel(text='foo').save()
def test_has_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = root_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk=1)
+ assert response.status_code == status.HTTP_201_CREATED
def test_api_root_view_discard_default_django_model_permission(self):
- """
+ """
We check that DEFAULT_PERMISSION_CLASSES can
apply to APIRoot view. More specifically we check expected behavior of
``_ignore_model_permissions`` attribute support.
"""
- request = factory.get('/', format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- request.resolver_match = ResolverMatch('get', (), {})
- response = api_root_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
+ request = factory.get('/', format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
+ request.resolver_match = ResolverMatch('get', (), {})
+ response = api_root_view(request)
+ assert response.status_code == status.HTTP_200_OK
def test_get_queryset_has_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = get_queryset_list_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = get_queryset_list_view(request, pk=1)
+ assert response.status_code == status.HTTP_201_CREATED
def test_has_put_permissions(self):
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
+ request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
def test_has_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk=1)
+ assert response.status_code == status.HTTP_204_NO_CONTENT
def test_does_not_have_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = root_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk=1)
+ assert response.status_code == status.HTTP_403_FORBIDDEN
def test_does_not_have_put_permissions(self):
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ assert response.status_code == status.HTTP_403_FORBIDDEN
def test_does_not_have_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk=1)
+ assert response.status_code == status.HTTP_403_FORBIDDEN
def test_options_permitted(self):
- request = factory.options(
- '/',
- HTTP_AUTHORIZATION=self.permitted_credentials
- )
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions']), ['POST'])
-
- request = factory.options(
- '/1',
- HTTP_AUTHORIZATION=self.permitted_credentials
- )
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions']), ['PUT'])
+ request = factory.options( '/', HTTP_AUTHORIZATION=self.permitted_credentials )
+ response = root_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
+ assert 'actions' in response.data
+ assert list(response.data['actions']) == ['POST']
+ request = factory.options( '/1', HTTP_AUTHORIZATION=self.permitted_credentials )
+ response = instance_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
+ assert 'actions' in response.data
+ assert list(response.data['actions']) == ['PUT']
def test_options_disallowed(self):
- request = factory.options(
- '/',
- HTTP_AUTHORIZATION=self.disallowed_credentials
- )
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
-
- request = factory.options(
- '/1',
- HTTP_AUTHORIZATION=self.disallowed_credentials
- )
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
+ request = factory.options( '/', HTTP_AUTHORIZATION=self.disallowed_credentials )
+ response = root_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
+ assert 'actions' not in response.data
+ request = factory.options( '/1', HTTP_AUTHORIZATION=self.disallowed_credentials )
+ response = instance_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
+ assert 'actions' not in response.data
def test_options_updateonly(self):
- request = factory.options(
- '/',
- HTTP_AUTHORIZATION=self.updateonly_credentials
- )
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
-
- request = factory.options(
- '/1',
- HTTP_AUTHORIZATION=self.updateonly_credentials
- )
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions']), ['PUT'])
+ request = factory.options( '/', HTTP_AUTHORIZATION=self.updateonly_credentials )
+ response = root_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
+ assert 'actions' not in response.data
+ request = factory.options( '/1', HTTP_AUTHORIZATION=self.updateonly_credentials )
+ response = instance_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
+ assert 'actions' in response.data
+ assert list(response.data['actions']) == ['PUT']
def test_empty_view_does_not_assert(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
- response = empty_list_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = empty_list_view(request, pk=1)
+ assert response.status_code == status.HTTP_200_OK
def test_calling_method_not_allowed(self):
- request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials)
- response = root_view(request)
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
-
- request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request)
+ assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
+ request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_check_auth_before_queryset_call(self):
- class View(RootView):
- def get_queryset(_):
- self.fail('should not reach due to auth check')
+ class View(RootView):
+ def get_queryset(_):
+ self.fail('should not reach due to auth check')
view = View.as_view()
-
- request = factory.get('/', HTTP_AUTHORIZATION='')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+request=factory.get('/',HTTP_AUTHORIZATION='')
+response=view(request)
+assertresponse.status_code==status.HTTP_401_UNAUTHORIZED
def test_queryset_assertions(self):
- class View(views.APIView):
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [permissions.DjangoModelPermissions]
+ class View(views.APIView):
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
view = View.as_view()
-
- request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials)
- msg = 'Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.'
- with self.assertRaisesMessage(AssertionError, msg):
- view(request)
+request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
+msg='Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.'
+withself.assertRaisesMessage(AssertionError,msg):
+ view(request)
# Faulty `get_queryset()` methods should trigger the above "view does not have a queryset" assertion.
class View(RootView):
- def get_queryset(self):
- return None
+ def get_queryset(self):
+ return None
view = View.as_view()
-
- request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials)
- with self.assertRaisesMessage(AssertionError, 'View.get_queryset() returned None'):
- view(request)
+request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
+withself.assertRaisesMessage(AssertionError,'View.get_queryset() returned None'):
+ view(request)
class BasicPermModel(models.Model):
@@ -310,149 +269,117 @@ get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.a
@unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed')
-class ObjectPermissionsIntegrationTests(TestCase):
- """
+"""
Integration tests for the object level permissions API.
"""
- def setUp(self):
- from guardian.shortcuts import assign_perm
-
- # create users
- create = User.objects.create_user
- users = {
- 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
- 'readonly': create('readonly', 'readonly@example.com', 'password'),
- 'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
- 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
- }
-
- # give everyone model level permissions, as we are not testing those
- everyone = Group.objects.create(name='everyone')
- model_name = BasicPermModel._meta.model_name
- app_label = BasicPermModel._meta.app_label
- f = '{}_{}'.format
- perms = {
- 'view': f('view', model_name),
- 'change': f('change', model_name),
- 'delete': f('delete', model_name)
- }
- for perm in perms.values():
- perm = '{}.{}'.format(app_label, perm)
- assign_perm(perm, everyone)
+defsetUp(self):
+ from guardian.shortcuts import assign_perm
+ create = User.objects.create_user
+ users = { 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), 'readonly': create('readonly', 'readonly@example.com', 'password'), 'writeonly': create('writeonly', 'writeonly@example.com', 'password'), 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), }
+ everyone = Group.objects.create(name='everyone')
+ model_name = BasicPermModel._meta.model_name
+ app_label = BasicPermModel._meta.app_label
+ f = '{}_{}'.format
+ perms = { 'view': f('view', model_name), 'change': f('change', model_name), 'delete': f('delete', model_name) }
+ for perm in perms.values():
+ perm = '{}.{}'.format(app_label, perm)
+ assign_perm(perm, everyone)
everyone.user_set.add(*users.values())
-
- # appropriate object level permissions
- readers = Group.objects.create(name='readers')
- writers = Group.objects.create(name='writers')
- deleters = Group.objects.create(name='deleters')
-
- model = BasicPermModel.objects.create(text='foo')
-
- assign_perm(perms['view'], readers, model)
- assign_perm(perms['change'], writers, model)
- assign_perm(perms['delete'], deleters, model)
-
- readers.user_set.add(users['fullaccess'], users['readonly'])
- writers.user_set.add(users['fullaccess'], users['writeonly'])
- deleters.user_set.add(users['fullaccess'], users['deleteonly'])
-
- self.credentials = {}
- for user in users.values():
- self.credentials[user.username] = basic_auth_header(user.username, 'password')
+readers=Group.objects.create(name='readers')
+writers=Group.objects.create(name='writers')
+deleters=Group.objects.create(name='deleters')
+model=BasicPermModel.objects.create(text='foo')
+assign_perm(perms['view'],readers,model)
+assign_perm(perms['change'],writers,model)
+assign_perm(perms['delete'],deleters,model)
+readers.user_set.add(users['fullaccess'],users['readonly'])
+writers.user_set.add(users['fullaccess'],users['writeonly'])
+deleters.user_set.add(users['fullaccess'],users['deleteonly'])
+self.credentials={}
+foruserinusers.values():
+ self.credentials[user.username] = basic_auth_header(user.username, 'password')
# Delete
def test_can_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
+ response = object_permissions_view(request, pk='1')
+ assert response.status_code == status.HTTP_204_NO_CONTENT
def test_cannot_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_view(request, pk='1')
+ assert response.status_code == status.HTTP_403_FORBIDDEN
# Update
def test_can_update_permissions(self):
- request = factory.patch(
- '/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['writeonly']
- )
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data.get('text'), 'foobar')
+ request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['writeonly'] )
+ response = object_permissions_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data.get('text') == 'foobar'
def test_cannot_update_permissions(self):
- request = factory.patch(
- '/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['deleteonly']
- )
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
+ response = object_permissions_view(request, pk='1')
+ assert response.status_code == status.HTTP_404_NOT_FOUND
def test_cannot_update_permissions_non_existing(self):
- request = factory.patch(
- '/999', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['deleteonly']
- )
- response = object_permissions_view(request, pk='999')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ request = factory.patch( '/999', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
+ response = object_permissions_view(request, pk='999')
+ assert response.status_code == status.HTTP_404_NOT_FOUND
# Read
def test_can_read_permissions(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
def test_cannot_read_permissions(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ response = object_permissions_view(request, pk='1')
+ assert response.status_code == status.HTTP_404_NOT_FOUND
def test_can_read_get_queryset_permissions(self):
- """
+ """
same as ``test_can_read_permissions`` but with a view
that rely on ``.get_queryset()`` instead of ``.queryset``.
"""
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = get_queryset_object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = get_queryset_object_permissions_view(request, pk='1')
+ assert response.status_code == status.HTTP_200_OK
# Read list
def test_django_object_permissions_filter_deprecated(self):
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- DjangoObjectPermissionsFilter()
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ DjangoObjectPermissionsFilter()
- message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved "
- "to the 3rd-party django-rest-framework-guardian package.")
- self.assertEqual(len(w), 1)
- self.assertIs(w[-1].category, RemovedInDRF310Warning)
- self.assertEqual(str(w[-1].message), message)
+ message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved ""to the 3rd-party django-rest-framework-guardian package.")
+assertlen(w)==1
+assertw[-1].categoryisRemovedInDRF310Warning
+assertstr(w[-1].message)==message
def test_can_read_list_permissions(self):
- request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
- object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
- # TODO: remove in version 3.10
- with warnings.catch_warnings(record=True):
- warnings.simplefilter("always")
- response = object_permissions_list_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data[0].get('id'), 1)
+ request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter("always")
+ response = object_permissions_list_view(request)
+ assert response.status_code== status.HTTP_200_OK
+assertresponse.data[0].get('id')==1
def test_cannot_read_list_permissions(self):
- request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
- object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
- # TODO: remove in version 3.10
- with warnings.catch_warnings(record=True):
- warnings.simplefilter("always")
- response = object_permissions_list_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertListEqual(response.data, [])
+ request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter("always")
+ response = object_permissions_list_view(request)
+ assert response.status_code== status.HTTP_200_OK
+assertresponse.data==[]
def test_cannot_method_not_allowed(self):
- request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_list_view(request)
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_list_view(request)
+ assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
class BasicPerm(permissions.BasePermission):
@@ -507,203 +434,176 @@ denied_view_with_detail = DeniedViewWithDetail.as_view()
denied_object_view = DeniedObjectView.as_view()
denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()
-
-
-class CustomPermissionsTests(TestCase):
- def setUp(self):
- BasicModel(text='foo').save()
- User.objects.create_user('username', 'username@example.com', 'password')
- credentials = basic_auth_header('username', 'password')
- self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials)
- self.custom_message = 'Custom: You cannot access this resource'
+def setUp(self):
+ BasicModel(text='foo').save()
+ User.objects.create_user('username', 'username@example.com', 'password')
+ credentials = basic_auth_header('username', 'password')
+ self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials)
+ self.custom_message = 'Custom: You cannot access this resource'
def test_permission_denied(self):
- response = denied_view(self.request, pk=1)
- detail = response.data.get('detail')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertNotEqual(detail, self.custom_message)
+ response = denied_view(self.request, pk=1)
+ detail = response.data.get('detail')
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+ assert detail != self.custom_message
def test_permission_denied_with_custom_detail(self):
- response = denied_view_with_detail(self.request, pk=1)
- detail = response.data.get('detail')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(detail, self.custom_message)
+ response = denied_view_with_detail(self.request, pk=1)
+ detail = response.data.get('detail')
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+ assert detail == self.custom_message
def test_permission_denied_for_object(self):
- response = denied_object_view(self.request, pk=1)
- detail = response.data.get('detail')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertNotEqual(detail, self.custom_message)
+ response = denied_object_view(self.request, pk=1)
+ detail = response.data.get('detail')
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+ assert detail != self.custom_message
def test_permission_denied_for_object_with_custom_detail(self):
- response = denied_object_view_with_detail(self.request, pk=1)
- detail = response.data.get('detail')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(detail, self.custom_message)
+ response = denied_object_view_with_detail(self.request, pk=1)
+ detail = response.data.get('detail')
+ assert response.status_code == status.HTTP_403_FORBIDDEN
+ assert detail == self.custom_message
-class PermissionsCompositionTests(TestCase):
- def setUp(self):
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(
- self.username,
- self.email,
- self.password
- )
- self.client.login(username=self.username, password=self.password)
+def setUp(self):
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user( self.username, self.email, self.password )
+ self.client.login(username=self.username, password=self.password)
def test_and_false(self):
- request = factory.get('/1', format='json')
- request.user = AnonymousUser()
- composed_perm = permissions.IsAuthenticated & permissions.AllowAny
- assert composed_perm().has_permission(request, None) is False
+ request = factory.get('/1', format='json')
+ request.user = AnonymousUser()
+ composed_perm = permissions.IsAuthenticated & permissions.AllowAny
+ assert composed_perm().has_permission(request, None) is False
def test_and_true(self):
- request = factory.get('/1', format='json')
- request.user = self.user
- composed_perm = permissions.IsAuthenticated & permissions.AllowAny
- assert composed_perm().has_permission(request, None) is True
+ request = factory.get('/1', format='json')
+ request.user = self.user
+ composed_perm = permissions.IsAuthenticated & permissions.AllowAny
+ assert composed_perm().has_permission(request, None) is True
def test_or_false(self):
- request = factory.get('/1', format='json')
- request.user = AnonymousUser()
- composed_perm = permissions.IsAuthenticated | permissions.AllowAny
- assert composed_perm().has_permission(request, None) is True
+ request = factory.get('/1', format='json')
+ request.user = AnonymousUser()
+ composed_perm = permissions.IsAuthenticated | permissions.AllowAny
+ assert composed_perm().has_permission(request, None) is True
def test_or_true(self):
- request = factory.get('/1', format='json')
- request.user = self.user
- composed_perm = permissions.IsAuthenticated | permissions.AllowAny
- assert composed_perm().has_permission(request, None) is True
+ request = factory.get('/1', format='json')
+ request.user = self.user
+ composed_perm = permissions.IsAuthenticated | permissions.AllowAny
+ assert composed_perm().has_permission(request, None) is True
def test_not_false(self):
- request = factory.get('/1', format='json')
- request.user = AnonymousUser()
- composed_perm = ~permissions.IsAuthenticated
- assert composed_perm().has_permission(request, None) is True
+ request = factory.get('/1', format='json')
+ request.user = AnonymousUser()
+ composed_perm = ~permissions.IsAuthenticated
+ assert composed_perm().has_permission(request, None) is True
def test_not_true(self):
- request = factory.get('/1', format='json')
- request.user = self.user
- composed_perm = ~permissions.AllowAny
- assert composed_perm().has_permission(request, None) is False
+ request = factory.get('/1', format='json')
+ request.user = self.user
+ composed_perm = ~permissions.AllowAny
+ assert composed_perm().has_permission(request, None) is False
def test_several_levels_without_negation(self):
- request = factory.get('/1', format='json')
- request.user = self.user
- composed_perm = (
- permissions.IsAuthenticated &
- permissions.IsAuthenticated &
- permissions.IsAuthenticated &
- permissions.IsAuthenticated
- )
- assert composed_perm().has_permission(request, None) is True
+ request = factory.get('/1', format='json')
+ request.user = self.user
+ composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated )
+ assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence_with_negation(self):
- request = factory.get('/1', format='json')
- request.user = self.user
- composed_perm = (
- permissions.IsAuthenticated &
- ~ permissions.IsAdminUser &
- permissions.IsAuthenticated &
- ~(permissions.IsAdminUser & permissions.IsAdminUser)
- )
- assert composed_perm().has_permission(request, None) is True
+ request = factory.get('/1', format='json')
+ request.user = self.user
+ composed_perm = ( permissions.IsAuthenticated & ~ permissions.IsAdminUser & permissions.IsAuthenticated & ~(permissions.IsAdminUser & permissions.IsAdminUser) )
+ assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence(self):
- request = factory.get('/1', format='json')
- request.user = self.user
- composed_perm = (
- permissions.IsAuthenticated &
- permissions.IsAuthenticated |
- permissions.IsAuthenticated &
- permissions.IsAuthenticated
- )
- assert composed_perm().has_permission(request, None) is True
+ request = factory.get('/1', format='json')
+ request.user = self.user
+ composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated | permissions.IsAuthenticated & permissions.IsAuthenticated )
+ assert composed_perm().has_permission(request, None) is True
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
- def test_or_lazyness(self):
- request = factory.get('/1', format='json')
- request.user = AnonymousUser()
+deftest_or_lazyness(self):
+ request = factory.get('/1', format='json')
+ request.user = AnonymousUser()
+ with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
+ with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
+ composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
+ hasperm = composed_perm().has_permission(request, None)
+ assert hasperm is True
+ mock_allow.assert_called_once()
+ mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
- hasperm = composed_perm().has_permission(request, None)
- self.assertIs(hasperm, True)
- mock_allow.assert_called_once()
- mock_deny.assert_not_called()
-
- with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
- hasperm = composed_perm().has_permission(request, None)
- self.assertIs(hasperm, True)
- mock_deny.assert_called_once()
- mock_allow.assert_called_once()
+ with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
+ composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
+ hasperm = composed_perm().has_permission(request, None)
+ assert hasperm is True
+ mock_deny.assert_called_once()
+ mock_allow.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
- def test_object_or_lazyness(self):
- request = factory.get('/1', format='json')
- request.user = AnonymousUser()
+deftest_object_or_lazyness(self):
+ request = factory.get('/1', format='json')
+ request.user = AnonymousUser()
+ with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
+ with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
+ composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
+ hasperm = composed_perm().has_object_permission(request, None, None)
+ assert hasperm is True
+ mock_allow.assert_called_once()
+ mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
- hasperm = composed_perm().has_object_permission(request, None, None)
- self.assertIs(hasperm, True)
- mock_allow.assert_called_once()
- mock_deny.assert_not_called()
-
- with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
- hasperm = composed_perm().has_object_permission(request, None, None)
- self.assertIs(hasperm, True)
- mock_deny.assert_called_once()
- mock_allow.assert_called_once()
+ with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
+ composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
+ hasperm = composed_perm().has_object_permission(request, None, None)
+ assert hasperm is True
+ mock_deny.assert_called_once()
+ mock_allow.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
- def test_and_lazyness(self):
- request = factory.get('/1', format='json')
- request.user = AnonymousUser()
+deftest_and_lazyness(self):
+ request = factory.get('/1', format='json')
+ request.user = AnonymousUser()
+ with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
+ with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
+ composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
+ hasperm = composed_perm().has_permission(request, None)
+ assert hasperm is False
+ mock_allow.assert_called_once()
+ mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
- hasperm = composed_perm().has_permission(request, None)
- self.assertIs(hasperm, False)
- mock_allow.assert_called_once()
- mock_deny.assert_called_once()
-
- with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
- hasperm = composed_perm().has_permission(request, None)
- self.assertIs(hasperm, False)
- mock_allow.assert_not_called()
- mock_deny.assert_called_once()
+ with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
+ composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
+ hasperm = composed_perm().has_permission(request, None)
+ assert hasperm is False
+ mock_allow.assert_not_called()
+ mock_deny.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
- def test_object_and_lazyness(self):
- request = factory.get('/1', format='json')
- request.user = AnonymousUser()
+deftest_object_and_lazyness(self):
+ request = factory.get('/1', format='json')
+ request.user = AnonymousUser()
+ with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
+ with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
+ composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
+ hasperm = composed_perm().has_object_permission(request, None, None)
+ assert hasperm is False
+ mock_allow.assert_called_once()
+ mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
- hasperm = composed_perm().has_object_permission(request, None, None)
- self.assertIs(hasperm, False)
- mock_allow.assert_called_once()
- mock_deny.assert_called_once()
-
- with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
- hasperm = composed_perm().has_object_permission(request, None, None)
- self.assertIs(hasperm, False)
- mock_allow.assert_not_called()
- mock_deny.assert_called_once()
+ with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
+ composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
+ hasperm = composed_perm().has_object_permission(request, None, None)
+ assert hasperm is False
+ mock_allow.assert_not_called()
+ mock_deny.assert_called_once()
diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py
index b07087c97..e4162d9aa 100644
--- a/tests/test_prefetch_related.py
+++ b/tests/test_prefetch_related.py
@@ -18,41 +18,30 @@ class UserUpdate(generics.UpdateAPIView):
serializer_class = UserSerializer
-class TestPrefetchRelatedUpdates(TestCase):
- def setUp(self):
- self.user = User.objects.create(username='tom', email='tom@example.com')
- self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
- self.user.groups.set(self.groups)
+def setUp(self):
+ self.user = User.objects.create(username='tom', email='tom@example.com')
+ self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
+ self.user.groups.set(self.groups)
def test_prefetch_related_updates(self):
- view = UserUpdate.as_view()
- pk = self.user.pk
- groups_pk = self.groups[0].pk
- request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
- response = view(request, pk=pk)
- assert User.objects.get(pk=pk).groups.count() == 1
- expected = {
- 'id': pk,
- 'username': 'new',
- 'groups': [1],
- 'email': 'tom@example.com'
- }
- assert response.data == expected
+ view = UserUpdate.as_view()
+ pk = self.user.pk
+ groups_pk = self.groups[0].pk
+ request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
+ response = view(request, pk=pk)
+ assert User.objects.get(pk=pk).groups.count() == 1
+ expected = { 'id': pk, 'username': 'new', 'groups': [1], 'email': 'tom@example.com' }
+ assert response.data == expected
def test_prefetch_related_excluding_instance_from_original_queryset(self):
- """
+ """
Regression test for https://github.com/encode/django-rest-framework/issues/4661
"""
- view = UserUpdate.as_view()
- pk = self.user.pk
- groups_pk = self.groups[0].pk
- request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
- response = view(request, pk=pk)
- assert User.objects.get(pk=pk).groups.count() == 1
- expected = {
- 'id': pk,
- 'username': 'exclude',
- 'groups': [1],
- 'email': 'tom@example.com'
- }
- assert response.data == expected
+ view = UserUpdate.as_view()
+ pk = self.user.pk
+ groups_pk = self.groups[0].pk
+ request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
+ response = view(request, pk=pk)
+ assert User.objects.get(pk=pk).groups.count() == 1
+ expected = { 'id': pk, 'username': 'exclude', 'groups': [1], 'email': 'tom@example.com' }
+ assert response.data == expected
diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py
index 5ad0e31ff..735e432d1 100644
--- a/tests/test_relations_hyperlink.py
+++ b/tests/test_relations_hyperlink.py
@@ -70,380 +70,268 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
-class HyperlinkedManyToManyTests(TestCase):
- def setUp(self):
- for idx in range(1, 4):
- target = ManyToManyTarget(name='target-%d' % idx)
- target.save()
- source = ManyToManySource(name='source-%d' % idx)
- source.save()
- for target in ManyToManyTarget.objects.all():
- source.targets.add(target)
+def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
def test_relative_hyperlinks(self):
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
- expected = [
- {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']},
- {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
- {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
- ]
- with self.assertNumQueries(4):
- assert serializer.data == expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
+ expected = [ {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} ]
+ with self.assertNumQueries(4):
+ assert serializer.data == expected
def test_many_to_many_retrieve(self):
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- ]
- with self.assertNumQueries(4):
- assert serializer.data == expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
+ with self.assertNumQueries(4):
+ assert serializer.data == expected
def test_many_to_many_retrieve_prefetch_related(self):
- queryset = ManyToManySource.objects.all().prefetch_related('targets')
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- with self.assertNumQueries(2):
- serializer.data
+ queryset = ManyToManySource.objects.all().prefetch_related('targets')
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ with self.assertNumQueries(2):
+ serializer.data
def test_reverse_many_to_many_retrieve(self):
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
- ]
- with self.assertNumQueries(4):
- assert serializer.data == expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
+ with self.assertNumQueries(4):
+ assert serializer.data == expected
def test_many_to_many_update(self):
- data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- ]
- assert serializer.data == expected
+ data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
+ assert serializer.data == expected
def test_reverse_many_to_many_update(self):
- data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
- instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
- # Ensure target 1 is updated, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
-
- ]
- assert serializer.data == expected
+ data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
+ assert serializer.data == expected
def test_many_to_many_create(self):
- data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
- serializer = ManyToManySourceSerializer(data=data, context={'request': request})
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
- {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
- ]
- assert serializer.data == expected
+ data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ serializer = ManyToManySourceSerializer(data=data, context={'request': request})
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'source-4'
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ]
+ assert serializer.data == expected
def test_reverse_many_to_many_create(self):
- data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
- serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'target-4'
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
- ]
- assert serializer.data == expected
+ data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'target-4'
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ]
+ assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
-class HyperlinkedForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
- ]
- with self.assertNumQueries(1):
- assert serializer.data == expected
-
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- ]
- with self.assertNumQueries(3):
- assert serializer.data == expected
-
- def test_foreign_key_update(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
- ]
- assert serializer.data == expected
-
- def test_foreign_key_update_incorrect_type(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- assert not serializer.is_valid()
- assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']}
-
- def test_reverse_foreign_key_update(self):
- data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
- assert serializer.is_valid()
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- ]
- assert new_serializer.data == expected
-
- serializer.save()
- assert serializer.data == data
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
- ]
- assert serializer.data == expected
-
- def test_foreign_key_create(self):
- data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
- serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'source-4'
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
- ]
- assert serializer.data == expected
-
- def test_reverse_foreign_key_create(self):
- data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
- serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'target-3'
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
- ]
- assert serializer.data == expected
-
- def test_foreign_key_update_with_invalid_null(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- assert not serializer.is_valid()
- assert serializer.errors == {'target': ['This field may not be null.']}
-
-
-@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
-class HyperlinkedNullableForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- assert serializer.data == expected
-
- def test_foreign_key_create_with_valid_null(self):
- data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- ]
- assert serializer.data == expected
-
- def test_foreign_key_create_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
- expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == expected_data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- ]
- assert serializer.data == expected
-
- def test_foreign_key_update_with_valid_null(self):
- data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- assert serializer.data == expected
-
- def test_foreign_key_update_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
- expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == expected_data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- assert serializer.data == expected
-
-
-@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
-class HyperlinkedNullableOneToOneTests(TestCase):
- def setUp(self):
- target = OneToOneTarget(name='target-1')
- target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
+def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
- {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
- ]
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
+ with self.assertNumQueries(1):
assert serializer.data == expected
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
+ with self.assertNumQueries(3):
+ assert serializer.data == expected
+
+ def test_foreign_key_update(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
+ assert serializer.data == expected
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ assert not serializer.is_valid()
+ assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']}
+
+ def test_reverse_foreign_key_update(self):
+ data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
+ assert serializer.is_valid()
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
+ assert new_serializer.data == expected
+ serializer.save()
+ assert serializer.data == data
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
+ assert serializer.data == expected
+
+ def test_foreign_key_create(self):
+ data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'source-4'
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, ]
+ assert serializer.data == expected
+
+ def test_reverse_foreign_key_create(self):
+ data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'target-3'
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
+ assert serializer.data == expected
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ assert not serializer.is_valid()
+ assert serializer.errors == {'target': ['This field may not be null.']}
+
+
+@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
+def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
+ assert serializer.data == expected
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'source-4'
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
+ assert serializer.data == expected
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == expected_data
+ assert obj.name == 'source-4'
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
+ assert serializer.data == expected
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
+ assert serializer.data == expected
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == expected_data
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
+ assert serializer.data == expected
+
+
+@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
+def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ]
+ assert serializer.data == expected
diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py
index 0da9da890..517b8bbb2 100644
--- a/tests/test_relations_pk.py
+++ b/tests/test_relations_pk.py
@@ -77,496 +77,373 @@ class OneToOnePKSourceSerializer(serializers.ModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid
-class PKManyToManyTests(TestCase):
- def setUp(self):
- for idx in range(1, 4):
- target = ManyToManyTarget(name='target-%d' % idx)
- target.save()
- source = ManyToManySource(name='source-%d' % idx)
- source.save()
- for target in ManyToManyTarget.objects.all():
- source.targets.add(target)
+def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
def test_many_to_many_retrieve(self):
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
- ]
- with self.assertNumQueries(4):
- assert serializer.data == expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
+ with self.assertNumQueries(4):
+ assert serializer.data == expected
def test_many_to_many_retrieve_prefetch_related(self):
- queryset = ManyToManySource.objects.all().prefetch_related('targets')
- serializer = ManyToManySourceSerializer(queryset, many=True)
- with self.assertNumQueries(2):
- serializer.data
+ queryset = ManyToManySource.objects.all().prefetch_related('targets')
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ with self.assertNumQueries(2):
+ serializer.data
def test_reverse_many_to_many_retrieve(self):
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]}
- ]
- with self.assertNumQueries(4):
- assert serializer.data == expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
+ with self.assertNumQueries(4):
+ assert serializer.data == expected
def test_many_to_many_update(self):
- data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
- instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data)
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
- ]
- assert serializer.data == expected
+ data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
+ assert serializer.data == expected
def test_reverse_many_to_many_update(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [1]}
- instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data)
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure target 1 is updated, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]}
- ]
- assert serializer.data == expected
+ data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
+ assert serializer.data == expected
def test_many_to_many_create(self):
- data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
- serializer = ManyToManySourceSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
- {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
- ]
- assert serializer.data == expected
+ data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
+ serializer = ManyToManySourceSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'source-4'
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, ]
+ assert serializer.data == expected
def test_many_to_many_unsaved(self):
- source = ManyToManySource(name='source-unsaved')
-
- serializer = ManyToManySourceSerializer(source)
-
- expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
- # no query if source hasn't been created yet
- with self.assertNumQueries(0):
- assert serializer.data == expected
+ source = ManyToManySource(name='source-unsaved')
+ serializer = ManyToManySourceSerializer(source)
+ expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
+ with self.assertNumQueries(0):
+ assert serializer.data == expected
def test_reverse_many_to_many_create(self):
- data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
- serializer = ManyToManyTargetSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'target-4'
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]},
- {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
- ]
- assert serializer.data == expected
+ data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ serializer = ManyToManyTargetSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'target-4'
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]}, {'id': 4, 'name': 'target-4', 'sources': [1, 3]} ]
+ assert serializer.data == expected
-class PKForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
+def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1}
- ]
- with self.assertNumQueries(1):
- assert serializer.data == expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
+ with self.assertNumQueries(1):
+ assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- with self.assertNumQueries(3):
- assert serializer.data == expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
+ with self.assertNumQueries(3):
+ assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self):
- queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- with self.assertNumQueries(2):
- serializer.data
+ queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ with self.assertNumQueries(2):
+ serializer.data
def test_foreign_key_update(self):
- data = {'id': 1, 'name': 'source-1', 'target': 2}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 2},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1}
- ]
- assert serializer.data == expected
+ data = {'id': 1, 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 2}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
+ assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
- data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- assert not serializer.is_valid()
- assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received str.']}
+ data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ assert not serializer.is_valid()
+ assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received str.']}
def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data)
- assert serializer.is_valid()
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- assert new_serializer.data == expected
-
- serializer.save()
- assert serializer.data == data
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [2]},
- {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
- ]
- assert serializer.data == expected
+ data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ assert serializer.is_valid()
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
+ assert new_serializer.data == expected
+ serializer.save()
+ assert serializer.data == data
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, ]
+ assert serializer.data == expected
def test_foreign_key_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': 2}
- serializer = ForeignKeySourceSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1},
- {'id': 4, 'name': 'source-4', 'target': 2},
- ]
- assert serializer.data == expected
+ data = {'id': 4, 'name': 'source-4', 'target': 2}
+ serializer = ForeignKeySourceSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'source-4'
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1}, {'id': 4, 'name': 'source-4', 'target': 2}, ]
+ assert serializer.data == expected
def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
- serializer = ForeignKeyTargetSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'target-3'
-
- # Ensure target 3 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [2]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
- ]
- assert serializer.data == expected
+ data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'target-3'
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, ]
+ assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- assert not serializer.is_valid()
- assert serializer.errors == {'target': ['This field may not be null.']}
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ assert not serializer.is_valid()
+ assert serializer.errors == {'target': ['This field may not be null.']}
def test_foreign_key_with_unsaved(self):
- source = ForeignKeySource(name='source-unsaved')
- expected = {'id': None, 'name': 'source-unsaved', 'target': None}
-
- serializer = ForeignKeySourceSerializer(source)
-
- # no query if source hasn't been created yet
- with self.assertNumQueries(0):
- assert serializer.data == expected
+ source = ForeignKeySource(name='source-unsaved')
+ expected = {'id': None, 'name': 'source-unsaved', 'target': None}
+ serializer = ForeignKeySourceSerializer(source)
+ with self.assertNumQueries(0):
+ assert serializer.data == expected
def test_foreign_key_with_empty(self):
- """
+ """
Regression test for #1072
https://github.com/encode/django-rest-framework/issues/1072
"""
- serializer = NullableForeignKeySourceSerializer()
- assert serializer.data['target'] is None
+ serializer = NullableForeignKeySourceSerializer()
+ assert serializer.data['target'] is None
def test_foreign_key_not_required(self):
- """
+ """
Let's say we wanted to fill the non-nullable model field inside
Model.save(), we would make it empty and not required.
"""
- class ModelSerializer(ForeignKeySourceSerializer):
- class Meta(ForeignKeySourceSerializer.Meta):
- extra_kwargs = {'target': {'required': False}}
+ class ModelSerializer(ForeignKeySourceSerializer):
+ class Meta(ForeignKeySourceSerializer.Meta):
+ extra_kwargs = {'target': {'required': False}}
serializer = ModelSerializer(data={'name': 'test'})
- serializer.is_valid(raise_exception=True)
- assert 'target' not in serializer.validated_data
+serializer.is_valid(raise_exception=True)
+assert'target'notinserializer.validated_data
def test_queryset_size_without_limited_choices(self):
- limited_target = ForeignKeyTarget(name="limited-target")
- limited_target.save()
- queryset = ForeignKeySourceSerializer().fields["target"].get_queryset()
- assert len(queryset) == 3
+ limited_target = ForeignKeyTarget(name="limited-target")
+ limited_target.save()
+ queryset = ForeignKeySourceSerializer().fields["target"].get_queryset()
+ assert len(queryset) == 3
def test_queryset_size_with_limited_choices(self):
- limited_target = ForeignKeyTarget(name="limited-target")
- limited_target.save()
- queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset()
- assert len(queryset) == 1
+ limited_target = ForeignKeyTarget(name="limited-target")
+ limited_target.save()
+ queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset()
+ assert len(queryset) == 1
def test_queryset_size_with_Q_limited_choices(self):
- limited_target = ForeignKeyTarget(name="limited-target")
- limited_target.save()
-
- class QLimitedChoicesSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeySourceWithQLimitedChoices
- fields = ("id", "target")
+ limited_target = ForeignKeyTarget(name="limited-target")
+ limited_target.save()
+ class QLimitedChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySourceWithQLimitedChoices
+ fields = ("id", "target")
queryset = QLimitedChoicesSerializer().fields["target"].get_queryset()
- assert len(queryset) == 1
+assertlen(queryset)==1
-class PKNullableForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
+def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
+source.save()
def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- ]
- assert serializer.data == expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, ]
+ assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- assert serializer.data == expected
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'source-4'
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
+ assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self):
- """
+ """
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 4, 'name': 'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == expected_data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- assert serializer.data == expected
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == expected_data
+ assert obj.name == 'source-4'
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
+ assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- assert serializer.data == expected
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
+ assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self):
- """
+ """
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 1, 'name': 'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == expected_data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- assert serializer.data == expected
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == expected_data
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
+ assert serializer.data == expected
def test_null_uuid_foreign_key_serializes_as_none(self):
- source = NullableUUIDForeignKeySource(name='Source')
- serializer = NullableUUIDForeignKeySourceSerializer(source)
- data = serializer.data
- assert data["target"] is None
+ source = NullableUUIDForeignKeySource(name='Source')
+ serializer = NullableUUIDForeignKeySourceSerializer(source)
+ data = serializer.data
+ assert data["target"] is None
def test_nullable_uuid_foreign_key_is_valid_when_none(self):
- data = {"name": "Source", "target": None}
- serializer = NullableUUIDForeignKeySourceSerializer(data=data)
- assert serializer.is_valid(), serializer.errors
+ data = {"name": "Source", "target": None}
+ serializer = NullableUUIDForeignKeySourceSerializer(data=data)
+ assert serializer.is_valid(), serializer.errors
-class PKNullableOneToOneTests(TestCase):
- def setUp(self):
- target = OneToOneTarget(name='target-1')
- target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=new_target)
- source.save()
+def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=new_target)
+ source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': None},
- {'id': 2, 'name': 'target-2', 'nullable_source': 1},
- ]
- assert serializer.data == expected
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ]
+ assert serializer.data == expected
-class OneToOnePrimaryKeyTests(TestCase):
- def setUp(self):
+def setUp(self):
# Given: Some target models already exist
- self.target = target = OneToOneTarget(name='target-1')
- target.save()
- self.alt_target = alt_target = OneToOneTarget(name='target-2')
- alt_target.save()
+ self.target = target = OneToOneTarget(name='target-1')
+ target.save()
+ self.alt_target = alt_target = OneToOneTarget(name='target-2')
+ alt_target.save()
def test_one_to_one_when_primary_key(self):
# When: Creating a Source pointing at the id of the second Target
- target_pk = self.alt_target.id
- source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
- # Then: The source is valid with the serializer
- if not source.is_valid():
- self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
+ target_pk = self.alt_target.id
+ source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
+ if not source.is_valid():
+ self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
# Then: Saving the serializer creates a new object
new_source = source.save()
- # Then: The new object has the same pk as the target object
- self.assertEqual(new_source.pk, target_pk)
+assertnew_source.pk==target_pk
def test_one_to_one_when_primary_key_no_duplicates(self):
# When: Creating a Source pointing at the id of the second Target
- target_pk = self.target.id
- data = {'name': 'source-1', 'target': target_pk}
- source = OneToOnePKSourceSerializer(data=data)
- # Then: The source is valid with the serializer
- self.assertTrue(source.is_valid())
- # Then: Saving the serializer creates a new object
- new_source = source.save()
- # Then: The new object has the same pk as the target object
- self.assertEqual(new_source.pk, target_pk)
- # When: Trying to create a second object
- second_source = OneToOnePKSourceSerializer(data=data)
- self.assertFalse(second_source.is_valid())
- expected = {'target': ['one to one pk source with this target already exists.']}
- self.assertDictEqual(second_source.errors, expected)
+ target_pk = self.target.id
+ data = {'name': 'source-1', 'target': target_pk}
+ source = OneToOnePKSourceSerializer(data=data)
+ assert source.is_valid()
+ new_source = source.save()
+ assert new_source.pk == target_pk
+ second_source = OneToOnePKSourceSerializer(data=data)
+ assert not second_source.is_valid()
+ expected = {'target': ['one to one pk source with this target already exists.']}
+ assert second_source.errors == expected
def test_one_to_one_when_primary_key_does_not_exist(self):
# Given: a target PK that does not exist
- target_pk = self.target.pk + self.alt_target.pk
- source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
- # Then: The source is not valid with the serializer
- self.assertFalse(source.is_valid())
- self.assertIn("Invalid pk", source.errors['target'][0])
- self.assertIn("object does not exist", source.errors['target'][0])
+ target_pk = self.target.pk + self.alt_target.pk
+ source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
+ assert not source.is_valid()
+ assert "Invalid pk" in source.errors['target'][0]
+ assert "object does not exist" in source.errors['target'][0]
diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py
index 0b9ca79d3..44e8d9e47 100644
--- a/tests/test_relations_slug.py
+++ b/tests/test_relations_slug.py
@@ -42,246 +42,177 @@ class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
# TODO: M2M Tests, FKTests (Non-nullable), One2One
-class SlugForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
+def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'}
- ]
- with self.assertNumQueries(4):
- assert serializer.data == expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
+ with self.assertNumQueries(4):
+ assert serializer.data == expected
def test_foreign_key_retrieve_select_related(self):
- queryset = ForeignKeySource.objects.all().select_related('target')
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- with self.assertNumQueries(1):
- serializer.data
+ queryset = ForeignKeySource.objects.all().select_related('target')
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ with self.assertNumQueries(1):
+ serializer.data
def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- assert serializer.data == expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
+ assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self):
- queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- with self.assertNumQueries(2):
- serializer.data
+ queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ with self.assertNumQueries(2):
+ serializer.data
def test_foreign_key_update(self):
- data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-2'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'}
- ]
- assert serializer.data == expected
+ data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-2'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
+ assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
- data = {'id': 1, 'name': 'source-1', 'target': 123}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- assert not serializer.is_valid()
- assert serializer.errors == {'target': ['Object with name=123 does not exist.']}
+ data = {'id': 1, 'name': 'source-1', 'target': 123}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ assert not serializer.is_valid()
+ assert serializer.errors == {'target': ['Object with name=123 does not exist.']}
def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data)
- assert serializer.is_valid()
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- assert new_serializer.data == expected
-
- serializer.save()
- assert serializer.data == data
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
- {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
- ]
- assert serializer.data == expected
+ data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ assert serializer.is_valid()
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
+ assert new_serializer.data == expected
+ serializer.save()
+ assert serializer.data == data
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, ]
+ assert serializer.data == expected
def test_foreign_key_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
- serializer = ForeignKeySourceSerializer(data=data)
- serializer.is_valid()
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'},
- {'id': 4, 'name': 'source-4', 'target': 'target-2'},
- ]
- assert serializer.data == expected
+ data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ serializer.is_valid()
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'source-4'
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'}, {'id': 4, 'name': 'source-4', 'target': 'target-2'}, ]
+ assert serializer.data == expected
def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
- serializer = ForeignKeyTargetSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'target-3'
-
- # Ensure target 3 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
- ]
- assert serializer.data == expected
+ data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'target-3'
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, ]
+ assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- assert not serializer.is_valid()
- assert serializer.errors == {'target': ['This field may not be null.']}
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ assert not serializer.is_valid()
+ assert serializer.errors == {'target': ['This field may not be null.']}
-class SlugNullableForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
+def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
+source.save()
def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- ]
- assert serializer.data == expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, ]
+ assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- assert serializer.data == expected
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == data
+ assert obj.name == 'source-4'
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
+ assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self):
- """
+ """
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 4, 'name': 'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- assert serializer.is_valid()
- obj = serializer.save()
- assert serializer.data == expected_data
- assert obj.name == 'source-4'
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- assert serializer.data == expected
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ assert serializer.is_valid()
+ obj = serializer.save()
+ assert serializer.data == expected_data
+ assert obj.name == 'source-4'
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
+ assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- assert serializer.data == expected
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == data
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
+ assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self):
- """
+ """
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 1, 'name': 'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- assert serializer.is_valid()
- serializer.save()
- assert serializer.data == expected_data
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- assert serializer.data == expected
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ assert serializer.is_valid()
+ serializer.save()
+ assert serializer.data == expected_data
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
+ assert serializer.data == expected
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
index d63dbcb9c..d2c8dbef9 100644
--- a/tests/test_renderers.py
+++ b/tests/test_renderers.py
@@ -49,11 +49,10 @@ class DummyTestModel(models.Model):
name = models.CharField(max_length=42, default='')
-class BasicRendererTests(TestCase):
- def test_expected_results(self):
- for value, renderer_cls, expected in expected_results:
- output = renderer_cls().render(value)
- self.assertEqual(output, expected)
+def test_expected_results(self):
+ for value, renderer_cls, expected in expected_results:
+ output = renderer_cls().render(value)
+ assert output == expected
class RendererA(BaseRenderer):
@@ -144,123 +143,114 @@ class POSTDeniedView(APIView):
return Response()
-class DocumentingRendererTests(TestCase):
- def test_only_permitted_forms_are_displayed(self):
- view = POSTDeniedView.as_view()
- request = APIRequestFactory().get('/')
- response = view(request).render()
- self.assertNotContains(response, '>POST<')
- self.assertContains(response, '>PUT<')
- self.assertContains(response, '>PATCH<')
+def test_only_permitted_forms_are_displayed(self):
+ view = POSTDeniedView.as_view()
+ request = APIRequestFactory().get('/')
+ response = view(request).render()
+ self.assertNotContains(response, '>POST<')
+ self.assertContains(response, '>PUT<')
+ self.assertContains(response, '>PATCH<')
@override_settings(ROOT_URLCONF='tests.test_renderers')
-class RendererEndToEndTests(TestCase):
- """
+"""
End-to-end testing of renderers using an RendererMixin on a generic view.
"""
- def test_default_renderer_serializes_content(self):
- """If the Accept header is not set the default renderer should serialize the response."""
- resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+deftest_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self):
- """No response must be included in HEAD requests."""
- resp = self.client.head('/')
- self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, b'')
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ assert resp.status_code == DUMMYSTATUS
+ assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
+ assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self):
- """If the Accept header is set to */* the default renderer should serialize the response."""
- resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
+ """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
+ """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
- """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+ """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
+ resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
+ assert resp.status_code == status.HTTP_406_NOT_ACCEPTABLE
def test_specified_renderer_serializes_content_on_format_query(self):
- """If a 'format' query is specified, the renderer with the matching
+ """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
- param = '?%s=%s' % (
- api_settings.URL_FORMAT_OVERRIDE,
- RendererB.format
- )
- resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
+ resp = self.client.get('/' + param)
+ assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self):
- """If a 'format' keyword arg is specified, the renderer with the matching
+ """If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
- resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ resp = self.client.get('/something.formatb')
+ assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
- """If both a 'format' query and a matching Accept header specified,
+ """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
- param = '?%s=%s' % (
- api_settings.URL_FORMAT_OVERRIDE,
- RendererB.format
- )
- resp = self.client.get('/' + param,
- HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
+ resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type)
+ assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_parse_error_renderers_browsable_api(self):
- """Invalid data should still render the browsable API correctly."""
- resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
+ """Invalid data should still render the browsable API correctly."""
+ resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
+ assert resp['Content-Type'] == 'text/html; charset=utf-8'
+ assert resp.status_code == status.HTTP_400_BAD_REQUEST
def test_204_no_content_responses_have_no_content_type_set(self):
- """
+ """
Regression test for #1196
https://github.com/encode/django-rest-framework/issues/1196
"""
- resp = self.client.get('/empty')
- self.assertEqual(resp.get('Content-Type', None), None)
- self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
+ resp = self.client.get('/empty')
+ assert resp.get('Content-Type', None) == None
+ assert resp.status_code == status.HTTP_204_NO_CONTENT
def test_contains_headers_of_api_response(self):
- """
+ """
Issue #1437
Test we display the headers of the API response and not those from the
HTML response
"""
- resp = self.client.get('/html1')
- self.assertContains(resp, '>GET, HEAD, OPTIONS<')
- self.assertContains(resp, '>application/json<')
- self.assertNotContains(resp, '>text/html; charset=utf-8<')
+ resp = self.client.get('/html1')
+ self.assertContains(resp, '>GET, HEAD, OPTIONS<')
+ self.assertContains(resp, '>application/json<')
+ self.assertNotContains(resp, '>text/html; charset=utf-8<')
_flat_repr = '{"foo":["bar","baz"]}'
@@ -275,183 +265,174 @@ def strip_trailing_whitespace(content):
return re.sub(' +\n', '\n', content)
-class BaseRendererTests(TestCase):
- """
+"""
Tests BaseRenderer
"""
- def test_render_raise_error(self):
- """
+deftest_render_raise_error(self):
+ """
BaseRenderer.render should raise NotImplementedError
"""
- with pytest.raises(NotImplementedError):
- BaseRenderer().render('test')
+ with pytest.raises(NotImplementedError):
+ BaseRenderer().render('test')
-class JSONRendererTests(TestCase):
- """
+"""
Tests specific to the JSON Renderer
"""
-
- def test_render_lazy_strings(self):
- """
+deftest_render_lazy_strings(self):
+ """
JSONRenderer should deal with lazy translated strings.
"""
- ret = JSONRenderer().render(_('test'))
- self.assertEqual(ret, b'"test"')
+ ret = JSONRenderer().render(_('test'))
+ assert ret == b'"test"'
def test_render_queryset_values(self):
- o = DummyTestModel.objects.create(name='dummy')
- qs = DummyTestModel.objects.values('id', 'name')
- ret = JSONRenderer().render(qs)
- data = json.loads(ret.decode())
- self.assertEqual(data, [{'id': o.id, 'name': o.name}])
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode())
+ assert data == [{'id': o.id, 'name': o.name}]
def test_render_queryset_values_list(self):
- o = DummyTestModel.objects.create(name='dummy')
- qs = DummyTestModel.objects.values_list('id', 'name')
- ret = JSONRenderer().render(qs)
- data = json.loads(ret.decode())
- self.assertEqual(data, [[o.id, o.name]])
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values_list('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode())
+ assert data == [[o.id, o.name]]
def test_render_dict_abc_obj(self):
- class Dict(MutableMapping):
- def __init__(self):
- self._dict = {}
+ class Dict(MutableMapping):
+ def __init__(self):
+ self._dict = {}
def __getitem__(self, key):
- return self._dict.__getitem__(key)
+ return self._dict.__getitem__(key)
def __setitem__(self, key, value):
- return self._dict.__setitem__(key, value)
+ return self._dict.__setitem__(key, value)
def __delitem__(self, key):
- return self._dict.__delitem__(key)
+ return self._dict.__delitem__(key)
def __iter__(self):
- return self._dict.__iter__()
+ return self._dict.__iter__()
def __len__(self):
- return self._dict.__len__()
+ return self._dict.__len__()
def keys(self):
- return self._dict.keys()
+ return self._dict.keys()
x = Dict()
- x['key'] = 'string value'
- x[2] = 3
- ret = JSONRenderer().render(x)
- data = json.loads(ret.decode())
- self.assertEqual(data, {'key': 'string value', '2': 3})
+x['key']='string value'
+x[2]=3
+ret=JSONRenderer().render(x)
+data=json.loads(ret.decode())
+assertdata=={'key':'string value','2':3}
def test_render_obj_with_getitem(self):
- class DictLike:
- def __init__(self):
- self._dict = {}
+ class DictLike:
+ def __init__(self):
+ self._dict = {}
def set(self, value):
- self._dict = dict(value)
+ self._dict = dict(value)
def __getitem__(self, key):
- return self._dict[key]
+ return self._dict[key]
x = DictLike()
- x.set({'a': 1, 'b': 'string'})
- with self.assertRaises(TypeError):
- JSONRenderer().render(x)
+x.set({'a':1,'b':'string'})
+withpytest.raises(TypeError):
+ JSONRenderer().render(x)
def test_float_strictness(self):
- renderer = JSONRenderer()
-
- # Default to strict
- for value in [float('inf'), float('-inf'), float('nan')]:
- with pytest.raises(ValueError):
- renderer.render(value)
+ renderer = JSONRenderer()
+ for value in [float('inf'), float('-inf'), float('nan')]:
+ with pytest.raises(ValueError):
+ renderer.render(value)
renderer.strict = False
- assert renderer.render(float('inf')) == b'Infinity'
- assert renderer.render(float('-inf')) == b'-Infinity'
- assert renderer.render(float('nan')) == b'NaN'
+assertrenderer.render(float('inf'))==b'Infinity'
+assertrenderer.render(float('-inf'))==b'-Infinity'
+assertrenderer.render(float('nan'))==b'NaN'
def test_without_content_type_args(self):
- """
+ """
Test basic JSON rendering.
"""
- obj = {'foo': ['bar', 'baz']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- # Fix failing test case which depends on version of JSON library.
- self.assertEqual(content.decode(), _flat_repr)
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ assert content.decode() == _flat_repr
def test_with_content_type_args(self):
- """
+ """
Test JSON rendering with additional content type arguments supplied.
"""
- obj = {'foo': ['bar', 'baz']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json; indent=2')
- self.assertEqual(strip_trailing_whitespace(content.decode()), _indented_repr)
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json; indent=2')
+ assert strip_trailing_whitespace(content.decode()) == _indented_repr
-class UnicodeJSONRendererTests(TestCase):
- """
+"""
Tests specific for the Unicode JSON Renderer
"""
- def test_proper_encoding(self):
- obj = {'countries': ['United Kingdom', 'France', 'España']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode())
+deftest_proper_encoding(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ assert content == '{"countries":["United Kingdom","France","España"]}'.encode()
def test_u2028_u2029(self):
# The \u2028 and \u2029 characters should be escaped,
# even when the non-escaping unicode representation is used.
# Regression test for #2169
- obj = {'should_escape': '\u2028\u2029'}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode())
+ obj = {'should_escape': '\u2028\u2029'}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ assert content == '{"should_escape":"\\u2028\\u2029"}'.encode()
-class AsciiJSONRendererTests(TestCase):
- """
+"""
Tests specific for the Unicode JSON Renderer
"""
- def test_proper_encoding(self):
- class AsciiJSONRenderer(JSONRenderer):
- ensure_ascii = True
+deftest_proper_encoding(self):
+ class AsciiJSONRenderer(JSONRenderer):
+ ensure_ascii = True
obj = {'countries': ['United Kingdom', 'France', 'España']}
- renderer = AsciiJSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode())
+renderer=AsciiJSONRenderer()
+content=renderer.render(obj,'application/json')
+assertcontent=='{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()
# Tests for caching issue, #346
@override_settings(ROOT_URLCONF='tests.test_renderers')
-class CacheRenderTest(TestCase):
- """
+"""
Tests specific to caching responses
"""
- def test_head_caching(self):
- """
+deftest_head_caching(self):
+ """
Test caching of HEAD requests
"""
- response = self.client.head('/cache')
- cache.set('key', response)
- cached_response = cache.get('key')
- assert isinstance(cached_response, Response)
- assert cached_response.content == response.content
- assert cached_response.status_code == response.status_code
+ response = self.client.head('/cache')
+ cache.set('key', response)
+ cached_response = cache.get('key')
+ assert isinstance(cached_response, Response)
+ assert cached_response.content == response.content
+ assert cached_response.status_code == response.status_code
def test_get_caching(self):
- """
+ """
Test caching of GET requests
"""
- response = self.client.get('/cache')
- cache.set('key', response)
- cached_response = cache.get('key')
- assert isinstance(cached_response, Response)
- assert cached_response.content == response.content
- assert cached_response.status_code == response.status_code
+ response = self.client.get('/cache')
+ cache.set('key', response)
+ cached_response = cache.get('key')
+ assert isinstance(cached_response, Response)
+ assert cached_response.content == response.content
+ assert cached_response.status_code == response.status_code
class TestJSONIndentationStyles:
@@ -476,150 +457,116 @@ class TestJSONIndentationStyles:
assert renderer.render(data) == b'{"a": 1, "b": 2}'
-class TestHiddenFieldHTMLFormRenderer(TestCase):
- def test_hidden_field_rendering(self):
- class TestSerializer(serializers.Serializer):
- published = serializers.HiddenField(default=True)
+def test_hidden_field_rendering(self):
+ class TestSerializer(serializers.Serializer):
+ published = serializers.HiddenField(default=True)
serializer = TestSerializer(data={})
- serializer.is_valid()
- renderer = HTMLFormRenderer()
- field = serializer['published']
- rendered = renderer.render_field(field, {})
- assert rendered == ''
+serializer.is_valid()
+renderer=HTMLFormRenderer()
+field=serializer['published']
+rendered=renderer.render_field(field,{})
+assertrendered==''
-class TestHTMLFormRenderer(TestCase):
- def setUp(self):
- class TestSerializer(serializers.Serializer):
- test_field = serializers.CharField()
+def setUp(self):
+ class TestSerializer(serializers.Serializer):
+ test_field = serializers.CharField()
self.renderer = HTMLFormRenderer()
- self.serializer = TestSerializer(data={})
+self.serializer=TestSerializer(data={})
def test_render_with_default_args(self):
- self.serializer.is_valid()
- renderer = HTMLFormRenderer()
-
- result = renderer.render(self.serializer.data)
-
- self.assertIsInstance(result, SafeText)
+ self.serializer.is_valid()
+ renderer = HTMLFormRenderer()
+ result = renderer.render(self.serializer.data)
+ assert isinstance(result, SafeText)
def test_render_with_provided_args(self):
- self.serializer.is_valid()
- renderer = HTMLFormRenderer()
-
- result = renderer.render(self.serializer.data, None, {})
-
- self.assertIsInstance(result, SafeText)
+ self.serializer.is_valid()
+ renderer = HTMLFormRenderer()
+ result = renderer.render(self.serializer.data, None, {})
+ assert isinstance(result, SafeText)
-class TestChoiceFieldHTMLFormRenderer(TestCase):
- """
+"""
Test rendering ChoiceField with HTMLFormRenderer.
"""
-
- def setUp(self):
- choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
-
- class TestSerializer(serializers.Serializer):
- test_field = serializers.ChoiceField(choices=choices,
- initial=2)
+defsetUp(self):
+ choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
+ class TestSerializer(serializers.Serializer):
+ test_field = serializers.ChoiceField(choices=choices, initial=2)
self.TestSerializer = TestSerializer
- self.renderer = HTMLFormRenderer()
+self.renderer=HTMLFormRenderer()
def test_render_initial_option(self):
- serializer = self.TestSerializer()
- result = self.renderer.render(serializer.data)
-
- self.assertIsInstance(result, SafeText)
-
- self.assertInHTML('',
- result)
- self.assertInHTML('', result)
- self.assertInHTML('', result)
+ serializer = self.TestSerializer()
+ result = self.renderer.render(serializer.data)
+ assert isinstance(result, SafeText)
+ self.assertInHTML('', result)
+ self.assertInHTML('', result)
+ self.assertInHTML('', result)
def test_render_selected_option(self):
- serializer = self.TestSerializer(data={'test_field': '12'})
-
- serializer.is_valid()
- result = self.renderer.render(serializer.data)
-
- self.assertIsInstance(result, SafeText)
-
- self.assertInHTML('',
- result)
- self.assertInHTML('', result)
- self.assertInHTML('', result)
+ serializer = self.TestSerializer(data={'test_field': '12'})
+ serializer.is_valid()
+ result = self.renderer.render(serializer.data)
+ assert isinstance(result, SafeText)
+ self.assertInHTML('', result)
+ self.assertInHTML('', result)
+ self.assertInHTML('', result)
-class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
- """
+"""
Test rendering MultipleChoiceField with HTMLFormRenderer.
"""
-
- def setUp(self):
- self.renderer = HTMLFormRenderer()
+defsetUp(self):
+ self.renderer = HTMLFormRenderer()
def test_render_selected_option_with_string_option_ids(self):
- choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'),
- ('}', 'OptionBrace'))
-
- class TestSerializer(serializers.Serializer):
- test_field = serializers.MultipleChoiceField(choices=choices)
+ choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), ('}', 'OptionBrace'))
+ class TestSerializer(serializers.Serializer):
+ test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']})
- serializer.is_valid()
-
- result = self.renderer.render(serializer.data)
-
- self.assertIsInstance(result, SafeText)
-
- self.assertInHTML('',
- result)
- self.assertInHTML('', result)
- self.assertInHTML('', result)
- self.assertInHTML('', result)
+serializer.is_valid()
+result=self.renderer.render(serializer.data)
+assertisinstance(result, SafeText)
+self.assertInHTML('',result)
+self.assertInHTML('',result)
+self.assertInHTML('',result)
+self.assertInHTML('',result)
def test_render_selected_option_with_integer_option_ids(self):
- choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
-
- class TestSerializer(serializers.Serializer):
- test_field = serializers.MultipleChoiceField(choices=choices)
+ choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
+ class TestSerializer(serializers.Serializer):
+ test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']})
- serializer.is_valid()
-
- result = self.renderer.render(serializer.data)
-
- self.assertIsInstance(result, SafeText)
-
- self.assertInHTML('',
- result)
- self.assertInHTML('', result)
- self.assertInHTML('', result)
+serializer.is_valid()
+result=self.renderer.render(serializer.data)
+assertisinstance(result, SafeText)
+self.assertInHTML('',result)
+self.assertInHTML('',result)
+self.assertInHTML('',result)
-class StaticHTMLRendererTests(TestCase):
- """
+"""
Tests specific for Static HTML Renderer
"""
- def setUp(self):
- self.renderer = StaticHTMLRenderer()
+defsetUp(self):
+ self.renderer = StaticHTMLRenderer()
def test_static_renderer(self):
- data = 'text'
- result = self.renderer.render(data)
- assert result == data
+ data = 'text'
+ result = self.renderer.render(data)
+ assert result == data
def test_static_renderer_with_exception(self):
- context = {
- 'response': Response(status=500, exception=True),
- 'request': Request(HttpRequest())
- }
- result = self.renderer.render({}, renderer_context=context)
- assert result == '500 Internal Server Error'
+ context = { 'response': Response(status=500, exception=True), 'request': Request(HttpRequest()) }
+ result = self.renderer.render({}, renderer_context=context)
+ assert result == '500 Internal Server Error'
class BrowsableAPIRendererTests(URLPatternsTestCase):
@@ -658,196 +605,142 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
assert '>Extra list action<' in resp.content.decode()
-class AdminRendererTests(TestCase):
- def setUp(self):
- self.renderer = AdminRenderer()
+def setUp(self):
+ self.renderer = AdminRenderer()
def test_render_when_resource_created(self):
- class DummyView(APIView):
- renderer_classes = (AdminRenderer, )
+ class DummyView(APIView):
+ renderer_classes = (AdminRenderer, )
request = Request(HttpRequest())
- request.build_absolute_uri = lambda: 'http://example.com'
- response = Response(status=201, headers={'Location': '/test'})
- context = {
- 'view': DummyView(),
- 'request': request,
- 'response': response
- }
-
- result = self.renderer.render(data={'test': 'test'},
- renderer_context=context)
- assert result == ''
- assert response.status_code == status.HTTP_303_SEE_OTHER
- assert response['Location'] == 'http://example.com'
+request.build_absolute_uri=lambda:'http://example.com'
+response=Response(status=201,headers={'Location':'/test'})
+context={'view':DummyView(),'request':request,'response':response}
+result=self.renderer.render(data={'test':'test'},renderer_context=context)
+assertresult==''
+assertresponse.status_code==status.HTTP_303_SEE_OTHER
+assertresponse['Location']=='http://example.com'
def test_render_dict(self):
- factory = APIRequestFactory()
-
- class DummyView(APIView):
- renderer_classes = (AdminRenderer, )
-
- def get(self, request):
- return Response({'foo': 'a string'})
+ factory = APIRequestFactory()
+ class DummyView(APIView):
+ renderer_classes = (AdminRenderer, )
+ def get(self, request):
+ return Response({'foo': 'a string'})
view = DummyView.as_view()
- request = factory.get('/')
- response = view(request)
- response.render()
- self.assertContains(response, 'Foo | a string |
', html=True)
+request=factory.get('/')
+response=view(request)
+response.render()
+self.assertContains(response,'Foo | a string |
',html=True)
def test_render_dict_with_items_key(self):
- factory = APIRequestFactory()
-
- class DummyView(APIView):
- renderer_classes = (AdminRenderer, )
-
- def get(self, request):
- return Response({'items': 'a string'})
+ factory = APIRequestFactory()
+ class DummyView(APIView):
+ renderer_classes = (AdminRenderer, )
+ def get(self, request):
+ return Response({'items': 'a string'})
view = DummyView.as_view()
- request = factory.get('/')
- response = view(request)
- response.render()
- self.assertContains(response, 'Items | a string |
', html=True)
+request=factory.get('/')
+response=view(request)
+response.render()
+self.assertContains(response,'Items | a string |
',html=True)
def test_render_dict_with_iteritems_key(self):
- factory = APIRequestFactory()
-
- class DummyView(APIView):
- renderer_classes = (AdminRenderer, )
-
- def get(self, request):
- return Response({'iteritems': 'a string'})
+ factory = APIRequestFactory()
+ class DummyView(APIView):
+ renderer_classes = (AdminRenderer, )
+ def get(self, request):
+ return Response({'iteritems': 'a string'})
view = DummyView.as_view()
- request = factory.get('/')
- response = view(request)
- response.render()
- self.assertContains(response, 'Iteritems | a string |
', html=True)
+request=factory.get('/')
+response=view(request)
+response.render()
+self.assertContains(response,'Iteritems | a string |
',html=True)
def test_get_result_url(self):
- factory = APIRequestFactory()
-
- class DummyGenericViewsetLike(APIView):
- lookup_field = 'test'
-
- def reverse_action(view, *args, **kwargs):
- self.assertEqual(kwargs['kwargs']['test'], 1)
- return '/example/'
+ factory = APIRequestFactory()
+ class DummyGenericViewsetLike(APIView):
+ lookup_field = 'test'
+ def reverse_action(view, *args, **kwargs):
+ assert kwargs['kwargs']['test'] == 1
+ return '/example/'
# get the view instance instead of the view function
view = DummyGenericViewsetLike.as_view()
- request = factory.get('/')
- response = view(request)
- view = response.renderer_context['view']
-
- self.assertEqual(self.renderer.get_result_url({'test': 1}, view), '/example/')
- self.assertIsNone(self.renderer.get_result_url({}, view))
+request=factory.get('/')
+response=view(request)
+view=response.renderer_context['view']
+assertself.renderer.get_result_url({'test':1},view)=='/example/'
+assertself.renderer.get_result_url({},view)is None
def test_get_result_url_no_result(self):
- factory = APIRequestFactory()
-
- class DummyView(APIView):
- lookup_field = 'test'
+ factory = APIRequestFactory()
+ class DummyView(APIView):
+ lookup_field = 'test'
# get the view instance instead of the view function
view = DummyView.as_view()
- request = factory.get('/')
- response = view(request)
- view = response.renderer_context['view']
-
- self.assertIsNone(self.renderer.get_result_url({'test': 1}, view))
- self.assertIsNone(self.renderer.get_result_url({}, view))
+request=factory.get('/')
+response=view(request)
+view=response.renderer_context['view']
+assertself.renderer.get_result_url({'test':1},view)is None
+assertself.renderer.get_result_url({},view)is None
def test_get_context_result_urls(self):
- factory = APIRequestFactory()
-
- class DummyView(APIView):
- lookup_field = 'test'
-
- def reverse_action(view, url_name, args=None, kwargs=None):
- return '/%s/%d' % (url_name, kwargs['test'])
+ factory = APIRequestFactory()
+ class DummyView(APIView):
+ lookup_field = 'test'
+ def reverse_action(view, url_name, args=None, kwargs=None):
+ return '/%s/%d' % (url_name, kwargs['test'])
# get the view instance instead of the view function
view = DummyView.as_view()
- request = factory.get('/')
- response = view(request)
-
- data = [
- {'test': 1},
- {'url': '/example', 'test': 2},
- {'url': None, 'test': 3},
- {},
- ]
- context = {
- 'view': DummyView(),
- 'request': Request(request),
- 'response': response
- }
-
- context = self.renderer.get_context(data, None, context)
- results = context['results']
-
- self.assertEqual(len(results), 4)
- self.assertEqual(results[0]['url'], '/detail/1')
- self.assertEqual(results[1]['url'], '/example')
- self.assertEqual(results[2]['url'], None)
- self.assertNotIn('url', results[3])
+request=factory.get('/')
+response=view(request)
+data=[{'test':1},{'url':'/example','test':2},{'url':None,'test':3},{},]
+context={'view':DummyView(),'request':Request(request),'response':response}
+context=self.renderer.get_context(data,None,context)
+results=context['results']
+assertlen(results)==4
+assertresults[0]['url']=='/detail/1'
+assertresults[1]['url']=='/example'
+assertresults[2]['url']==None
+assert'url'not inresults[3]
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
-class TestDocumentationRenderer(TestCase):
- def test_document_with_link_named_data(self):
- """
+def test_document_with_link_named_data(self):
+ """
Ref #5395: Doc's `document.data` would fail with a Link named "data".
As per #4972, use templatetag instead.
"""
- document = coreapi.Document(
- title='Data Endpoint API',
- url='https://api.example.org/',
- content={
- 'data': coreapi.Link(
- url='/data/',
- action='get',
- fields=[],
- description='Return data.'
- )
- }
- )
-
- factory = APIRequestFactory()
- request = factory.get('/')
-
- renderer = DocumentationRenderer()
-
- html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
- assert 'Data Endpoint API
' in html
+ document = coreapi.Document( title='Data Endpoint API', url='https://api.example.org/', content={ 'data': coreapi.Link( url='/data/', action='get', fields=[], description='Return data.' ) } )
+ factory = APIRequestFactory()
+ request = factory.get('/')
+ renderer = DocumentationRenderer()
+ html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
+ assert 'Data Endpoint API
' in html
def test_shell_code_example_rendering(self):
- template = loader.get_template('rest_framework/docs/langs/shell.html')
- context = {
- 'document': coreapi.Document(url='https://api.example.org/'),
- 'link_key': 'testcases > list',
- 'link': coreapi.Link(url='/data/', action='get', fields=[]),
- }
- html = template.render(context)
- assert 'testcases list' in html
+ template = loader.get_template('rest_framework/docs/langs/shell.html')
+ context = { 'document': coreapi.Document(url='https://api.example.org/'), 'link_key': 'testcases > list', 'link': coreapi.Link(url='/data/', action='get', fields=[]), }
+ html = template.render(context)
+ assert 'testcases list' in html
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
-class TestSchemaJSRenderer(TestCase):
- def test_schemajs_output(self):
- """
+def test_schemajs_output(self):
+ """
Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates,
and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix.
"""
- factory = APIRequestFactory()
- request = factory.get('/')
-
- renderer = SchemaJSRenderer()
-
- output = renderer.render('data', renderer_context={"request": request})
- assert "'ImRhdGEi'" in output
- assert "'b'ImRhdGEi''" not in output
+ factory = APIRequestFactory()
+ request = factory.get('/')
+ renderer = SchemaJSRenderer()
+ output = renderer.render('data', renderer_context={"request": request})
+ assert "'ImRhdGEi'" in output
+ assert "'b'ImRhdGEi''" not in output
diff --git a/tests/test_request.py b/tests/test_request.py
index 0f682deb0..d7f1d44cb 100644
--- a/tests/test_request.py
+++ b/tests/test_request.py
@@ -23,18 +23,11 @@ from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView
factory = APIRequestFactory()
-
-
-class TestInitializer(TestCase):
- def test_request_type(self):
- request = Request(factory.get('/'))
-
- message = (
- 'The `request` argument must be an instance of '
- '`django.http.HttpRequest`, not `rest_framework.request.Request`.'
- )
- with self.assertRaisesMessage(AssertionError, message):
- Request(request)
+def test_request_type(self):
+ request = Request(factory.get('/'))
+ message = ( 'The `request` argument must be an instance of ' '`django.http.HttpRequest`, not `rest_framework.request.Request`.' )
+ with self.assertRaisesMessage(AssertionError, message):
+ Request(request)
class PlainTextParser(BaseParser):
@@ -50,79 +43,78 @@ class PlainTextParser(BaseParser):
return stream.read()
-class TestContentParsing(TestCase):
- def test_standard_behaviour_determines_no_content_GET(self):
- """
+def test_standard_behaviour_determines_no_content_GET(self):
+ """
Ensure request.data returns empty QueryDict for GET request.
"""
- request = Request(factory.get('/'))
- assert request.data == {}
+ request = Request(factory.get('/'))
+ assert request.data == {}
def test_standard_behaviour_determines_no_content_HEAD(self):
- """
+ """
Ensure request.data returns empty QueryDict for HEAD request.
"""
- request = Request(factory.head('/'))
- assert request.data == {}
+ request = Request(factory.head('/'))
+ assert request.data == {}
def test_request_DATA_with_form_content(self):
- """
+ """
Ensure request.data returns content for POST request with form content.
"""
- data = {'qwerty': 'uiop'}
- request = Request(factory.post('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- assert list(request.data.items()) == list(data.items())
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ assert list(request.data.items()) == list(data.items())
def test_request_DATA_with_text_content(self):
- """
+ """
Ensure request.data returns content for POST request with
non-form content.
"""
- content = b'qwerty'
- content_type = 'text/plain'
- request = Request(factory.post('/', content, content_type=content_type))
- request.parsers = (PlainTextParser(),)
- assert request.data == content
+ content = b'qwerty'
+ content_type = 'text/plain'
+ request = Request(factory.post('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(),)
+ assert request.data == content
def test_request_POST_with_form_content(self):
- """
+ """
Ensure request.POST returns content for POST request with form content.
"""
- data = {'qwerty': 'uiop'}
- request = Request(factory.post('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- assert list(request.POST.items()) == list(data.items())
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ assert list(request.POST.items()) == list(data.items())
def test_request_POST_with_files(self):
- """
+ """
Ensure request.POST returns no content for POST request with file content.
"""
- upload = SimpleUploadedFile("file.txt", b"file_content")
- request = Request(factory.post('/', {'upload': upload}))
- request.parsers = (FormParser(), MultiPartParser())
- assert list(request.POST) == []
- assert list(request.FILES) == ['upload']
+ upload = SimpleUploadedFile("file.txt", b"file_content")
+ request = Request(factory.post('/', {'upload': upload}))
+ request.parsers = (FormParser(), MultiPartParser())
+ assert list(request.POST) == []
+ assert list(request.FILES) == ['upload']
def test_standard_behaviour_determines_form_content_PUT(self):
- """
+ """
Ensure request.data returns content for PUT request with form content.
"""
- data = {'qwerty': 'uiop'}
- request = Request(factory.put('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- assert list(request.data.items()) == list(data.items())
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.put('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ assert list(request.data.items()) == list(data.items())
def test_standard_behaviour_determines_non_form_content_PUT(self):
- """
+ """
Ensure request.data returns content for PUT request with
non-form content.
"""
- content = b'qwerty'
- content_type = 'text/plain'
- request = Request(factory.put('/', content, content_type=content_type))
- request.parsers = (PlainTextParser(), )
- assert request.data == content
+ content = b'qwerty'
+ content_type = 'text/plain'
+ request = Request(factory.put('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(), )
+ assert request.data == content
class MockView(APIView):
@@ -160,175 +152,148 @@ urlpatterns = [
@override_settings(
ROOT_URLCONF='tests.test_request',
FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler'])
-class FileUploadTests(TestCase):
- def test_fileuploads_closed_at_request_end(self):
- with tempfile.NamedTemporaryFile() as f:
- response = self.client.post('/upload/', {'file': f})
+def test_fileuploads_closed_at_request_end(self):
+ with tempfile.NamedTemporaryFile() as f:
+ response = self.client.post('/upload/', {'file': f})
# sanity check that file was processed
assert len(response.data) == 1
-
- for file in response.data:
- assert not os.path.exists(file)
+forfileinresponse.data:
+ assert not os.path.exists(file)
@override_settings(ROOT_URLCONF='tests.test_request')
-class TestContentParsingWithAuthentication(TestCase):
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
+def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
- """
+ """
Ensures request.POST exists after SessionAuthentication when user
doesn't log in.
"""
- content = {'example': 'example'}
-
- response = self.client.post('/', content)
- assert status.HTTP_200_OK == response.status_code
-
- response = self.csrf_client.post('/', content)
- assert status.HTTP_200_OK == response.status_code
+ content = {'example': 'example'}
+ response = self.client.post('/', content)
+ assert status.HTTP_200_OK == response.status_code
+ response = self.csrf_client.post('/', content)
+ assert status.HTTP_200_OK == response.status_code
-class TestUserSetter(TestCase):
- def setUp(self):
+def setUp(self):
# Pass request object through session middleware so session is
# available to login and logout functions
- self.wrapped_request = factory.get('/')
- self.request = Request(self.wrapped_request)
- SessionMiddleware().process_request(self.wrapped_request)
- AuthenticationMiddleware().process_request(self.wrapped_request)
-
- User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
- self.user = authenticate(username='ringo', password='yellow')
+ self.wrapped_request = factory.get('/')
+ self.request = Request(self.wrapped_request)
+ SessionMiddleware().process_request(self.wrapped_request)
+ AuthenticationMiddleware().process_request(self.wrapped_request)
+ User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
+ self.user = authenticate(username='ringo', password='yellow')
def test_user_can_be_set(self):
- self.request.user = self.user
- assert self.request.user == self.user
+ self.request.user = self.user
+ assert self.request.user == self.user
def test_user_can_login(self):
- login(self.request, self.user)
- assert self.request.user == self.user
+ login(self.request, self.user)
+ assert self.request.user == self.user
def test_user_can_logout(self):
- self.request.user = self.user
- assert not self.request.user.is_anonymous
- logout(self.request)
- assert self.request.user.is_anonymous
+ self.request.user = self.user
+ assert not self.request.user.is_anonymous
+ logout(self.request)
+ assert self.request.user.is_anonymous
def test_logged_in_user_is_set_on_wrapped_request(self):
- login(self.request, self.user)
- assert self.wrapped_request.user == self.user
+ login(self.request, self.user)
+ assert self.wrapped_request.user == self.user
def test_calling_user_fails_when_attribute_error_is_raised(self):
- """
+ """
This proves that when an AttributeError is raised inside of the request.user
property, that we can handle this and report the true, underlying error.
"""
- class AuthRaisesAttributeError:
- def authenticate(self, request):
- self.MISSPELLED_NAME_THAT_DOESNT_EXIST
+ class AuthRaisesAttributeError:
+ def authenticate(self, request):
+ self.MISSPELLED_NAME_THAT_DOESNT_EXIST
request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),))
-
- # The middleware processes the underlying Django request, sets anonymous user
- assert self.wrapped_request.user.is_anonymous
-
- # The DRF request object does not have a user and should run authenticators
- expected = r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
- with pytest.raises(WrappedAttributeError, match=expected):
- request.user
+assertself.wrapped_request.user.is_anonymous
+expected=r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
+withpytest.raises(WrappedAttributeError,match=expected):
+ request.user
with pytest.raises(WrappedAttributeError, match=expected):
- hasattr(request, 'user')
+ hasattr(request, 'user')
with pytest.raises(WrappedAttributeError, match=expected):
- login(request, self.user)
+ login(request, self.user)
-class TestAuthSetter(TestCase):
- def test_auth_can_be_set(self):
- request = Request(factory.get('/'))
- request.auth = 'DUMMY'
- assert request.auth == 'DUMMY'
+def test_auth_can_be_set(self):
+ request = Request(factory.get('/'))
+ request.auth = 'DUMMY'
+ assert request.auth == 'DUMMY'
-class TestSecure(TestCase):
- def test_default_secure_false(self):
- request = Request(factory.get('/', secure=False))
- assert request.scheme == 'http'
+def test_default_secure_false(self):
+ request = Request(factory.get('/', secure=False))
+ assert request.scheme == 'http'
def test_default_secure_true(self):
- request = Request(factory.get('/', secure=True))
- assert request.scheme == 'https'
+ request = Request(factory.get('/', secure=True))
+ assert request.scheme == 'https'
-class TestHttpRequest(TestCase):
- def test_attribute_access_proxy(self):
- http_request = factory.get('/')
- request = Request(http_request)
-
- inner_sentinel = object()
- http_request.inner_property = inner_sentinel
- assert request.inner_property is inner_sentinel
-
- outer_sentinel = object()
- request.inner_property = outer_sentinel
- assert request.inner_property is outer_sentinel
+def test_attribute_access_proxy(self):
+ http_request = factory.get('/')
+ request = Request(http_request)
+ inner_sentinel = object()
+ http_request.inner_property = inner_sentinel
+ assert request.inner_property is inner_sentinel
+ outer_sentinel = object()
+ request.inner_property = outer_sentinel
+ assert request.inner_property is outer_sentinel
def test_exception_proxy(self):
# ensure the exception message is not for the underlying WSGIRequest
- http_request = factory.get('/')
- request = Request(http_request)
-
- message = "'Request' object has no attribute 'inner_property'"
- with self.assertRaisesMessage(AttributeError, message):
- request.inner_property
+ http_request = factory.get('/')
+ request = Request(http_request)
+ message = "'Request' object has no attribute 'inner_property'"
+ with self.assertRaisesMessage(AttributeError, message):
+ request.inner_property
@override_settings(ROOT_URLCONF='tests.test_request')
- def test_duplicate_request_stream_parsing_exception(self):
- """
+deftest_duplicate_request_stream_parsing_exception(self):
+ """
Check assumption that duplicate stream parsing will result in a
`RawPostDataException` being raised.
"""
- response = APIClient().post('/echo/', data={'a': 'b'}, format='json')
- request = response.renderer_context['request']
-
- # ensure that request stream was consumed by json parser
- assert request.content_type.startswith('application/json')
- assert response.data == {'a': 'b'}
-
- # pass same HttpRequest to view, stream already consumed
- with pytest.raises(RawPostDataException):
- EchoView.as_view()(request._request)
+ response = APIClient().post('/echo/', data={'a': 'b'}, format='json')
+ request = response.renderer_context['request']
+ assert request.content_type.startswith('application/json')
+ assert response.data == {'a': 'b'}
+ with pytest.raises(RawPostDataException):
+ EchoView.as_view()(request._request)
@override_settings(ROOT_URLCONF='tests.test_request')
- def test_duplicate_request_form_data_access(self):
- """
+deftest_duplicate_request_form_data_access(self):
+ """
Form data is copied to the underlying django request for middleware
and file closing reasons. Duplicate processing of a request with form
data is 'safe' in so far as accessing `request.POST` does not trigger
the duplicate stream parse exception.
"""
- response = APIClient().post('/echo/', data={'a': 'b'})
- request = response.renderer_context['request']
-
- # ensure that request stream was consumed by form parser
- assert request.content_type.startswith('multipart/form-data')
- assert response.data == {'a': ['b']}
-
- # pass same HttpRequest to view, form data set on underlying request
- response = EchoView.as_view()(request._request)
- request = response.renderer_context['request']
-
- # ensure that request stream was consumed by form parser
- assert request.content_type.startswith('multipart/form-data')
- assert response.data == {'a': ['b']}
+ response = APIClient().post('/echo/', data={'a': 'b'})
+ request = response.renderer_context['request']
+ assert request.content_type.startswith('multipart/form-data')
+ assert response.data == {'a': ['b']}
+ response = EchoView.as_view()(request._request)
+ request = response.renderer_context['request']
+ assert request.content_type.startswith('multipart/form-data')
+ assert response.data == {'a': ['b']}
diff --git a/tests/test_response.py b/tests/test_response.py
index d3a56d01b..9f2027541 100644
--- a/tests/test_response.py
+++ b/tests/test_response.py
@@ -131,155 +131,146 @@ urlpatterns = [
# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
@override_settings(ROOT_URLCONF='tests.test_response')
-class RendererIntegrationTests(TestCase):
- """
+"""
End-to-end testing of renderers using an ResponseMixin on a generic view.
"""
- def test_default_renderer_serializes_content(self):
- """If the Accept header is not set the default renderer should serialize the response."""
- resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+deftest_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self):
- """No response must be included in HEAD requests."""
- resp = self.client.head('/')
- self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, b'')
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ assert resp.status_code == DUMMYSTATUS
+ assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
+ assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self):
- """If the Accept header is set to */* the default renderer should serialize the response."""
- resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
+ """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
+ """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_query(self):
- """If a 'format' query is specified, the renderer with the matching
+ """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ resp = self.client.get('/?format=%s' % RendererB.format)
+ assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self):
- """If a 'format' keyword arg is specified, the renderer with the matching
+ """If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
- resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ resp = self.client.get('/something.formatb')
+ assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
- """If both a 'format' query and a matching Accept header specified,
+ """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format,
- HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
+ resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type)
+ assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
+ assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
+ assert resp.status_code == DUMMYSTATUS
@override_settings(ROOT_URLCONF='tests.test_response')
-class UnsupportedMediaTypeTests(TestCase):
- def test_should_allow_posting_json(self):
- response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
-
- self.assertEqual(response.status_code, 200)
+def test_should_allow_posting_json(self):
+ response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
+ assert response.status_code == 200
def test_should_not_allow_posting_xml(self):
- response = self.client.post('/json', data='123', content_type='application/xml')
-
- self.assertEqual(response.status_code, 415)
+ response = self.client.post('/json', data='123', content_type='application/xml')
+ assert response.status_code == 415
def test_should_not_allow_posting_a_form(self):
- response = self.client.post('/json', data={'test': 123})
-
- self.assertEqual(response.status_code, 415)
+ response = self.client.post('/json', data={'test': 123})
+ assert response.status_code == 415
@override_settings(ROOT_URLCONF='tests.test_response')
-class Issue122Tests(TestCase):
- """
+"""
Tests that covers #122.
"""
- def test_only_html_renderer(self):
- """
+deftest_only_html_renderer(self):
+ """
Test if no infinite recursion occurs.
"""
- self.client.get('/html')
+ self.client.get('/html')
def test_html_renderer_is_first(self):
- """
+ """
Test if no infinite recursion occurs.
"""
- self.client.get('/html1')
+ self.client.get('/html1')
@override_settings(ROOT_URLCONF='tests.test_response')
-class Issue467Tests(TestCase):
- """
+"""
Tests for #467
"""
- def test_form_has_label_and_help_text(self):
- resp = self.client.get('/html_new_model')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+deftest_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.')
@override_settings(ROOT_URLCONF='tests.test_response')
-class Issue807Tests(TestCase):
- """
+"""
Covers #807
"""
- def test_does_not_append_charset_by_default(self):
- """
+deftest_does_not_append_charset_by_default(self):
+ """
Renderers don't include a charset unless set explicitly.
"""
- headers = {"HTTP_ACCEPT": RendererA.media_type}
- resp = self.client.get('/', **headers)
- expected = "{}; charset={}".format(RendererA.media_type, 'utf-8')
- self.assertEqual(expected, resp['Content-Type'])
+ headers = {"HTTP_ACCEPT": RendererA.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{}; charset={}".format(RendererA.media_type, 'utf-8')
+ assert expected == resp['Content-Type']
def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
- """
+ """
If renderer class has charset attribute declared, it gets appended
to Response's Content-Type
"""
- headers = {"HTTP_ACCEPT": RendererC.media_type}
- resp = self.client.get('/', **headers)
- expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset)
- self.assertEqual(expected, resp['Content-Type'])
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset)
+ assert expected == resp['Content-Type']
def test_content_type_set_explicitly_on_response(self):
- """
+ """
The content type may be set explicitly on the response.
"""
- headers = {"HTTP_ACCEPT": RendererC.media_type}
- resp = self.client.get('/setbyview', **headers)
- self.assertEqual('setbyview', resp['Content-Type'])
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/setbyview', **headers)
+ assert 'setbyview' == resp['Content-Type']
def test_form_has_label_and_help_text(self):
- resp = self.client.get('/html_new_model')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ resp = self.client.get('/html_new_model')
+ assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.')
diff --git a/tests/test_reverse.py b/tests/test_reverse.py
index 9ab1667c5..be360ed97 100644
--- a/tests/test_reverse.py
+++ b/tests/test_reverse.py
@@ -30,25 +30,22 @@ class MockVersioningScheme:
@override_settings(ROOT_URLCONF='tests.test_reverse')
-class ReverseTests(TestCase):
- """
+"""
Tests for fully qualified URLs when using `reverse`.
"""
- def test_reversed_urls_are_fully_qualified(self):
- request = factory.get('/view')
- url = reverse('view', request=request)
- assert url == 'http://testserver/view'
+deftest_reversed_urls_are_fully_qualified(self):
+ request = factory.get('/view')
+ url = reverse('view', request=request)
+ assert url == 'http://testserver/view'
def test_reverse_with_versioning_scheme(self):
- request = factory.get('/view')
- request.versioning_scheme = MockVersioningScheme()
-
- url = reverse('view', request=request)
- assert url == 'http://scheme-reversed/view'
+ request = factory.get('/view')
+ request.versioning_scheme = MockVersioningScheme()
+ url = reverse('view', request=request)
+ assert url == 'http://scheme-reversed/view'
def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self):
- request = factory.get('/view')
- request.versioning_scheme = MockVersioningScheme(raise_error=True)
-
- url = reverse('view', request=request)
- assert url == 'http://testserver/view'
+ request = factory.get('/view')
+ request.versioning_scheme = MockVersioningScheme(raise_error=True)
+ url = reverse('view', request=request)
+ assert url == 'http://testserver/view'
diff --git a/tests/test_routers.py b/tests/test_routers.py
index 0f428e2a5..3341f8fd6 100644
--- a/tests/test_routers.py
+++ b/tests/test_routers.py
@@ -214,25 +214,24 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"}
-class TestLookupValueRegex(TestCase):
- """
+"""
Ensure the router honors lookup_value_regex when applied
to the viewset.
"""
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- queryset = RouterTestModel.objects.all()
- lookup_field = 'uuid'
- lookup_value_regex = '[0-9a-f]{32}'
+defsetUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ lookup_field = 'uuid'
+ lookup_value_regex = '[0-9a-f]{32}'
self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
+self.router.register(r'notes',NoteViewSet)
+self.urls=self.router.urls
def test_urls_limited_by_lookup_value_regex(self):
- expected = ['^notes/$', '^notes/(?P[0-9a-f]{32})/$']
- for idx in range(len(expected)):
- assert expected[idx] == get_regex_pattern(self.urls[idx])
+ expected = ['^notes/$', '^notes/(?P[0-9a-f]{32})/$']
+ for idx in range(len(expected)):
+ assert expected[idx] == get_regex_pattern(self.urls[idx])
@override_settings(ROOT_URLCONF='tests.test_routers')
@@ -264,97 +263,84 @@ class TestLookupUrlKwargs(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"}
-class TestTrailingSlashIncluded(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- queryset = RouterTestModel.objects.all()
+def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
+self.router.register(r'notes',NoteViewSet)
+self.urls=self.router.urls
def test_urls_have_trailing_slash_by_default(self):
- expected = ['^notes/$', '^notes/(?P[^/.]+)/$']
- for idx in range(len(expected)):
- assert expected[idx] == get_regex_pattern(self.urls[idx])
+ expected = ['^notes/$', '^notes/(?P[^/.]+)/$']
+ for idx in range(len(expected)):
+ assert expected[idx] == get_regex_pattern(self.urls[idx])
-class TestTrailingSlashRemoved(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- queryset = RouterTestModel.objects.all()
+def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
self.router = SimpleRouter(trailing_slash=False)
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
+self.router.register(r'notes',NoteViewSet)
+self.urls=self.router.urls
def test_urls_can_have_trailing_slash_removed(self):
- expected = ['^notes$', '^notes/(?P[^/.]+)$']
- for idx in range(len(expected)):
- assert expected[idx] == get_regex_pattern(self.urls[idx])
+ expected = ['^notes$', '^notes/(?P[^/.]+)$']
+ for idx in range(len(expected)):
+ assert expected[idx] == get_regex_pattern(self.urls[idx])
-class TestNameableRoot(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- queryset = RouterTestModel.objects.all()
+def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
self.router = DefaultRouter()
- self.router.root_view_name = 'nameable-root'
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
+self.router.root_view_name='nameable-root'
+self.router.register(r'notes',NoteViewSet)
+self.urls=self.router.urls
def test_router_has_custom_name(self):
- expected = 'nameable-root'
- assert expected == self.urls[-1].name
+ expected = 'nameable-root'
+ assert expected == self.urls[-1].name
-class TestActionKeywordArgs(TestCase):
- """
+"""
Ensure keyword arguments passed in the `@action` decorator
are properly handled. Refs #940.
"""
-
- def setUp(self):
- class TestViewSet(viewsets.ModelViewSet):
- permission_classes = []
-
- @action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny])
- def custom(self, request, *args, **kwargs):
- return Response({
- 'permission_classes': self.permission_classes
- })
+defsetUp(self):
+ class TestViewSet(viewsets.ModelViewSet):
+ permission_classes = []
+ @action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny])
+ def custom(self, request, *args, **kwargs):
+ return Response({ 'permission_classes': self.permission_classes })
self.router = SimpleRouter()
- self.router.register(r'test', TestViewSet, basename='test')
- self.view = self.router.urls[-1].callback
+self.router.register(r'test',TestViewSet,basename='test')
+self.view=self.router.urls[-1].callback
def test_action_kwargs(self):
- request = factory.post('/test/0/custom/')
- response = self.view(request)
- assert response.data == {'permission_classes': [permissions.AllowAny]}
+ request = factory.post('/test/0/custom/')
+ response = self.view(request)
+ assert response.data == {'permission_classes': [permissions.AllowAny]}
-class TestActionAppliedToExistingRoute(TestCase):
- """
+"""
Ensure `@action` decorator raises an except when applied
to an existing route
"""
+deftest_exception_raised_when_action_applied_to_existing_route(self):
+ class TestViewSet(viewsets.ModelViewSet):
- def test_exception_raised_when_action_applied_to_existing_route(self):
- class TestViewSet(viewsets.ModelViewSet):
-
- @action(methods=['post'], detail=True)
- def retrieve(self, request, *args, **kwargs):
- return Response({
- 'hello': 'world'
- })
+ @action(methods=['post'], detail=True)
+ def retrieve(self, request, *args, **kwargs):
+ return Response({ 'hello': 'world' })
self.router = SimpleRouter()
- self.router.register(r'test', TestViewSet, basename='test')
-
- with pytest.raises(ImproperlyConfigured):
- self.router.urls
+self.router.register(r'test',TestViewSet,basename='test')
+withpytest.raises(ImproperlyConfigured):
+ self.router.urls
class DynamicListAndDetailViewSet(viewsets.ViewSet):
@@ -390,44 +376,33 @@ class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet):
pass
-class TestDynamicListAndDetailRouter(TestCase):
- def setUp(self):
- self.router = SimpleRouter()
+def setUp(self):
+ self.router = SimpleRouter()
def _test_list_and_detail_route_decorators(self, viewset):
- routes = self.router.get_routes(viewset)
- decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
-
- MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
- # Make sure all these endpoints exist and none have been clobbered
- for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'),
- MethodNamesMap('list_route_get', 'list_route_get'),
- MethodNamesMap('list_route_post', 'list_route_post'),
- MethodNamesMap('detail_custom_route_get', 'detail_custom-route'),
- MethodNamesMap('detail_route_get', 'detail_route_get'),
- MethodNamesMap('detail_route_post', 'detail_route_post')
- ]):
- route = decorator_routes[i]
- # check url listing
- method_name = endpoint.method_name
- url_path = endpoint.url_path
-
- if method_name.startswith('list_'):
- assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
+ routes = self.router.get_routes(viewset)
+ decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
+ MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
+ for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), MethodNamesMap('list_route_get', 'list_route_get'), MethodNamesMap('list_route_post', 'list_route_post'), MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), MethodNamesMap('detail_route_get', 'detail_route_get'), MethodNamesMap('detail_route_post', 'detail_route_post') ]):
+ route = decorator_routes[i]
+ method_name = endpoint.method_name
+ url_path = endpoint.url_path
+ if method_name.startswith('list_'):
+ assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
else:
- assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)
+ assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)
# check method to function mapping
if method_name.endswith('_post'):
- method_map = 'post'
+ method_map = 'post'
else:
- method_map = 'get'
+ method_map = 'get'
assert route.mapping[method_map] == method_name
def test_list_and_detail_route_decorators(self):
- self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet)
+ self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet)
def test_inherited_list_and_detail_route_decorators(self):
- self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
+ self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
class TestEmptyPrefix(URLPatternsTestCase, TestCase):
@@ -490,69 +465,52 @@ class TestViewInitkwargs(URLPatternsTestCase, TestCase):
assert initkwargs['basename'] == 'routertestmodel'
-class TestBaseNameRename(TestCase):
- def test_base_name_and_basename_assertion(self):
- router = SimpleRouter()
-
- msg = "Do not provide both the `basename` and `base_name` arguments."
- with warnings.catch_warnings(record=True) as w, \
- self.assertRaisesMessage(AssertionError, msg):
- warnings.simplefilter('always')
- router.register('mock', MockViewSet, 'mock', base_name='mock')
+def test_base_name_and_basename_assertion(self):
+ router = SimpleRouter()
+ msg = "Do not provide both the `basename` and `base_name` arguments."
+ with warnings.catch_warnings(record=True) as w, self.assertRaisesMessage(AssertionError, msg):
+ warnings.simplefilter('always')
+ router.register('mock', MockViewSet, 'mock', base_name='mock')
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
- assert len(w) == 1
- assert str(w[0].message) == msg
+assertlen(w)==1
+assertstr(w[0].message)==msg
def test_base_name_argument_deprecation(self):
- router = SimpleRouter()
-
- with pytest.warns(RemovedInDRF311Warning) as w:
- warnings.simplefilter('always')
- router.register('mock', MockViewSet, base_name='mock')
+ router = SimpleRouter()
+ with pytest.warns(RemovedInDRF311Warning) as w:
+ warnings.simplefilter('always')
+ router.register('mock', MockViewSet, base_name='mock')
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
- assert len(w) == 1
- assert str(w[0].message) == msg
- assert router.registry == [
- ('mock', MockViewSet, 'mock'),
- ]
+assertlen(w)==1
+assertstr(w[0].message)==msg
+assertrouter.registry==[('mock',MockViewSet,'mock'),]
def test_basename_argument_no_warnings(self):
- router = SimpleRouter()
-
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always')
- router.register('mock', MockViewSet, basename='mock')
+ router = SimpleRouter()
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ router.register('mock', MockViewSet, basename='mock')
assert len(w) == 0
- assert router.registry == [
- ('mock', MockViewSet, 'mock'),
- ]
+assertrouter.registry==[('mock',MockViewSet,'mock'),]
def test_get_default_base_name_deprecation(self):
- msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`."
-
- # Class definition should raise a warning
- with pytest.warns(RemovedInDRF311Warning) as w:
- warnings.simplefilter('always')
-
- class CustomRouter(SimpleRouter):
- def get_default_base_name(self, viewset):
- return 'foo'
+ msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`."
+ with pytest.warns(RemovedInDRF311Warning) as w:
+ warnings.simplefilter('always')
+ class CustomRouter(SimpleRouter):
+ def get_default_base_name(self, viewset):
+ return 'foo'
assert len(w) == 1
- assert str(w[0].message) == msg
-
- # Deprecated method implementation should still be called
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always')
-
- router = CustomRouter()
- router.register('mock', MockViewSet)
+assertstr(w[0].message)==msg
+withwarnings.catch_warnings(record=True)asw:
+ warnings.simplefilter('always')
+ router = CustomRouter()
+ router.register('mock', MockViewSet)
assert len(w) == 0
- assert router.registry == [
- ('mock', MockViewSet, 'foo'),
- ]
+assertrouter.registry==[('mock',MockViewSet,'foo'),]
diff --git a/tests/test_schemas.py b/tests/test_schemas.py
index 230f8f012..8af8d85bc 100644
--- a/tests/test_schemas.py
+++ b/tests/test_schemas.py
@@ -149,206 +149,20 @@ urlpatterns = [
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(ROOT_URLCONF='tests.test_schemas')
-class TestRouterGeneratedSchema(TestCase):
- def test_anonymous_request(self):
- client = APIClient()
- response = client.get('/', HTTP_ACCEPT='application/coreapi+json')
- assert response.status_code == 200
- expected = coreapi.Document(
- url='http://testserver/',
- title='Example API',
- content={
- 'example': {
- 'list': coreapi.Link(
- url='/example/',
- action='get',
- fields=[
- coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')),
- coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- ),
- 'custom_list_action': coreapi.Link(
- url='/example/custom_list_action/',
- action='get'
- ),
- 'custom_list_action_multiple_methods': {
- 'read': coreapi.Link(
- url='/example/custom_list_action_multiple_methods/',
- action='get',
- description='Custom description.',
- )
- },
- 'documented_custom_action': {
- 'read': coreapi.Link(
- url='/example/documented_custom_action/',
- action='get',
- description='A description of the get method on the custom action.',
- )
- },
- 'read': coreapi.Link(
- url='/example/{id}/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- )
- }
- }
- )
- assert response.data == expected
+def test_anonymous_request(self):
+ client = APIClient()
+ response = client.get('/', HTTP_ACCEPT='application/coreapi+json')
+ assert response.status_code == 200
+ expected = coreapi.Document( url='http://testserver/', title='Example API', content={ 'example': { 'list': coreapi.Link( url='/example/', action='get', fields=[ coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'custom_list_action': coreapi.Link( url='/example/custom_list_action/', action='get' ), 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='get', description='Custom description.', ) }, 'documented_custom_action': { 'read': coreapi.Link( url='/example/documented_custom_action/', action='get', description='A description of the get method on the custom action.', ) }, 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ) } } )
+ assert response.data == expected
def test_authenticated_request(self):
- client = APIClient()
- client.force_authenticate(MockUser())
- response = client.get('/', HTTP_ACCEPT='application/coreapi+json')
- assert response.status_code == 200
- expected = coreapi.Document(
- url='http://testserver/',
- title='Example API',
- content={
- 'example': {
- 'list': coreapi.Link(
- url='/example/',
- action='get',
- fields=[
- coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')),
- coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- ),
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- encoding='application/json',
- fields=[
- coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B'))
- ]
- ),
- 'read': coreapi.Link(
- url='/example/{id}/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- ),
- 'custom_action': coreapi.Link(
- url='/example/{id}/custom_action/',
- action='post',
- encoding='application/json',
- description='A description of custom action.',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('c', required=True, location='form', schema=coreschema.String(title='C')),
- coreapi.Field('d', required=False, location='form', schema=coreschema.String(title='D')),
- ]
- ),
- 'custom_action_with_dict_field': coreapi.Link(
- url='/example/{id}/custom_action_with_dict_field/',
- action='post',
- encoding='application/json',
- description='A custom action using a dict field in the serializer.',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('a', required=True, location='form', schema=coreschema.Object(title='A')),
- ]
- ),
- 'custom_action_with_list_fields': coreapi.Link(
- url='/example/{id}/custom_action_with_list_fields/',
- action='post',
- encoding='application/json',
- description='A custom action using both list field and list serializer in the serializer.',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('a', required=True, location='form', schema=coreschema.Array(title='A', items=coreschema.Integer())),
- coreapi.Field('b', required=True, location='form', schema=coreschema.Array(title='B', items=coreschema.String())),
- ]
- ),
- 'custom_list_action': coreapi.Link(
- url='/example/custom_list_action/',
- action='get'
- ),
- 'custom_list_action_multiple_methods': {
- 'read': coreapi.Link(
- url='/example/custom_list_action_multiple_methods/',
- action='get',
- description='Custom description.',
- ),
- 'create': coreapi.Link(
- url='/example/custom_list_action_multiple_methods/',
- action='post',
- description='Custom description.',
- ),
- 'delete': coreapi.Link(
- url='/example/custom_list_action_multiple_methods/',
- action='delete',
- description='Deletion description.',
- ),
- },
- 'documented_custom_action': {
- 'read': coreapi.Link(
- url='/example/documented_custom_action/',
- action='get',
- description='A description of the get method on the custom action.',
- ),
- 'create': coreapi.Link(
- url='/example/documented_custom_action/',
- action='post',
- description='A description of the post method on the custom action.',
- encoding='application/json',
- fields=[
- coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B'))
- ]
- ),
- 'update': coreapi.Link(
- url='/example/documented_custom_action/',
- action='put',
- description='A description of the put method on the custom action from mapping.',
- encoding='application/json',
- fields=[
- coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B'))
- ]
- ),
- },
- 'update': coreapi.Link(
- url='/example/{id}/',
- action='put',
- encoding='application/json',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description=('A field description'))),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- ),
- 'partial_update': coreapi.Link(
- url='/example/{id}/',
- action='patch',
- encoding='application/json',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('a', required=False, location='form', schema=coreschema.String(title='A', description='A field description')),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- ),
- 'delete': coreapi.Link(
- url='/example/{id}/',
- action='delete',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- )
- }
- }
- )
- assert response.data == expected
+ client = APIClient()
+ client.force_authenticate(MockUser())
+ response = client.get('/', HTTP_ACCEPT='application/coreapi+json')
+ assert response.status_code == 200
+ expected = coreapi.Document( url='http://testserver/', title='Example API', content={ 'example': { 'list': coreapi.Link( url='/example/', action='get', fields=[ coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'create': coreapi.Link( url='/example/', action='post', encoding='application/json', fields=[ coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) ] ), 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'custom_action': coreapi.Link( url='/example/{id}/custom_action/', action='post', encoding='application/json', description='A description of custom action.', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('c', required=True, location='form', schema=coreschema.String(title='C')), coreapi.Field('d', required=False, location='form', schema=coreschema.String(title='D')), ] ), 'custom_action_with_dict_field': coreapi.Link( url='/example/{id}/custom_action_with_dict_field/', action='post', encoding='application/json', description='A custom action using a dict field in the serializer.', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('a', required=True, location='form', schema=coreschema.Object(title='A')), ] ), 'custom_action_with_list_fields': coreapi.Link( url='/example/{id}/custom_action_with_list_fields/', action='post', encoding='application/json', description='A custom action using both list field and list serializer in the serializer.', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('a', required=True, location='form', schema=coreschema.Array(title='A', items=coreschema.Integer())), coreapi.Field('b', required=True, location='form', schema=coreschema.Array(title='B', items=coreschema.String())), ] ), 'custom_list_action': coreapi.Link( url='/example/custom_list_action/', action='get' ), 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='get', description='Custom description.', ), 'create': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='post', description='Custom description.', ), 'delete': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='delete', description='Deletion description.', ), }, 'documented_custom_action': { 'read': coreapi.Link( url='/example/documented_custom_action/', action='get', description='A description of the get method on the custom action.', ), 'create': coreapi.Link( url='/example/documented_custom_action/', action='post', description='A description of the post method on the custom action.', encoding='application/json', fields=[ coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) ] ), 'update': coreapi.Link( url='/example/documented_custom_action/', action='put', description='A description of the put method on the custom action from mapping.', encoding='application/json', fields=[ coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) ] ), }, 'update': coreapi.Link( url='/example/{id}/', action='put', encoding='application/json', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description=('A field description'))), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'partial_update': coreapi.Link( url='/example/{id}/', action='patch', encoding='application/json', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('a', required=False, location='form', schema=coreschema.String(title='A', description='A field description')), coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'delete': coreapi.Link( url='/example/{id}/', action='delete', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ) } } )
+ assert response.data == expected
class DenyAllUsingHttp404(permissions.BasePermission):
@@ -400,260 +214,84 @@ class ExampleDetailView(APIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
-class TestSchemaGenerator(TestCase):
- def setUp(self):
- self.patterns = [
- url(r'^example/?$', ExampleListView.as_view()),
- url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()),
- url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()),
- ]
+def setUp(self):
+ self.patterns = [ url(r'^example/?$', ExampleListView.as_view()), url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()), url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), ]
def test_schema_for_regular_views(self):
- """
+ """
Ensure that schema generation works for APIView classes.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- schema = generator.get_schema()
- expected = coreapi.Document(
- url='',
- title='Example API',
- content={
- 'example': {
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- fields=[]
- ),
- 'list': coreapi.Link(
- url='/example/',
- action='get',
- fields=[]
- ),
- 'read': coreapi.Link(
- url='/example/{id}/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- ),
- 'sub': {
- 'list': coreapi.Link(
- url='/example/{id}/sub/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
- }
- }
- }
- )
- assert schema == expected
+ generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ schema = generator.get_schema()
+ expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/example/', action='post', fields=[] ), 'list': coreapi.Link( url='/example/', action='get', fields=[] ), 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ), 'sub': { 'list': coreapi.Link( url='/example/{id}/sub/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ) } } } )
+ assert schema == expected
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@unittest.skipUnless(path, 'needs Django 2')
-class TestSchemaGeneratorDjango2(TestCase):
- def setUp(self):
- self.patterns = [
- path('example/', ExampleListView.as_view()),
- path('example//', ExampleDetailView.as_view()),
- path('example//sub/', ExampleDetailView.as_view()),
- ]
+def setUp(self):
+ self.patterns = [ path('example/', ExampleListView.as_view()), path('example//', ExampleDetailView.as_view()), path('example//sub/', ExampleDetailView.as_view()), ]
def test_schema_for_regular_views(self):
- """
+ """
Ensure that schema generation works for APIView classes.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- schema = generator.get_schema()
- expected = coreapi.Document(
- url='',
- title='Example API',
- content={
- 'example': {
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- fields=[]
- ),
- 'list': coreapi.Link(
- url='/example/',
- action='get',
- fields=[]
- ),
- 'read': coreapi.Link(
- url='/example/{id}/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- ),
- 'sub': {
- 'list': coreapi.Link(
- url='/example/{id}/sub/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
- }
- }
- }
- )
- assert schema == expected
+ generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ schema = generator.get_schema()
+ expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/example/', action='post', fields=[] ), 'list': coreapi.Link( url='/example/', action='get', fields=[] ), 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ), 'sub': { 'list': coreapi.Link( url='/example/{id}/sub/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ) } } } )
+ assert schema == expected
@unittest.skipUnless(coreapi, 'coreapi is not installed')
-class TestSchemaGeneratorNotAtRoot(TestCase):
- def setUp(self):
- self.patterns = [
- url(r'^api/v1/example/?$', ExampleListView.as_view()),
- url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()),
- url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()),
- ]
+def setUp(self):
+ self.patterns = [ url(r'^api/v1/example/?$', ExampleListView.as_view()), url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), ]
def test_schema_for_regular_views(self):
- """
+ """
Ensure that schema generation with an API that is not at the URL
root continues to use correct structure for link keys.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- schema = generator.get_schema()
- expected = coreapi.Document(
- url='',
- title='Example API',
- content={
- 'example': {
- 'create': coreapi.Link(
- url='/api/v1/example/',
- action='post',
- fields=[]
- ),
- 'list': coreapi.Link(
- url='/api/v1/example/',
- action='get',
- fields=[]
- ),
- 'read': coreapi.Link(
- url='/api/v1/example/{id}/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- ),
- 'sub': {
- 'list': coreapi.Link(
- url='/api/v1/example/{id}/sub/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
- }
- }
- }
- )
- assert schema == expected
+ generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ schema = generator.get_schema()
+ expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/api/v1/example/', action='post', fields=[] ), 'list': coreapi.Link( url='/api/v1/example/', action='get', fields=[] ), 'read': coreapi.Link( url='/api/v1/example/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ), 'sub': { 'list': coreapi.Link( url='/api/v1/example/{id}/sub/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()) ] ) } } } )
+ assert schema == expected
@unittest.skipUnless(coreapi, 'coreapi is not installed')
-class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
- def setUp(self):
- router = DefaultRouter()
- router.register('example1', MethodLimitedViewSet, basename='example1')
- self.patterns = [
- url(r'^', include(router.urls))
- ]
+def setUp(self):
+ router = DefaultRouter()
+ router.register('example1', MethodLimitedViewSet, basename='example1')
+ self.patterns = [ url(r'^', include(router.urls)) ]
def test_schema_for_regular_views(self):
- """
+ """
Ensure that schema generation works for ViewSet classes
with method limitation by Django CBV's http_method_names attribute
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- request = factory.get('/example1/')
- schema = generator.get_schema(Request(request))
-
- expected = coreapi.Document(
- url='http://testserver/example1/',
- title='Example API',
- content={
- 'example1': {
- 'list': coreapi.Link(
- url='/example1/',
- action='get',
- fields=[
- coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')),
- coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- ),
- 'custom_list_action': coreapi.Link(
- url='/example1/custom_list_action/',
- action='get'
- ),
- 'custom_list_action_multiple_methods': {
- 'read': coreapi.Link(
- url='/example1/custom_list_action_multiple_methods/',
- action='get',
- description='Custom description.',
- )
- },
- 'documented_custom_action': {
- 'read': coreapi.Link(
- url='/example1/documented_custom_action/',
- action='get',
- description='A description of the get method on the custom action.',
- ),
- },
- 'read': coreapi.Link(
- url='/example1/{id}/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- )
- }
- }
- )
- assert schema == expected
+ generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ request = factory.get('/example1/')
+ schema = generator.get_schema(Request(request))
+ expected = coreapi.Document( url='http://testserver/example1/', title='Example API', content={ 'example1': { 'list': coreapi.Link( url='/example1/', action='get', fields=[ coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ), 'custom_list_action': coreapi.Link( url='/example1/custom_list_action/', action='get' ), 'custom_list_action_multiple_methods': { 'read': coreapi.Link( url='/example1/custom_list_action_multiple_methods/', action='get', description='Custom description.', ) }, 'documented_custom_action': { 'read': coreapi.Link( url='/example1/documented_custom_action/', action='get', description='A description of the get method on the custom action.', ), }, 'read': coreapi.Link( url='/example1/{id}/', action='get', fields=[ coreapi.Field('id', required=True, location='path', schema=coreschema.String()), coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) ] ) } } )
+ assert schema == expected
@unittest.skipUnless(coreapi, 'coreapi is not installed')
-class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
- def setUp(self):
- router = DefaultRouter()
- router.register('example1', Http404ExampleViewSet, basename='example1')
- router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
- self.patterns = [
- url('^example/?$', ExampleListView.as_view()),
- url(r'^', include(router.urls))
- ]
+def setUp(self):
+ router = DefaultRouter()
+ router.register('example1', Http404ExampleViewSet, basename='example1')
+ router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
+ self.patterns = [ url('^example/?$', ExampleListView.as_view()), url(r'^', include(router.urls)) ]
def test_schema_for_regular_views(self):
- """
+ """
Ensure that schema generation works for ViewSet classes
with permission classes raising exceptions.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- request = factory.get('/')
- schema = generator.get_schema(Request(request))
- expected = coreapi.Document(
- url='http://testserver/',
- title='Example API',
- content={
- 'example': {
- 'list': coreapi.Link(
- url='/example/',
- action='get',
- fields=[]
- ),
- },
- }
- )
- assert schema == expected
+ generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ request = factory.get('/')
+ schema = generator.get_schema(Request(request))
+ expected = coreapi.Document( url='http://testserver/', title='Example API', content={ 'example': { 'list': coreapi.Link( url='/example/', action='get', fields=[] ), }, } )
+ assert schema == expected
class ForeignKeySourceSerializer(serializers.ModelSerializer):
@@ -668,37 +306,17 @@ class ForeignKeySourceView(generics.CreateAPIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
-class TestSchemaGeneratorWithForeignKey(TestCase):
- def setUp(self):
- self.patterns = [
- url(r'^example/?$', ForeignKeySourceView.as_view()),
- ]
+def setUp(self):
+ self.patterns = [ url(r'^example/?$', ForeignKeySourceView.as_view()), ]
def test_schema_for_regular_views(self):
- """
+ """
Ensure that AutoField foreign keys are output as Integer.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- schema = generator.get_schema()
-
- expected = coreapi.Document(
- url='',
- title='Example API',
- content={
- 'example': {
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- encoding='application/json',
- fields=[
- coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')),
- coreapi.Field('target', required=True, location='form', schema=coreschema.Integer(description='Target', title='Target')),
- ]
- )
- }
- }
- )
- assert schema == expected
+ generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ schema = generator.get_schema()
+ expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/example/', action='post', encoding='application/json', fields=[ coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')), coreapi.Field('target', required=True, location='form', schema=coreschema.Integer(description='Target', title='Target')), ] ) } } )
+ assert schema == expected
class ManyToManySourceSerializer(serializers.ModelSerializer):
@@ -713,48 +331,24 @@ class ManyToManySourceView(generics.CreateAPIView):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
-class TestSchemaGeneratorWithManyToMany(TestCase):
- def setUp(self):
- self.patterns = [
- url(r'^example/?$', ManyToManySourceView.as_view()),
- ]
+def setUp(self):
+ self.patterns = [ url(r'^example/?$', ManyToManySourceView.as_view()), ]
def test_schema_for_regular_views(self):
- """
+ """
Ensure that AutoField many to many fields are output as Integer.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- schema = generator.get_schema()
-
- expected = coreapi.Document(
- url='',
- title='Example API',
- content={
- 'example': {
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- encoding='application/json',
- fields=[
- coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')),
- coreapi.Field('targets', required=True, location='form', schema=coreschema.Array(title='Targets', items=coreschema.Integer())),
- ]
- )
- }
- }
- )
- assert schema == expected
+ generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ schema = generator.get_schema()
+ expected = coreapi.Document( url='', title='Example API', content={ 'example': { 'create': coreapi.Link( url='/example/', action='post', encoding='application/json', fields=[ coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')), coreapi.Field('targets', required=True, location='form', schema=coreschema.Array(title='Targets', items=coreschema.Integer())), ] ) } } )
+ assert schema == expected
@unittest.skipUnless(coreapi, 'coreapi is not installed')
-class Test4605Regression(TestCase):
- def test_4605_regression(self):
- generator = SchemaGenerator()
- prefix = generator.determine_path_prefix([
- '/api/v1/items/',
- '/auth/convert-token/'
- ])
- assert prefix == '/'
+def test_4605_regression(self):
+ generator = SchemaGenerator()
+ prefix = generator.determine_path_prefix([ '/api/v1/items/', '/auth/convert-token/' ])
+ assert prefix == '/'
class CustomViewInspector(AutoSchema):
@@ -762,213 +356,113 @@ class CustomViewInspector(AutoSchema):
pass
-class TestAutoSchema(TestCase):
- def test_apiview_schema_descriptor(self):
- view = APIView()
- assert hasattr(view, 'schema')
- assert isinstance(view.schema, AutoSchema)
+def test_apiview_schema_descriptor(self):
+ view = APIView()
+ assert hasattr(view, 'schema')
+ assert isinstance(view.schema, AutoSchema)
def test_set_custom_inspector_class_on_view(self):
- class CustomView(APIView):
- schema = CustomViewInspector()
+ class CustomView(APIView):
+ schema = CustomViewInspector()
view = CustomView()
- assert isinstance(view.schema, CustomViewInspector)
+assertisinstance(view.schema,CustomViewInspector)
def test_set_custom_inspector_class_via_settings(self):
- with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}):
- view = APIView()
- assert isinstance(view.schema, CustomViewInspector)
+ with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}):
+ view = APIView()
+ assert isinstance(view.schema, CustomViewInspector)
def test_get_link_requires_instance(self):
- descriptor = APIView.schema # Accessed from class
- with pytest.raises(AssertionError):
- descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
+ descriptor = APIView.schema # Accessed from class
+ with pytest.raises(AssertionError):
+ descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
- def test_update_fields(self):
- """
+deftest_update_fields(self):
+ """
That updating fields by-name helper is correct
Recall: `update_fields(fields, update_with)`
"""
- schema = AutoSchema()
- fields = []
-
- # Adds a field...
- fields = schema.update_fields(fields, [
- coreapi.Field(
- "my_field",
- required=True,
- location="path",
- schema=coreschema.String()
- ),
- ])
-
- assert len(fields) == 1
- assert fields[0].name == "my_field"
-
- # Replaces a field...
- fields = schema.update_fields(fields, [
- coreapi.Field(
- "my_field",
- required=False,
- location="path",
- schema=coreschema.String()
- ),
- ])
-
- assert len(fields) == 1
- assert fields[0].required is False
+ schema = AutoSchema()
+ fields = []
+ fields = schema.update_fields(fields, [ coreapi.Field( "my_field", required=True, location="path", schema=coreschema.String() ), ])
+ assert len(fields) == 1
+ assert fields[0].name == "my_field"
+ fields = schema.update_fields(fields, [ coreapi.Field( "my_field", required=False, location="path", schema=coreschema.String() ), ])
+ assert len(fields) == 1
+ assert fields[0].required is False
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
- def test_get_manual_fields(self):
- """That get_manual_fields is applied during get_link"""
-
- class CustomView(APIView):
- schema = AutoSchema(manual_fields=[
- coreapi.Field(
- "my_extra_field",
- required=True,
- location="path",
- schema=coreschema.String()
- ),
- ])
+deftest_get_manual_fields(self):
+ """That get_manual_fields is applied during get_link"""
+ class CustomView(APIView):
+ schema = AutoSchema(manual_fields=[ coreapi.Field( "my_extra_field", required=True, location="path", schema=coreschema.String() ), ])
view = CustomView()
- link = view.schema.get_link('/a/url/{id}/', 'GET', '')
- fields = link.fields
-
- assert len(fields) == 2
- assert "my_extra_field" in [f.name for f in fields]
+link=view.schema.get_link('/a/url/{id}/','GET','')
+fields=link.fields
+assertlen(fields)==2
+assert"my_extra_field"in[f.nameforfinfields]
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
- def test_viewset_action_with_schema(self):
- class CustomViewSet(GenericViewSet):
- @action(detail=True, schema=AutoSchema(manual_fields=[
- coreapi.Field(
- "my_extra_field",
- required=True,
- location="path",
- schema=coreschema.String()
- ),
- ]))
- def extra_action(self, pk, **kwargs):
- pass
+deftest_viewset_action_with_schema(self):
+ class CustomViewSet(GenericViewSet):
+ @action(detail=True, schema=AutoSchema(manual_fields=[ coreapi.Field( "my_extra_field", required=True, location="path", schema=coreschema.String() ), ]))
+ def extra_action(self, pk, **kwargs):
+ pass
router = SimpleRouter()
- router.register(r'detail', CustomViewSet, basename='detail')
-
- generator = SchemaGenerator()
- view = generator.create_view(router.urls[0].callback, 'GET')
- link = view.schema.get_link('/a/url/{id}/', 'GET', '')
- fields = link.fields
-
- assert len(fields) == 2
- assert "my_extra_field" in [f.name for f in fields]
+router.register(r'detail',CustomViewSet,basename='detail')
+generator=SchemaGenerator()
+view=generator.create_view(router.urls[0].callback,'GET')
+link=view.schema.get_link('/a/url/{id}/','GET','')
+fields=link.fields
+assertlen(fields)==2
+assert"my_extra_field"in[f.nameforfinfields]
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
- def test_viewset_action_with_null_schema(self):
- class CustomViewSet(GenericViewSet):
- @action(detail=True, schema=None)
- def extra_action(self, pk, **kwargs):
- pass
+deftest_viewset_action_with_null_schema(self):
+ class CustomViewSet(GenericViewSet):
+ @action(detail=True, schema=None)
+ def extra_action(self, pk, **kwargs):
+ pass
router = SimpleRouter()
- router.register(r'detail', CustomViewSet, basename='detail')
-
- generator = SchemaGenerator()
- view = generator.create_view(router.urls[0].callback, 'GET')
- assert view.schema is None
+router.register(r'detail',CustomViewSet,basename='detail')
+generator=SchemaGenerator()
+view=generator.create_view(router.urls[0].callback,'GET')
+assertview.schemaisNone
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
- def test_view_with_manual_schema(self):
+deftest_view_with_manual_schema(self):
- path = '/example'
- method = 'get'
- base_url = None
-
- fields = [
- coreapi.Field(
- "first_field",
- required=True,
- location="path",
- schema=coreschema.String()
- ),
- coreapi.Field(
- "second_field",
- required=True,
- location="path",
- schema=coreschema.String()
- ),
- coreapi.Field(
- "third_field",
- required=True,
- location="path",
- schema=coreschema.String()
- ),
- ]
- description = "A test endpoint"
-
- class CustomView(APIView):
- """
+ path = '/example'
+ method = 'get'
+ base_url = None
+ fields = [ coreapi.Field( "first_field", required=True, location="path", schema=coreschema.String() ), coreapi.Field( "second_field", required=True, location="path", schema=coreschema.String() ), coreapi.Field( "third_field", required=True, location="path", schema=coreschema.String() ), ]
+ description = "A test endpoint"
+ class CustomView(APIView):
+ """
ManualSchema takes list of fields for endpoint.
- Provides url and action, which are always dynamic
"""
- schema = ManualSchema(fields, description)
+ schema = ManualSchema(fields, description)
- expected = coreapi.Link(
- url=path,
- action=method,
- fields=fields,
- description=description
- )
-
- view = CustomView()
- link = view.schema.get_link(path, method, base_url)
- assert link == expected
+ expected = coreapi.Link(url=path,action=method,fields=fields,description=description)
+view=CustomView()
+link=view.schema.get_link(path,method,base_url)
+assertlink==expected
@unittest.skipUnless(coreschema, 'coreschema is not installed')
- def test_field_to_schema(self):
- label = 'Test label'
- help_text = 'This is a helpful test text'
-
- cases = [
- # tuples are ([field], [expected schema])
- # TODO: Add remaining cases
- (
- serializers.BooleanField(label=label, help_text=help_text),
- coreschema.Boolean(title=label, description=help_text)
- ),
- (
- serializers.DecimalField(1000, 1000, label=label, help_text=help_text),
- coreschema.Number(title=label, description=help_text)
- ),
- (
- serializers.FloatField(label=label, help_text=help_text),
- coreschema.Number(title=label, description=help_text)
- ),
- (
- serializers.IntegerField(label=label, help_text=help_text),
- coreschema.Integer(title=label, description=help_text)
- ),
- (
- serializers.DateField(label=label, help_text=help_text),
- coreschema.String(title=label, description=help_text, format='date')
- ),
- (
- serializers.DateTimeField(label=label, help_text=help_text),
- coreschema.String(title=label, description=help_text, format='date-time')
- ),
- (
- serializers.JSONField(label=label, help_text=help_text),
- coreschema.Object(title=label, description=help_text)
- ),
- ]
-
- for case in cases:
- self.assertEqual(field_to_schema(case[0]), case[1])
+deftest_field_to_schema(self):
+ label = 'Test label'
+ help_text = 'This is a helpful test text'
+ cases = [ ( serializers.BooleanField(label=label, help_text=help_text), coreschema.Boolean(title=label, description=help_text) ), ( serializers.DecimalField(1000, 1000, label=label, help_text=help_text), coreschema.Number(title=label, description=help_text) ), ( serializers.FloatField(label=label, help_text=help_text), coreschema.Number(title=label, description=help_text) ), ( serializers.IntegerField(label=label, help_text=help_text), coreschema.Integer(title=label, description=help_text) ), ( serializers.DateField(label=label, help_text=help_text), coreschema.String(title=label, description=help_text, format='date') ), ( serializers.DateTimeField(label=label, help_text=help_text), coreschema.String(title=label, description=help_text, format='date-time') ), ( serializers.JSONField(label=label, help_text=help_text), coreschema.Object(title=label, description=help_text) ), ]
+ for case in cases:
+ assert field_to_schema(case[0]) == case[1]
def test_docstring_is_not_stripped_by_get_description():
@@ -1026,56 +520,33 @@ def included_fbv(request):
@unittest.skipUnless(coreapi, 'coreapi is not installed')
-class SchemaGenerationExclusionTests(TestCase):
- def setUp(self):
- self.patterns = [
- url('^excluded-cbv/$', ExcludedAPIView.as_view()),
- url('^excluded-fbv/$', excluded_fbv),
- url('^included-fbv/$', included_fbv),
- ]
+def setUp(self):
+ self.patterns = [ url('^excluded-cbv/$', ExcludedAPIView.as_view()), url('^excluded-fbv/$', excluded_fbv), url('^included-fbv/$', included_fbv), ]
def test_schema_generator_excludes_correctly(self):
- """Schema should not include excluded views"""
- generator = SchemaGenerator(title='Exclusions', patterns=self.patterns)
- schema = generator.get_schema()
- expected = coreapi.Document(
- url='',
- title='Exclusions',
- content={
- 'included-fbv': {
- 'list': coreapi.Link(url='/included-fbv/', action='get')
- }
- }
- )
-
- assert len(schema.data) == 1
- assert 'included-fbv' in schema.data
- assert schema == expected
+ """Schema should not include excluded views"""
+ generator = SchemaGenerator(title='Exclusions', patterns=self.patterns)
+ schema = generator.get_schema()
+ expected = coreapi.Document( url='', title='Exclusions', content={ 'included-fbv': { 'list': coreapi.Link(url='/included-fbv/', action='get') } } )
+ assert len(schema.data) == 1
+ assert 'included-fbv' in schema.data
+ assert schema == expected
def test_endpoint_enumerator_excludes_correctly(self):
- """It is responsibility of EndpointEnumerator to exclude views"""
- inspector = EndpointEnumerator(self.patterns)
- endpoints = inspector.get_api_endpoints()
-
- assert len(endpoints) == 1
- path, method, callback = endpoints[0]
- assert path == '/included-fbv/'
+ """It is responsibility of EndpointEnumerator to exclude views"""
+ inspector = EndpointEnumerator(self.patterns)
+ endpoints = inspector.get_api_endpoints()
+ assert len(endpoints) == 1
+ path, method, callback = endpoints[0]
+ assert path == '/included-fbv/'
def test_should_include_endpoint_excludes_correctly(self):
- """This is the specific method that should handle the exclusion"""
- inspector = EndpointEnumerator(self.patterns)
-
- # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test
- pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback)
- for pattern in self.patterns]
-
- should_include = [
- inspector.should_include_endpoint(*pair) for pair in pairs
- ]
-
- expected = [False, False, True]
-
- assert should_include == expected
+ """This is the specific method that should handle the exclusion"""
+ inspector = EndpointEnumerator(self.patterns)
+ pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback) for pattern in self.patterns]
+ should_include = [ inspector.should_include_endpoint(*pair) for pair in pairs ]
+ expected = [False, False, True]
+ assert should_include == expected
@api_view(["GET"])
@@ -1118,126 +589,63 @@ naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
-class TestURLNamingCollisions(TestCase):
- """
+"""
Ref: https://github.com/encode/django-rest-framework/issues/4704
"""
- def test_manually_routing_nested_routes(self):
- patterns = [
- url(r'^test', simple_fbv),
- url(r'^test/list/', simple_fbv),
- ]
-
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
- schema = generator.get_schema()
-
- expected = coreapi.Document(
- url='',
- title='Naming Colisions',
- content={
- 'test': {
- 'list': {
- 'list': coreapi.Link(url='/test/list/', action='get')
- },
- 'list_0': coreapi.Link(url='/test', action='get')
- }
- }
- )
-
- assert expected == schema
+deftest_manually_routing_nested_routes(self):
+ patterns = [ url(r'^test', simple_fbv), url(r'^test/list/', simple_fbv), ]
+ generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ schema = generator.get_schema()
+ expected = coreapi.Document( url='', title='Naming Colisions', content={ 'test': { 'list': { 'list': coreapi.Link(url='/test/list/', action='get') }, 'list_0': coreapi.Link(url='/test', action='get') } } )
+ assert expected == schema
def _verify_cbv_links(self, loc, url, methods=None, suffixes=None):
- if methods is None:
- methods = ('read', 'update', 'partial_update', 'delete')
+ if methods is None:
+ methods = ('read', 'update', 'partial_update', 'delete')
if suffixes is None:
- suffixes = (None for m in methods)
+ suffixes = (None for m in methods)
for method, suffix in zip(methods, suffixes):
- if suffix is not None:
- key = '{}_{}'.format(method, suffix)
+ if suffix is not None:
+ key = '{}_{}'.format(method, suffix)
else:
- key = method
+ key = method
assert loc[key].url == url
def test_manually_routing_generic_view(self):
- patterns = [
- url(r'^test', NamingCollisionView.as_view()),
- url(r'^test/retrieve/', NamingCollisionView.as_view()),
- url(r'^test/update/', NamingCollisionView.as_view()),
-
- # Fails with method names:
- url(r'^test/get/', NamingCollisionView.as_view()),
- url(r'^test/put/', NamingCollisionView.as_view()),
- url(r'^test/delete/', NamingCollisionView.as_view()),
- ]
-
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
-
- schema = generator.get_schema()
-
- self._verify_cbv_links(schema['test']['delete'], '/test/delete/')
- self._verify_cbv_links(schema['test']['put'], '/test/put/')
- self._verify_cbv_links(schema['test']['get'], '/test/get/')
- self._verify_cbv_links(schema['test']['update'], '/test/update/')
- self._verify_cbv_links(schema['test']['retrieve'], '/test/retrieve/')
- self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0'))
+ patterns = [ url(r'^test', NamingCollisionView.as_view()), url(r'^test/retrieve/', NamingCollisionView.as_view()), url(r'^test/update/', NamingCollisionView.as_view()), url(r'^test/get/', NamingCollisionView.as_view()), url(r'^test/put/', NamingCollisionView.as_view()), url(r'^test/delete/', NamingCollisionView.as_view()), ]
+ generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ schema = generator.get_schema()
+ self._verify_cbv_links(schema['test']['delete'], '/test/delete/')
+ self._verify_cbv_links(schema['test']['put'], '/test/put/')
+ self._verify_cbv_links(schema['test']['get'], '/test/get/')
+ self._verify_cbv_links(schema['test']['update'], '/test/update/')
+ self._verify_cbv_links(schema['test']['retrieve'], '/test/retrieve/')
+ self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0'))
def test_from_router(self):
- patterns = [
- url(r'from-router', include(naming_collisions_router.urls)),
- ]
-
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
- schema = generator.get_schema()
-
- # not important here
- desc_0 = schema['detail']['detail_export'].description
- desc_1 = schema['detail_0'].description
-
- expected = coreapi.Document(
- url='',
- title='Naming Colisions',
- content={
- 'detail': {
- 'detail_export': coreapi.Link(
- url='/from-routercollision/detail/export/',
- action='get',
- description=desc_0)
- },
- 'detail_0': coreapi.Link(
- url='/from-routercollision/detail/',
- action='get',
- description=desc_1
- )
- }
- )
-
- assert schema == expected
+ patterns = [ url(r'from-router', include(naming_collisions_router.urls)), ]
+ generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ schema = generator.get_schema()
+ desc_0 = schema['detail']['detail_export'].description
+ desc_1 = schema['detail_0'].description
+ expected = coreapi.Document( url='', title='Naming Colisions', content={ 'detail': { 'detail_export': coreapi.Link( url='/from-routercollision/detail/export/', action='get', description=desc_0) }, 'detail_0': coreapi.Link( url='/from-routercollision/detail/', action='get', description=desc_1 ) } )
+ assert schema == expected
def test_url_under_same_key_not_replaced(self):
- patterns = [
- url(r'example/(?P\d+)/$', BasicNamingCollisionView.as_view()),
- url(r'example/(?P\w+)/$', BasicNamingCollisionView.as_view()),
- ]
-
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
- schema = generator.get_schema()
-
- assert schema['example']['read'].url == '/example/{id}/'
- assert schema['example']['read_0'].url == '/example/{slug}/'
+ patterns = [ url(r'example/(?P\d+)/$', BasicNamingCollisionView.as_view()), url(r'example/(?P\w+)/$', BasicNamingCollisionView.as_view()), ]
+ generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ schema = generator.get_schema()
+ assert schema['example']['read'].url == '/example/{id}/'
+ assert schema['example']['read_0'].url == '/example/{slug}/'
def test_url_under_same_key_not_replaced_another(self):
- patterns = [
- url(r'^test/list/', simple_fbv),
- url(r'^test/(?P\d+)/list/', simple_fbv),
- ]
-
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
- schema = generator.get_schema()
-
- assert schema['test']['list']['list'].url == '/test/list/'
- assert schema['test']['list']['list_0'].url == '/test/{id}/list/'
+ patterns = [ url(r'^test/list/', simple_fbv), url(r'^test/(?P\d+)/list/', simple_fbv), ]
+ generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ schema = generator.get_schema()
+ assert schema['test']['list']['list'].url == '/test/list/'
+ assert schema['test']['list']['list_0'].url == '/test/{id}/list/'
def test_is_list_view_recognises_retrieve_view_subclasses():
diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py
index 0465578bb..689aa05bd 100644
--- a/tests/test_serializer_bulk_update.py
+++ b/tests/test_serializer_bulk_update.py
@@ -4,120 +4,66 @@ Tests to cover bulk create and update using serializers.
from django.test import TestCase
from rest_framework import serializers
-
-
-class BulkCreateSerializerTests(TestCase):
- """
+"""
Creating multiple instances using serializers.
"""
-
- def setUp(self):
- class BookSerializer(serializers.Serializer):
- id = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- author = serializers.CharField(max_length=100)
+defsetUp(self):
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
self.BookSerializer = BookSerializer
def test_bulk_create_success(self):
- """
+ """
Correct bulk update serialization should return the input data.
"""
-
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 2,
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
-
- serializer = self.BookSerializer(data=data, many=True)
- assert serializer.is_valid() is True
- assert serializer.validated_data == data
- assert serializer.errors == []
+ data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 2, 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
+ serializer = self.BookSerializer(data=data, many=True)
+ assert serializer.is_valid() is True
+ assert serializer.validated_data == data
+ assert serializer.errors == []
def test_bulk_create_errors(self):
- """
+ """
Incorrect bulk create serialization should return errors.
"""
-
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 'foo',
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
- expected_errors = [
- {},
- {},
- {'id': ['A valid integer is required.']}
- ]
-
- serializer = self.BookSerializer(data=data, many=True)
- assert serializer.is_valid() is False
- assert serializer.errors == expected_errors
- assert serializer.validated_data == []
+ data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 'foo', 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
+ expected_errors = [ {}, {}, {'id': ['A valid integer is required.']} ]
+ serializer = self.BookSerializer(data=data, many=True)
+ assert serializer.is_valid() is False
+ assert serializer.errors == expected_errors
+ assert serializer.validated_data == []
def test_invalid_list_datatype(self):
- """
+ """
Data containing list of incorrect data type should return errors.
"""
- data = ['foo', 'bar', 'baz']
- serializer = self.BookSerializer(data=data, many=True)
- assert serializer.is_valid() is False
-
- message = 'Invalid data. Expected a dictionary, but got str.'
- expected_errors = [
- {'non_field_errors': [message]},
- {'non_field_errors': [message]},
- {'non_field_errors': [message]}
- ]
-
- assert serializer.errors == expected_errors
+ data = ['foo', 'bar', 'baz']
+ serializer = self.BookSerializer(data=data, many=True)
+ assert serializer.is_valid() is False
+ message = 'Invalid data. Expected a dictionary, but got str.'
+ expected_errors = [ {'non_field_errors': [message]}, {'non_field_errors': [message]}, {'non_field_errors': [message]} ]
+ assert serializer.errors == expected_errors
def test_invalid_single_datatype(self):
- """
+ """
Data containing a single incorrect data type should return errors.
"""
- data = 123
- serializer = self.BookSerializer(data=data, many=True)
- assert serializer.is_valid() is False
-
- expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
-
- assert serializer.errors == expected_errors
+ data = 123
+ serializer = self.BookSerializer(data=data, many=True)
+ assert serializer.is_valid() is False
+ expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
+ assert serializer.errors == expected_errors
def test_invalid_single_object(self):
- """
+ """
Data containing only a single object, instead of a list of objects
should return errors.
"""
- data = {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }
- serializer = self.BookSerializer(data=data, many=True)
- assert serializer.is_valid() is False
-
- expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
-
- assert serializer.errors == expected_errors
+ data = { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }
+ serializer = self.BookSerializer(data=data, many=True)
+ assert serializer.is_valid() is False
+ expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
+ assert serializer.errors == expected_errors