unittest to pytest assertion convertions

This commit is contained in:
Asif Saif Uddin 2019-05-02 13:36:18 +06:00
parent 375f1b59c6
commit 3d28279d05
31 changed files with 3503 additions and 5311 deletions

2
.gitignore vendored
View File

@ -2,7 +2,7 @@
*.db
*~
.*
.py.bak
*.py.bak
/site/

View File

@ -11,6 +11,7 @@ from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
from tests.models import BasicModel
import pytest
factory = APIRequestFactory()
@ -52,51 +53,49 @@ urlpatterns = (
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
class DBTransactionTests(TestCase):
def setUp(self):
self.view = BasicView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
def setUp(self):
self.view = BasicView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_no_exception_commit_transaction(self):
request = factory.post('/')
with self.assertNumQueries(1):
response = self.view(request)
request = factory.post('/')
with self.assertNumQueries(1):
response = self.view(request)
assert not transaction.get_rollback()
assert response.status_code == status.HTTP_200_OK
assert BasicModel.objects.count() == 1
assertresponse.status_code==status.HTTP_200_OK
assertBasicModel.objects.count()==1
@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
class DBTransactionErrorTests(TestCase):
def setUp(self):
self.view = ErrorView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
def setUp(self):
self.view = ErrorView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_generic_exception_delegate_transaction_management(self):
"""
"""
Transaction is eventually managed by outer-most transaction atomic
block. DRF do not try to interfere here.
We let django deal with the transaction when it will catch the Exception.
"""
request = factory.post('/')
with self.assertNumQueries(3):
request = factory.post('/')
with self.assertNumQueries(3):
# 1 - begin savepoint
# 2 - insert
# 3 - release savepoint
with transaction.atomic():
self.assertRaises(Exception, self.view, request)
assert not transaction.get_rollback()
with transaction.atomic():
with pytest.raises(Exception):
self.view(request)
assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1
@ -104,30 +103,29 @@ class DBTransactionErrorTests(TestCase):
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
class DBTransactionAPIExceptionTests(TestCase):
def setUp(self):
self.view = APIExceptionView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
def setUp(self):
self.view = APIExceptionView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_api_exception_rollback_transaction(self):
"""
"""
Transaction is rollbacked by our transaction atomic block.
"""
request = factory.post('/')
num_queries = 4 if connection.features.can_release_savepoints else 3
with self.assertNumQueries(num_queries):
request = factory.post('/')
num_queries = 4 if connection.features.can_release_savepoints else 3
with self.assertNumQueries(num_queries):
# 1 - begin savepoint
# 2 - insert
# 3 - rollback savepoint
# 4 - release savepoint
with transaction.atomic():
response = self.view(request)
assert transaction.get_rollback()
with transaction.atomic():
response = self.view(request)
assert transaction.get_rollback()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.count() == 0
assertBasicModel.objects.count()==0
@unittest.skipUnless(

View File

@ -13,74 +13,68 @@ from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.exceptions import ValidationError
class AuthTokenTests(TestCase):
def setUp(self):
self.site = site
self.user = User.objects.create_user(username='test_user')
self.token = Token.objects.create(key='test token', user=self.user)
def setUp(self):
self.site = site
self.user = User.objects.create_user(username='test_user')
self.token = Token.objects.create(key='test token', user=self.user)
def test_model_admin_displayed_fields(self):
mock_request = object()
token_admin = TokenAdmin(self.token, self.site)
assert token_admin.get_fields(mock_request) == ('user',)
mock_request = object()
token_admin = TokenAdmin(self.token, self.site)
assert token_admin.get_fields(mock_request) == ('user',)
def test_token_string_representation(self):
assert str(self.token) == 'test token'
assert str(self.token) == 'test token'
def test_validate_raise_error_if_no_credentials_provided(self):
with pytest.raises(ValidationError):
AuthTokenSerializer().validate({})
with pytest.raises(ValidationError):
AuthTokenSerializer().validate({})
def test_whitespace_in_password(self):
data = {'username': self.user.username, 'password': 'test pass '}
self.user.set_password(data['password'])
self.user.save()
assert AuthTokenSerializer(data=data).is_valid()
data = {'username': self.user.username, 'password': 'test pass '}
self.user.set_password(data['password'])
self.user.save()
assert AuthTokenSerializer(data=data).is_valid()
class AuthTokenCommandTests(TestCase):
def setUp(self):
self.site = site
self.user = User.objects.create_user(username='test_user')
def setUp(self):
self.site = site
self.user = User.objects.create_user(username='test_user')
def test_command_create_user_token(self):
token = AuthTokenCommand().create_user_token(self.user.username, False)
assert token is not None
token_saved = Token.objects.first()
assert token.key == token_saved.key
token = AuthTokenCommand().create_user_token(self.user.username, False)
assert token is not None
token_saved = Token.objects.first()
assert token.key == token_saved.key
def test_command_create_user_token_invalid_user(self):
with pytest.raises(User.DoesNotExist):
AuthTokenCommand().create_user_token('not_existing_user', False)
with pytest.raises(User.DoesNotExist):
AuthTokenCommand().create_user_token('not_existing_user', False)
def test_command_reset_user_token(self):
AuthTokenCommand().create_user_token(self.user.username, False)
first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, True)
second_token_key = Token.objects.first().key
assert first_token_key != second_token_key
AuthTokenCommand().create_user_token(self.user.username, False)
first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, True)
second_token_key = Token.objects.first().key
assert first_token_key != second_token_key
def test_command_do_not_reset_user_token(self):
AuthTokenCommand().create_user_token(self.user.username, False)
first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, False)
second_token_key = Token.objects.first().key
assert first_token_key == second_token_key
AuthTokenCommand().create_user_token(self.user.username, False)
first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, False)
second_token_key = Token.objects.first().key
assert first_token_key == second_token_key
def test_command_raising_error_for_invalid_user(self):
out = StringIO()
with pytest.raises(CommandError):
call_command('drf_create_token', 'not_existing_user', stdout=out)
out = StringIO()
with pytest.raises(CommandError):
call_command('drf_create_token', 'not_existing_user', stdout=out)
def test_command_output(self):
out = StringIO()
call_command('drf_create_token', self.user.username, stdout=out)
token_saved = Token.objects.first()
self.assertIn('Generated token', out.getvalue())
self.assertIn(self.user.username, out.getvalue())
self.assertIn(token_saved.key, out.getvalue())
out = StringIO()
call_command('drf_create_token', self.user.username, stdout=out)
token_saved = Token.objects.first()
assert 'Generated token' in out.getvalue()
assert self.user.username in out.getvalue()
assert token_saved.key in out.getvalue()

View File

@ -17,307 +17,270 @@ from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView
class DecoratorTestCase(TestCase):
def setUp(self):
self.factory = APIRequestFactory()
def setUp(self):
self.factory = APIRequestFactory()
def _finalize_response(self, request, response, *args, **kwargs):
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
def test_api_view_incorrect(self):
"""
"""
If @api_view is not applied correct, we should raise an assertion.
"""
@api_view
def view(request):
return Response()
@api_view
request = self.factory.get('/')
withpytest.raises(AssertionError):
view(request)
def test_api_view_incorrect_arguments(self):
"""
If @api_view is missing arguments, we should raise an assertion.
"""
with pytest.raises(AssertionError):
@api_view('GET')
def view(request):
return Response()
request = self.factory.get('/')
self.assertRaises(AssertionError, view, request)
def test_api_view_incorrect_arguments(self):
"""
If @api_view is missing arguments, we should raise an assertion.
"""
with self.assertRaises(AssertionError):
@api_view('GET')
def view(request):
return Response()
def test_calling_method(self):
@api_view(['GET'])
def view(request):
return Response({})
@api_view(['GET'])
def view(request):
return Response({})
request = self.factory.get('/')
response = view(request)
assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/')
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
response=view(request)
assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
response=view(request)
assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_put_method(self):
@api_view(['GET', 'PUT'])
def view(request):
return Response({})
@api_view(['GET', 'PUT'])
def view(request):
return Response({})
request = self.factory.put('/')
response = view(request)
assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/')
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
response=view(request)
assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
response=view(request)
assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_patch_method(self):
@api_view(['GET', 'PATCH'])
def view(request):
return Response({})
@api_view(['GET', 'PATCH'])
def view(request):
return Response({})
request = self.factory.patch('/')
response = view(request)
assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/')
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
response=view(request)
assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/')
response=view(request)
assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
def test_renderer_classes(self):
@api_view(['GET'])
@renderer_classes([JSONRenderer])
def view(request):
return Response({})
@api_view(['GET'])
@renderer_classes([JSONRenderer])
def view(request):
return Response({})
request = self.factory.get('/')
response = view(request)
assert isinstance(response.accepted_renderer, JSONRenderer)
response=view(request)
assertisinstance(response.accepted_renderer,JSONRenderer)
def test_parser_classes(self):
@api_view(['GET'])
@parser_classes([JSONParser])
def view(request):
assert len(request.parsers) == 1
assert isinstance(request.parsers[0], JSONParser)
return Response({})
@api_view(['GET'])
@parser_classes([JSONParser])
def view(request):
assert len(request.parsers) == 1
assert isinstance(request.parsers[0], JSONParser)
return Response({})
request = self.factory.get('/')
view(request)
view(request)
def test_authentication_classes(self):
@api_view(['GET'])
@authentication_classes([BasicAuthentication])
def view(request):
assert len(request.authenticators) == 1
assert isinstance(request.authenticators[0], BasicAuthentication)
return Response({})
@api_view(['GET'])
@authentication_classes([BasicAuthentication])
def view(request):
assert len(request.authenticators) == 1
assert isinstance(request.authenticators[0], BasicAuthentication)
return Response({})
request = self.factory.get('/')
view(request)
view(request)
def test_permission_classes(self):
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def view(request):
return Response({})
@api_view(['GET'])
@permission_classes([IsAuthenticated])
def view(request):
return Response({})
request = self.factory.get('/')
response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN
response=view(request)
assertresponse.status_code==status.HTTP_403_FORBIDDEN
def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle):
rate = '1/day'
class OncePerDayUserThrottle(UserRateThrottle):
rate = '1/day'
@api_view(['GET'])
@throttle_classes([OncePerDayUserThrottle])
def view(request):
return Response({})
@throttle_classes([OncePerDayUserThrottle])
defview(request):
return Response({})
request = self.factory.get('/')
response = view(request)
assert response.status_code == status.HTTP_200_OK
response = view(request)
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
response=view(request)
assertresponse.status_code==status.HTTP_200_OK
response=view(request)
assertresponse.status_code==status.HTTP_429_TOO_MANY_REQUESTS
def test_schema(self):
"""
"""
Checks CustomSchema class is set on view
"""
class CustomSchema(AutoSchema):
pass
class CustomSchema(AutoSchema):
pass
@api_view(['GET'])
@schema(CustomSchema())
def view(request):
return Response({})
@schema(CustomSchema())
defview(request):
return Response({})
assert isinstance(view.cls.schema, CustomSchema)
class ActionDecoratorTestCase(TestCase):
def test_defaults(self):
@action(detail=True)
def test_action(request):
"""Description"""
def test_defaults(self):
@action(detail=True)
def test_action(request):
"""Description"""
assert test_action.mapping == {'get': 'test_action'}
assert test_action.detail is True
assert test_action.url_path == 'test_action'
assert test_action.url_name == 'test-action'
assert test_action.kwargs == {
'name': 'Test action',
'description': 'Description',
}
asserttest_action.detailisTrue
asserttest_action.url_path=='test_action'
asserttest_action.url_name=='test-action'
asserttest_action.kwargs=={'name':'Test action','description':'Description',}
def test_detail_required(self):
with pytest.raises(AssertionError) as excinfo:
@action()
def test_action(request):
raise NotImplementedError
with pytest.raises(AssertionError) as excinfo:
@action()
def test_action(request):
raise NotImplementedError
assert str(excinfo.value) == "@action() missing required argument: 'detail'"
def test_method_mapping_http_methods(self):
# All HTTP methods should be mappable
@action(detail=False, methods=[])
def test_action():
raise NotImplementedError
@action(detail=False, methods=[])
def test_action():
raise NotImplementedError
for name in APIView.http_method_names:
def method():
raise NotImplementedError
def method():
raise NotImplementedError
method.__name__ = name
getattr(test_action.mapping, name)(method)
getattr(test_action.mapping,name)(method)
# ensure the mapping returns the correct method name
for name in APIView.http_method_names:
assert test_action.mapping[name] == name
assert test_action.mapping[name] == name
def test_view_name_kwargs(self):
"""
"""
'name' and 'suffix' are mutually exclusive kwargs used for generating
a view's display name.
"""
# by default, generate name from method
@action(detail=True)
def test_action(request):
raise NotImplementedError
@action(detail=True)
def test_action(request):
raise NotImplementedError
assert test_action.kwargs == {
'description': None,
'name': 'Test action',
}
assert test_action.kwargs == {'description':None,'name':'Test action',}
@action(detail=True,name='test name')
deftest_action(request):
raise NotImplementedError
# name kwarg supersedes name generation
@action(detail=True, name='test name')
def test_action(request):
raise NotImplementedError
assert test_action.kwargs == {'description':None,'name':'test name',}
@action(detail=True,suffix='Suffix')
deftest_action(request):
raise NotImplementedError
assert test_action.kwargs == {
'description': None,
'name': 'test name',
}
# suffix kwarg supersedes name generation
@action(detail=True, suffix='Suffix')
def test_action(request):
raise NotImplementedError
assert test_action.kwargs == {
'description': None,
'suffix': 'Suffix',
}
# name + suffix is a conflict.
with pytest.raises(TypeError) as excinfo:
action(detail=True, name='test name', suffix='Suffix')
assert test_action.kwargs == {'description':None,'suffix':'Suffix',}
withpytest.raises(TypeError)asexcinfo:
action(detail=True, name='test name', suffix='Suffix')
assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments."
def test_method_mapping(self):
@action(detail=False)
def test_action(request):
raise NotImplementedError
@action(detail=False)
def test_action(request):
raise NotImplementedError
@test_action.mapping.post
def test_action_post(request):
raise NotImplementedError
deftest_action_post(request):
raise NotImplementedError
# The secondary handler methods should not have the action attributes
for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']:
assert hasattr(test_action, name) and not hasattr(test_action_post, name)
assert hasattr(test_action, name) and not hasattr(test_action_post, name)
def test_method_mapping_already_mapped(self):
@action(detail=True)
def test_action(request):
raise NotImplementedError
@action(detail=True)
def test_action(request):
raise NotImplementedError
msg = "Method 'get' has already been mapped to '.test_action'."
with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.get
def test_action_get(request):
raise NotImplementedError
withself.assertRaisesMessage(AssertionError,msg):
@test_action.mapping.get
def test_action_get(request):
raise NotImplementedError
def test_method_mapping_overwrite(self):
@action(detail=True)
@action(detail=True)
def test_action():
raise NotImplementedError
msg = ("Method mapping does not behave like the property decorator. You ""cannot use the same method name for each mapping declaration.")
withself.assertRaisesMessage(AssertionError,msg):
@test_action.mapping.post
def test_action():
raise NotImplementedError
msg = ("Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.")
with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.post
def test_action():
raise NotImplementedError
def test_detail_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record:
@detail_route()
def view(request):
raise NotImplementedError
with pytest.warns(RemovedInDRF310Warning) as record:
@detail_route()
def view(request):
raise NotImplementedError
assert len(record) == 1
assert str(record[0].message) == (
"`detail_route` is deprecated and will be removed in "
"3.10 in favor of `action`, which accepts a `detail` bool. Use "
"`@action(detail=True)` instead."
)
assertstr(record[0].message)==("`detail_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=True)` instead.")
def test_list_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record:
@list_route()
def view(request):
raise NotImplementedError
with pytest.warns(RemovedInDRF310Warning) as record:
@list_route()
def view(request):
raise NotImplementedError
assert len(record) == 1
assert str(record[0].message) == (
"`list_route` is deprecated and will be removed in "
"3.10 in favor of `action`, which accepts a `detail` bool. Use "
"`@action(detail=False)` instead."
)
assertstr(record[0].message)==("`list_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=False)` instead.")
def test_route_url_name_from_path(self):
# pre-3.8 behavior was to base the `url_name` off of the `url_path`
with pytest.warns(RemovedInDRF310Warning):
@list_route(url_path='foo_bar')
def view(request):
raise NotImplementedError
with pytest.warns(RemovedInDRF310Warning):
@list_route(url_path='foo_bar')
def view(request):
raise NotImplementedError
assert view.url_path == 'foo_bar'
assert view.url_name == 'foo-bar'
assertview.url_name=='foo-bar'

View File

@ -69,37 +69,34 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
</code></pre>
<p>indented</p>
<h2 id="hash-style-header">hash style header</h2>%s"""
class TestViewNamesAndDescriptions(TestCase):
def test_view_name_uses_class_name(self):
"""
def test_view_name_uses_class_name(self):
"""
Ensure view names are based on the class name.
"""
class MockView(APIView):
pass
class MockView(APIView):
pass
assert MockView().get_view_name() == 'Mock'
def test_view_name_uses_name_attribute(self):
class MockView(APIView):
name = 'Foo'
class MockView(APIView):
name = 'Foo'
assert MockView().get_view_name() == 'Foo'
def test_view_name_uses_suffix_attribute(self):
class MockView(APIView):
suffix = 'List'
class MockView(APIView):
suffix = 'List'
assert MockView().get_view_name() == 'Mock List'
def test_view_name_preferences_name_over_suffix(self):
class MockView(APIView):
name = 'Foo'
suffix = 'List'
class MockView(APIView):
name = 'Foo'
suffix = 'List'
assert MockView().get_view_name() == 'Foo'
def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring."""
class MockView(APIView):
"""an example docstring
"""Ensure view descriptions are based on the docstring."""
class MockView(APIView):
"""an example docstring
====================
* list
@ -124,64 +121,53 @@ class TestViewNamesAndDescriptions(TestCase):
assert MockView().get_view_description() == DESCRIPTION
def test_view_description_uses_description_attribute(self):
class MockView(APIView):
description = 'Foo'
class MockView(APIView):
description = 'Foo'
assert MockView().get_view_description() == 'Foo'
def test_view_description_allows_empty_description(self):
class MockView(APIView):
"""Description."""
description = ''
class MockView(APIView):
"""Description."""
description = ''
assert MockView().get_view_description() == ''
def test_view_description_can_be_empty(self):
"""
"""
Ensure that if a view has no docstring,
then it's description is the empty string.
"""
class MockView(APIView):
pass
class MockView(APIView):
pass
assert MockView().get_view_description() == ''
def test_view_description_can_be_promise(self):
"""
"""
Ensure a view may have a docstring that is actually a lazily evaluated
class that can be converted to a string.
See: https://github.com/encode/django-rest-framework/issues/1708
"""
# use a mock object instead of gettext_lazy to ensure that we can't end
# up with a test case string in our l10n catalog
class MockLazyStr:
def __init__(self, string):
self.s = string
class MockLazyStr:
def __init__(self, string):
self.s = string
def __str__(self):
return self.s
return self.s
class MockView(APIView):
__doc__ = MockLazyStr("a gettext string")
__doc__ = MockLazyStr("a gettext string")
assert MockView().get_view_description() == 'a gettext string'
def test_markdown(self):
"""
"""
Ensure markdown to HTML works as expected.
"""
if apply_markdown:
md_applied = apply_markdown(DESCRIPTION)
gte_21_match = (
md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
lt_21_match = (
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
assert gte_21_match or lt_21_match
if apply_markdown:
md_applied = apply_markdown(DESCRIPTION)
gte_21_match = ( md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
lt_21_match = ( md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
assert gte_21_match or lt_21_match
def test_dedent_tabs():

View File

@ -15,81 +15,79 @@ class MockList:
return [1, 2, 3]
class JSONEncoderTests(TestCase):
"""
"""
Tests the JSONEncoder method
"""
def setUp(self):
self.encoder = JSONEncoder()
defsetUp(self):
self.encoder = JSONEncoder()
def test_encode_decimal(self):
"""
"""
Tests encoding a decimal
"""
d = Decimal(3.14)
assert self.encoder.default(d) == float(d)
d = Decimal(3.14)
assert self.encoder.default(d) == float(d)
def test_encode_datetime(self):
"""
"""
Tests encoding a datetime object
"""
current_time = datetime.now()
assert self.encoder.default(current_time) == current_time.isoformat()
current_time_utc = current_time.replace(tzinfo=utc)
assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z'
current_time = datetime.now()
assert self.encoder.default(current_time) == current_time.isoformat()
current_time_utc = current_time.replace(tzinfo=utc)
assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z'
def test_encode_time(self):
"""
"""
Tests encoding a timezone
"""
current_time = datetime.now().time()
assert self.encoder.default(current_time) == current_time.isoformat()
current_time = datetime.now().time()
assert self.encoder.default(current_time) == current_time.isoformat()
def test_encode_time_tz(self):
"""
"""
Tests encoding a timezone aware timestamp
"""
current_time = datetime.now().time()
current_time = current_time.replace(tzinfo=utc)
with pytest.raises(ValueError):
self.encoder.default(current_time)
current_time = datetime.now().time()
current_time = current_time.replace(tzinfo=utc)
with pytest.raises(ValueError):
self.encoder.default(current_time)
def test_encode_date(self):
"""
"""
Tests encoding a date object
"""
current_date = date.today()
assert self.encoder.default(current_date) == current_date.isoformat()
current_date = date.today()
assert self.encoder.default(current_date) == current_date.isoformat()
def test_encode_timedelta(self):
"""
"""
Tests encoding a timedelta object
"""
delta = timedelta(hours=1)
assert self.encoder.default(delta) == str(delta.total_seconds())
delta = timedelta(hours=1)
assert self.encoder.default(delta) == str(delta.total_seconds())
def test_encode_uuid(self):
"""
"""
Tests encoding a UUID object
"""
unique_id = uuid4()
assert self.encoder.default(unique_id) == str(unique_id)
unique_id = uuid4()
assert self.encoder.default(unique_id) == str(unique_id)
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_encode_coreapi_raises_error(self):
"""
deftest_encode_coreapi_raises_error(self):
"""
Tests encoding a coreapi objects raises proper error
"""
with pytest.raises(RuntimeError):
self.encoder.default(coreapi.Document())
with pytest.raises(RuntimeError):
self.encoder.default(coreapi.Document())
with pytest.raises(RuntimeError):
self.encoder.default(coreapi.Error())
self.encoder.default(coreapi.Error())
def test_encode_object_with_tolist(self):
"""
"""
Tests encoding a object with tolist method
"""
foo = MockList()
assert self.encoder.default(foo) == [1, 2, 3]
foo = MockList()
assert self.encoder.default(foo) == [1, 2, 3]

View File

@ -7,89 +7,58 @@ from rest_framework.exceptions import (
server_error
)
def test_get_error_details(self):
class ExceptionTestCase(TestCase):
def test_get_error_details(self):
example = "string"
lazy_example = _(example)
assert _get_error_details(lazy_example) == example
assert isinstance(
_get_error_details(lazy_example),
ErrorDetail
)
assert _get_error_details({'nested': lazy_example})['nested'] == example
assert isinstance(
_get_error_details({'nested': lazy_example})['nested'],
ErrorDetail
)
assert _get_error_details([[lazy_example]])[0][0] == example
assert isinstance(
_get_error_details([[lazy_example]])[0][0],
ErrorDetail
)
example = "string"
lazy_example = _(example)
assert _get_error_details(lazy_example) == example
assert isinstance( _get_error_details(lazy_example), ErrorDetail )
assert _get_error_details({'nested': lazy_example})['nested'] == example
assert isinstance( _get_error_details({'nested': lazy_example})['nested'], ErrorDetail )
assert _get_error_details([[lazy_example]])[0][0] == example
assert isinstance( _get_error_details([[lazy_example]])[0][0], ErrorDetail )
def test_get_full_details_with_throttling(self):
exception = Throttled()
assert exception.get_full_details() == {
'message': 'Request was throttled.', 'code': 'throttled'}
exception = Throttled(wait=2)
assert exception.get_full_details() == {
'message': 'Request was throttled. Expected available in {} seconds.'.format(2),
'code': 'throttled'}
exception = Throttled(wait=2, detail='Slow down!')
assert exception.get_full_details() == {
'message': 'Slow down! Expected available in {} seconds.'.format(2),
'code': 'throttled'}
exception = Throttled()
assert exception.get_full_details() == { 'message': 'Request was throttled.', 'code': 'throttled'}
exception = Throttled(wait=2)
assert exception.get_full_details() == { 'message': 'Request was throttled. Expected available in {} seconds.'.format(2), 'code': 'throttled'}
exception = Throttled(wait=2, detail='Slow down!')
assert exception.get_full_details() == { 'message': 'Slow down! Expected available in {} seconds.'.format(2), 'code': 'throttled'}
class ErrorDetailTests(TestCase):
def test_eq(self):
assert ErrorDetail('msg') == ErrorDetail('msg')
assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
assert ErrorDetail('msg') == 'msg'
assert ErrorDetail('msg', 'code') == 'msg'
def test_eq(self):
assert ErrorDetail('msg') == ErrorDetail('msg')
assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
assert ErrorDetail('msg') == 'msg'
assert ErrorDetail('msg', 'code') == 'msg'
def test_ne(self):
assert ErrorDetail('msg1') != ErrorDetail('msg2')
assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
assert ErrorDetail('msg1') != 'msg2'
assert ErrorDetail('msg1', 'code') != 'msg2'
assert ErrorDetail('msg1') != ErrorDetail('msg2')
assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
assert ErrorDetail('msg1') != 'msg2'
assert ErrorDetail('msg1', 'code') != 'msg2'
def test_repr(self):
assert repr(ErrorDetail('msg1')) == \
'ErrorDetail(string={!r}, code=None)'.format('msg1')
assert repr(ErrorDetail('msg1', 'code')) == \
'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
assert repr(ErrorDetail('msg1')) == 'ErrorDetail(string={!r}, code=None)'.format('msg1')
assert repr(ErrorDetail('msg1', 'code')) == 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
def test_str(self):
assert str(ErrorDetail('msg1')) == 'msg1'
assert str(ErrorDetail('msg1', 'code')) == 'msg1'
assert str(ErrorDetail('msg1')) == 'msg1'
assert str(ErrorDetail('msg1', 'code')) == 'msg1'
def test_hash(self):
assert hash(ErrorDetail('msg')) == hash('msg')
assert hash(ErrorDetail('msg', 'code')) == hash('msg')
assert hash(ErrorDetail('msg')) == hash('msg')
assert hash(ErrorDetail('msg', 'code')) == hash('msg')
class TranslationTests(TestCase):
@translation.override('fr')
def test_message(self):
@translation.override('fr')
deftest_message(self):
# this test largely acts as a sanity test to ensure the translation files are present.
self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.')
self.assertEqual(str(APIException()), 'Une erreur du serveur est survenue.')
assert _('A server error occurred.') == 'Une erreur du serveur est survenue.'
assert str(APIException()) == 'Une erreur du serveur est survenue.'
def test_server_error():

View File

@ -1134,40 +1134,38 @@ class TestNoStringCoercionDecimalField(FieldValues):
)
class TestLocalizedDecimalField(TestCase):
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
def test_to_internal_value(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
assert field.to_internal_value('1,1') == Decimal('1.1')
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
deftest_to_internal_value(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
assert field.to_internal_value('1,1') == Decimal('1.1')
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
def test_to_representation(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
assert field.to_representation(Decimal('1.1')) == '1,1'
deftest_to_representation(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
assert field.to_representation(Decimal('1.1')) == '1,1'
def test_localize_forces_coerce_to_string(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True)
assert isinstance(field.to_representation(Decimal('1.1')), str)
field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True)
assert isinstance(field.to_representation(Decimal('1.1')), str)
class TestQuantizedValueForDecimal(TestCase):
def test_int_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value(12).as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple
def test_int_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value(12).as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple
def test_string_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value('12').as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple
field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value('12').as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple
def test_part_precision_string_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value('12.0').as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple
field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value('12.0').as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple
class TestNoDecimalPlaces(FieldValues):
@ -1185,17 +1183,15 @@ class TestNoDecimalPlaces(FieldValues):
field = serializers.DecimalField(max_digits=6, decimal_places=None)
class TestRoundingDecimalField(TestCase):
def test_valid_rounding(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
assert field.to_representation(Decimal('1.234')) == '1.24'
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
assert field.to_representation(Decimal('1.234')) == '1.23'
def test_valid_rounding(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
assert field.to_representation(Decimal('1.234')) == '1.24'
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
assert field.to_representation(Decimal('1.234')) == '1.23'
def test_invalid_rounding(self):
with pytest.raises(AssertionError) as excinfo:
serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
with pytest.raises(AssertionError) as excinfo:
serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
assert 'Invalid rounding option' in str(excinfo.value)
@ -1369,46 +1365,41 @@ class TestTZWithDateTimeField(FieldValues):
@override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestDefaultTZDateTimeField(TestCase):
"""
"""
Test the current/default timezone handling in `DateTimeField`.
"""
@classmethod
def setup_class(cls):
cls.field = serializers.DateTimeField()
cls.kolkata = pytz.timezone('Asia/Kolkata')
@classmethod
defsetup_class(cls):
cls.field = serializers.DateTimeField()
cls.kolkata = pytz.timezone('Asia/Kolkata')
def test_default_timezone(self):
assert self.field.default_timezone() == utc
assert self.field.default_timezone() == utc
def test_current_timezone(self):
assert self.field.default_timezone() == utc
activate(self.kolkata)
assert self.field.default_timezone() == self.kolkata
deactivate()
assert self.field.default_timezone() == utc
assert self.field.default_timezone() == utc
activate(self.kolkata)
assert self.field.default_timezone() == self.kolkata
deactivate()
assert self.field.default_timezone() == utc
@pytest.mark.skipif(pytz is None, reason='pytz not installed')
@override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestCustomTimezoneForDateTimeField(TestCase):
@classmethod
def setup_class(cls):
cls.kolkata = pytz.timezone('Asia/Kolkata')
cls.date_format = '%d/%m/%Y %H:%M'
@classmethod
defsetup_class(cls):
cls.kolkata = pytz.timezone('Asia/Kolkata')
cls.date_format = '%d/%m/%Y %H:%M'
def test_should_render_date_time_in_default_timezone(self):
field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
with override(self.kolkata):
rendered_date = field.to_representation(dt)
field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
with override(self.kolkata):
rendered_date = field.to_representation(dt)
rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format)
assert rendered_date == rendered_date_in_timezone
assertrendered_date==rendered_date_in_timezone
class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):

View File

@ -13,28 +13,25 @@ from rest_framework.compat import coreschema
from rest_framework.test import APIRequestFactory
factory = APIRequestFactory()
class BaseFilterTests(TestCase):
def setUp(self):
self.original_coreapi = filters.coreapi
filters.coreapi = True # mock it, because not None value needed
self.filter_backend = filters.BaseFilterBackend()
def setUp(self):
self.original_coreapi = filters.coreapi
filters.coreapi = True
self.filter_backend = filters.BaseFilterBackend()
def tearDown(self):
filters.coreapi = self.original_coreapi
filters.coreapi = self.original_coreapi
def test_filter_queryset_raises_error(self):
with pytest.raises(NotImplementedError):
self.filter_backend.filter_queryset(None, None, None)
with pytest.raises(NotImplementedError):
self.filter_backend.filter_queryset(None, None, None)
@pytest.mark.skipif(not coreschema, reason='coreschema is not installed')
def test_get_schema_fields_checks_for_coreapi(self):
filters.coreapi = None
with pytest.raises(AssertionError):
self.filter_backend.get_schema_fields({})
deftest_get_schema_fields_checks_for_coreapi(self):
filters.coreapi = None
with pytest.raises(AssertionError):
self.filter_backend.get_schema_fields({})
filters.coreapi = True
assert self.filter_backend.get_schema_fields({}) == []
assertself.filter_backend.get_schema_fields({})==[]
class SearchFilterModel(models.Model):
@ -48,137 +45,115 @@ class SearchFilterSerializer(serializers.ModelSerializer):
fields = '__all__'
class SearchFilterTests(TestCase):
def setUp(self):
def setUp(self):
# Sequence of title/text is:
#
# z abc
# zz bcd
# zzz cde
# ...
for idx in range(10):
title = 'z' * (idx + 1)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModel(title=title, text=text).save()
for idx in range(10):
title = 'z' * (idx + 1)
text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
SearchFilterModel(title=title, text=text).save()
def test_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
view = SearchListView.as_view()
request=factory.get('/',{'search':'b'})
response=view(request)
assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
view = SearchListView.as_view()
request=factory.get('/')
response=view(request)
expected=SearchFilterSerializer(SearchFilterModel.objects.all(),many=True).data
assertresponse.data==expected
def test_exact_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
view = SearchListView.as_view()
request=factory.get('/',{'search':'zzz'})
response=view(request)
assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
view = SearchListView.as_view()
request=factory.get('/',{'search':'b'})
response=view(request)
assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
def test_regexp_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('$title', '$text')
view = SearchListView.as_view()
request=factory.get('/',{'search':'z{2} ^b'})
response=view(request)
assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
def test_search_with_nonstandard_search_param(self):
with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
reload_module(filters)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
view = SearchListView.as_view()
request = factory.get('/')
response = view(request)
expected = SearchFilterSerializer(SearchFilterModel.objects.all(),
many=True).data
assert response.data == expected
def test_exact_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'zzz'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'})
response = view(request)
assert response.data == [
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_regexp_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('$title', '$text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'z{2} ^b'})
response = view(request)
assert response.data == [
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_with_nonstandard_search_param(self):
with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
reload_module(filters)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
view = SearchListView.as_view()
request = factory.get('/', {'query': 'b'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
request=factory.get('/',{'query':'b'})
response=view(request)
assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
reload_module(filters)
def test_search_with_filter_subclass(self):
class CustomSearchFilter(filters.SearchFilter):
class CustomSearchFilter(filters.SearchFilter):
# Filter that dynamically changes search fields
def get_search_fields(self, view, request):
if request.query_params.get('title_only'):
return ('$title',)
def get_search_fields(self, view, request):
if request.query_params.get('title_only'):
return ('$title',)
return super().get_search_fields(view, request)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (CustomSearchFilter,)
search_fields = ('$title', '$text')
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (CustomSearchFilter,)
search_fields = ('$title', '$text')
view = SearchListView.as_view()
request = factory.get('/', {'search': r'^\w{3}$'})
response = view(request)
assert len(response.data) == 10
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
request=factory.get('/',{'search':r'^\w{3}$'})
response=view(request)
assertlen(response.data)==10
request=factory.get('/',{'search':r'^\w{3}$','title_only':'true'})
response=view(request)
assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
class AttributeModel(models.Model):
@ -196,31 +171,21 @@ class SearchFilterFkSerializer(serializers.ModelSerializer):
fields = '__all__'
class SearchFilterFkTests(TestCase):
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%stitle" % prefix]
)
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%stitle" % prefix, "%sattribute__label" % prefix]
)
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix] )
assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix, "%sattribute__label" % prefix] )
def test_must_call_distinct_restores_meta_for_each_field(self):
# In this test case the attribute of the fk model comes first in the
# list of search fields.
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%sattribute__label" % prefix, "%stitle" % prefix]
)
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%sattribute__label" % prefix, "%stitle" % prefix] )
class SearchFilterModelM2M(models.Model):
@ -235,53 +200,41 @@ class SearchFilterM2MSerializer(serializers.ModelSerializer):
fields = '__all__'
class SearchFilterM2MTests(TestCase):
def setUp(self):
def setUp(self):
# Sequence of title/text/attributes is:
#
# z abc [1, 2, 3]
# zz bcd [1, 2, 3]
# zzz cde [1, 2, 3]
# ...
for idx in range(3):
label = 'w' * (idx + 1)
AttributeModel.objects.create(label=label)
for idx in range(3):
label = 'w' * (idx + 1)
AttributeModel.objects.create(label=label)
for idx in range(10):
title = 'z' * (idx + 1)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModelM2M(title=title, text=text).save()
title = 'z' * (idx + 1)
text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
SearchFilterModelM2M(title=title, text=text).save()
SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3)
def test_m2m_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModelM2M.objects.all()
serializer_class = SearchFilterM2MSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text', 'attributes__label')
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModelM2M.objects.all()
serializer_class = SearchFilterM2MSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text', 'attributes__label')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'zz'})
response = view(request)
assert len(response.data) == 1
request=factory.get('/',{'search':'zz'})
response=view(request)
assertlen(response.data)==1
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
SearchFilterModelM2M._meta,
["%stitle" % prefix]
)
assert filter_.must_call_distinct(
SearchFilterModelM2M._meta,
["%stitle" % prefix, "%sattributes__label" % prefix]
)
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix] )
assert filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix] )
class Blog(models.Model):
@ -300,32 +253,27 @@ class BlogSerializer(serializers.ModelSerializer):
fields = '__all__'
class SearchFilterToManyTests(TestCase):
@classmethod
def setUpTestData(cls):
b1 = Blog.objects.create(name='Blog 1')
b2 = Blog.objects.create(name='Blog 2')
# Multiple entries on Lennon published in 1979 - distinct should deduplicate
Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
# Entry on Lennon *and* a separate entry in 1979 - should not match
Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
@classmethod
defsetUpTestData(cls):
b1 = Blog.objects.create(name='Blog 1')
b2 = Blog.objects.create(name='Blog 2')
Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
def test_multiple_filter_conditions(self):
class SearchListView(generics.ListAPIView):
queryset = Blog.objects.all()
serializer_class = BlogSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
class SearchListView(generics.ListAPIView):
queryset = Blog.objects.all()
serializer_class = BlogSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'Lennon,1979'})
response = view(request)
assert len(response.data) == 1
request=factory.get('/',{'search':'Lennon,1979'})
response=view(request)
assertlen(response.data)==1
class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
@ -336,28 +284,23 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
fields = ('title', 'text', 'title_text')
class SearchFilterAnnotatedFieldTests(TestCase):
@classmethod
def setUpTestData(cls):
SearchFilterModel.objects.create(title='abc', text='def')
SearchFilterModel.objects.create(title='ghi', text='jkl')
@classmethod
defsetUpTestData(cls):
SearchFilterModel.objects.create(title='abc', text='def')
SearchFilterModel.objects.create(title='ghi', text='jkl')
def test_search_in_annotated_field(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.annotate(
title_text=Upper(
Concat(models.F('title'), models.F('text'))
)
).all()
serializer_class = SearchFilterAnnotatedSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title_text',)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.annotate( title_text=Upper( Concat(models.F('title'), models.F('text')) ) ).all()
serializer_class = SearchFilterAnnotatedSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title_text',)
view = SearchListView.as_view()
request = factory.get('/', {'search': 'ABCDEF'})
response = view(request)
assert len(response.data) == 1
assert response.data[0]['title_text'] == 'ABCDEF'
request=factory.get('/',{'search':'ABCDEF'})
response=view(request)
assertlen(response.data)==1
assertresponse.data[0]['title_text']=='ABCDEF'
class OrderingFilterModel(models.Model):
@ -403,253 +346,187 @@ class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
fields = '__all__'
class OrderingFilterTests(TestCase):
def setUp(self):
def setUp(self):
# Sequence of title/text is:
#
# zyx abc
# yxw bcd
# xwv cde
for idx in range(3):
title = (
chr(ord('z') - idx) +
chr(ord('y') - idx) +
chr(ord('x') - idx)
)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
OrderingFilterModel(title=title, text=text).save()
for idx in range(3):
title = ( chr(ord('z') - idx) + chr(ord('y') - idx) + chr(ord('x') - idx) )
text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
OrderingFilterModel(title=title, text=text).save()
def test_ordering(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
request=factory.get('/',{'ordering':'text'})
response=view(request)
assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
def test_reverse_ordering(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-text'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
request=factory.get('/',{'ordering':'-text'})
response=view(request)
assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_incorrecturl_extrahyphens_ordering(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '--text'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
request=factory.get('/',{'ordering':'--text'})
response=view(request)
assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_incorrectfield_ordering(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'foobar'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
request=factory.get('/',{'ordering':'foobar'})
response=view(request)
assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_default_ordering(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('')
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
request=factory.get('')
response=view(request)
assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_default_ordering_using_string(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
ordering_fields = ('text',)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('')
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
request=factory.get('')
response=view(request)
assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
def test_ordering_by_aggregate_field(self):
# create some related models to aggregate order by
num_objs = [2, 5, 3]
for obj, num_relateds in zip(OrderingFilterModel.objects.all(),
num_objs):
for _ in range(num_relateds):
new_related = OrderingFilterRelatedModel(
related_object=obj
)
new_related.save()
num_objs = [2, 5, 3]
for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs):
for _ in range(num_relateds):
new_related = OrderingFilterRelatedModel( related_object=obj )
new_related.save()
class OrderingListView(generics.ListAPIView):
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
ordering_fields = '__all__'
queryset = OrderingFilterModel.objects.all().annotate(
models.Count("relateds"))
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
ordering_fields = '__all__'
queryset = OrderingFilterModel.objects.all().annotate( models.Count("relateds"))
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'relateds__count'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
]
request=factory.get('/',{'ordering':'relateds__count'})
response=view(request)
assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},]
def test_ordering_by_dotted_source(self):
for index, obj in enumerate(OrderingFilterModel.objects.all()):
OrderingFilterRelatedModel.objects.create(
related_object=obj,
index=index
)
for index, obj in enumerate(OrderingFilterModel.objects.all()):
OrderingFilterRelatedModel.objects.create( related_object=obj, index=index )
class OrderingListView(generics.ListAPIView):
serializer_class = OrderingDottedRelatedSerializer
filter_backends = (filters.OrderingFilter,)
queryset = OrderingFilterRelatedModel.objects.all()
serializer_class = OrderingDottedRelatedSerializer
filter_backends = (filters.OrderingFilter,)
queryset = OrderingFilterRelatedModel.objects.all()
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'related_object__text'})
response = view(request)
assert response.data == [
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
]
request = factory.get('/', {'ordering': '-index'})
response = view(request)
assert response.data == [
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
]
request=factory.get('/',{'ordering':'related_object__text'})
response=view(request)
assertresponse.data==[{'related_title':'zyx','related_text':'abc','index':0},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'xwv','related_text':'cde','index':2},]
request=factory.get('/',{'ordering':'-index'})
response=view(request)
assertresponse.data==[{'related_title':'xwv','related_text':'cde','index':2},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'zyx','related_text':'abc','index':0},]
def test_ordering_with_nonstandard_ordering_param(self):
with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
reload_module(filters)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
reload_module(filters)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('/', {'order': 'text'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
request=factory.get('/',{'order':'text'})
response=view(request)
assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
reload_module(filters)
def test_get_template_context(self):
class OrderingListView(generics.ListAPIView):
ordering_fields = '__all__'
serializer_class = OrderingFilterSerializer
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
class OrderingListView(generics.ListAPIView):
ordering_fields = '__all__'
serializer_class = OrderingFilterSerializer
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html')
view = OrderingListView.as_view()
response = view(request)
self.assertContains(response, 'verbose title')
view=OrderingListView.as_view()
response=view(request)
self.assertContains(response,'verbose title')
def test_ordering_with_overridden_get_serializer_class(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
# note: no ordering_fields and serializer_class specified
def get_serializer_class(self):
return OrderingFilterSerializer
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
def get_serializer_class(self):
return OrderingFilterSerializer
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
request=factory.get('/',{'ordering':'text'})
response=view(request)
assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
def test_ordering_with_improper_configuration(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
# note: no ordering_fields and serializer_class
# or get_serializer_class specified
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'})
with self.assertRaises(ImproperlyConfigured):
view(request)
request=factory.get('/',{'ordering':'text'})
withpytest.raises(ImproperlyConfigured):
view(request)
class SensitiveOrderingFilterModel(models.Model):
@ -684,63 +561,44 @@ class SensitiveDataSerializer3(serializers.ModelSerializer):
fields = ('id', 'user')
class SensitiveOrderingFilterTests(TestCase):
def setUp(self):
for idx in range(3):
username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
SensitiveOrderingFilterModel(username=username, password=password).save()
def setUp(self):
for idx in range(3):
username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
SensitiveOrderingFilterModel(username=username, password=password).save()
def test_order_by_serializer_fields(self):
for serializer_cls in [
SensitiveDataSerializer1,
SensitiveDataSerializer2,
SensitiveDataSerializer3
]:
class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-username'})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
username_field = 'user'
request=factory.get('/',{'ordering':'-username'})
response=view(request)
ifserializer_cls==SensitiveDataSerializer3:
username_field = 'user'
else:
username_field = 'username'
username_field = 'username'
# Note: Inverse username ordering correctly applied.
assert response.data == [
{'id': 3, username_field: 'userC'},
{'id': 2, username_field: 'userB'},
{'id': 1, username_field: 'userA'},
]
assert response.data == [{'id':3,username_field:'userC'},{'id':2,username_field:'userB'},{'id':1,username_field:'userA'},]
def test_cannot_order_by_non_serializer_fields(self):
for serializer_cls in [
SensitiveDataSerializer1,
SensitiveDataSerializer2,
SensitiveDataSerializer3
]:
class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'password'})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
username_field = 'user'
request=factory.get('/',{'ordering':'password'})
response=view(request)
ifserializer_cls==SensitiveDataSerializer3:
username_field = 'user'
else:
username_field = 'username'
username_field = 'username'
# Note: The passwords are not in order. Default ordering is used.
assert response.data == [
{'id': 1, username_field: 'userA'}, # PassB
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
]
assert response.data == [{'id':1,username_field:'userA'},{'id':2,username_field:'userB'},{'id':3,username_field:'userC'},]

View File

@ -23,14 +23,12 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_generateschema')
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class GenerateSchemaTests(TestCase):
"""Tests for management command generateschema."""
def setUp(self):
self.out = io.StringIO()
"""Tests for management command generateschema."""
defsetUp(self):
self.out = io.StringIO()
def test_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info:
expected_out = """info:
description: Sample description
title: SampleAPI
version: ''
@ -42,45 +40,16 @@ class GenerateSchemaTests(TestCase):
servers:
- url: http://api.sample.com/
"""
call_command('generateschema',
'--title=SampleAPI',
'--url=http://api.sample.com',
'--description=Sample description',
stdout=self.out)
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
call_command('generateschema', '--title=SampleAPI', '--url=http://api.sample.com', '--description=Sample description', stdout=self.out)
assert formatting.dedent(expected_out) in self.out.getvalue()
def test_renders_openapi_json_schema(self):
expected_out = {
"openapi": "3.0.0",
"info": {
"version": "",
"title": "",
"description": ""
},
"servers": [
{
"url": ""
}
],
"paths": {
"/": {
"get": {
"operationId": "list"
}
}
}
}
call_command('generateschema',
'--format=openapi-json',
stdout=self.out)
out_json = json.loads(self.out.getvalue())
self.assertDictEqual(out_json, expected_out)
expected_out = { "openapi": "3.0.0", "info": { "version": "", "title": "", "description": "" }, "servers": [ { "url": "" } ], "paths": { "/": { "get": { "operationId": "list" } } } }
call_command('generateschema', '--format=openapi-json', stdout=self.out)
out_json = json.loads(self.out.getvalue())
assert out_json == expected_out
def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema',
'--format=corejson',
stdout=self.out)
self.assertIn(expected_out, self.out.getvalue())
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema', '--format=corejson', stdout=self.out)
assert expected_out in self.out.getvalue()

View File

@ -75,304 +75,282 @@ class SlugBasedInstanceView(InstanceView):
# Tests
class TestRootView(TestCase):
def setUp(self):
"""
def setUp(self):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = RootView.as_view()
self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
self.view=RootView.as_view()
def test_get_root_view(self):
"""
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
with self.assertNumQueries(1):
response = self.view(request).render()
request = factory.get('/')
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data
assertresponse.data==self.data
def test_head_root_view(self):
"""
"""
HEAD requests to ListCreateAPIView should return 200.
"""
request = factory.head('/')
with self.assertNumQueries(1):
response = self.view(request).render()
request = factory.head('/')
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
def test_post_root_view(self):
"""
"""
POST requests to ListCreateAPIView should create a new object.
"""
data = {'text': 'foobar'}
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
data = {'text': 'foobar'}
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4)
assert created.text == 'foobar'
assertresponse.data=={'id':4,'text':'foobar'}
created=self.objects.get(id=4)
assertcreated.text=='foobar'
def test_put_root_view(self):
"""
"""
PUT requests to ListCreateAPIView should not be allowed
"""
data = {'text': 'foobar'}
request = factory.put('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
data = {'text': 'foobar'}
request = factory.put('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "PUT" not allowed.'}
assertresponse.data=={"detail":'Method "PUT" not allowed.'}
def test_delete_root_view(self):
"""
"""
DELETE requests to ListCreateAPIView should not be allowed
"""
request = factory.delete('/')
with self.assertNumQueries(0):
response = self.view(request).render()
request = factory.delete('/')
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "DELETE" not allowed.'}
assertresponse.data=={"detail":'Method "DELETE" not allowed.'}
def test_post_cannot_set_id(self):
"""
"""
POST requests to create a new object should not be able to set the id.
"""
data = {'id': 999, 'text': 'foobar'}
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
data = {'id': 999, 'text': 'foobar'}
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4)
assert created.text == 'foobar'
assertresponse.data=={'id':4,'text':'foobar'}
created=self.objects.get(id=4)
assertcreated.text=='foobar'
def test_post_error_root_view(self):
"""
"""
POST requests to ListCreateAPIView in HTML should include a form error.
"""
data = {'text': 'foobar' * 100}
request = factory.post('/', data, HTTP_ACCEPT='text/html')
response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode()
data = {'text': 'foobar' * 100}
request = factory.post('/', data, HTTP_ACCEPT='text/html')
response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode()
EXPECTED_QUERIES_FOR_PUT = 2
class TestInstanceView(TestCase):
def setUp(self):
"""
def setUp(self):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz', 'filtered out']
for item in items:
BasicModel(text=item).save()
items = ['foo', 'bar', 'baz', 'filtered out']
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
self.view=InstanceView.as_view()
self.slug_based_view=SlugBasedInstanceView.as_view()
def test_get_instance_view(self):
"""
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0]
assertresponse.data==self.data[0]
def test_post_instance_view(self):
"""
"""
POST requests to RetrieveUpdateDestroyAPIView should not be allowed
"""
data = {'text': 'foobar'}
request = factory.post('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
data = {'text': 'foobar'}
request = factory.post('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "POST" not allowed.'}
assertresponse.data=={"detail":'Method "POST" not allowed.'}
def test_put_instance_view(self):
"""
"""
PUT requests to RetrieveUpdateDestroyAPIView should update an object.
"""
data = {'text': 'foobar'}
request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render()
data = {'text': 'foobar'}
request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render()
assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
assert updated.text == 'foobar'
assertdict(response.data)=={'id':1,'text':'foobar'}
updated=self.objects.get(id=1)
assertupdated.text=='foobar'
def test_patch_instance_view(self):
"""
"""
PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
"""
data = {'text': 'foobar'}
request = factory.patch('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
data = {'text': 'foobar'}
request = factory.patch('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
assert updated.text == 'foobar'
assertresponse.data=={'id':1,'text':'foobar'}
updated=self.objects.get(id=1)
assertupdated.text=='foobar'
def test_delete_instance_view(self):
"""
"""
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
"""
request = factory.delete('/1')
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
request = factory.delete('/1')
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == b''
ids = [obj.id for obj in self.objects.all()]
assert ids == [2, 3]
assertresponse.content==b''
ids=[obj.idforobjinself.objects.all()]
assertids==[2,3]
def test_get_instance_view_incorrect_arg(self):
"""
"""
GET requests with an incorrect pk type, should raise 404, not 500.
Regression test for #890.
"""
request = factory.get('/a')
with self.assertNumQueries(0):
response = self.view(request, pk='a').render()
request = factory.get('/a')
with self.assertNumQueries(0):
response = self.view(request, pk='a').render()
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self):
"""
"""
PUT requests to create a new object should not be able to set the id.
"""
data = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
data = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
assert updated.text == 'foobar'
assertresponse.data=={'id':1,'text':'foobar'}
updated=self.objects.get(id=1)
assertupdated.text=='foobar'
def test_put_to_deleted_instance(self):
"""
"""
PUT requests to RetrieveUpdateDestroyAPIView should return 404 if
an object does not currently exist.
"""
self.objects.get(id=1).delete()
data = {'text': 'foobar'}
request = factory.put('/1', data, format='json')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
self.objects.get(id=1).delete()
data = {'text': 'foobar'}
request = factory.put('/1', data, format='json')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_to_filtered_out_instance(self):
"""
"""
PUT requests to an URL of instance which is filtered out should not be
able to create new objects.
"""
data = {'text': 'foo'}
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
request = factory.put('/{}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
data = {'text': 'foo'}
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
request = factory.put('/{}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_patch_cannot_create_an_object(self):
"""
"""
PATCH requests should not be able to create objects.
"""
data = {'text': 'foobar'}
request = factory.patch('/999', data, format='json')
with self.assertNumQueries(1):
response = self.view(request, pk=999).render()
data = {'text': 'foobar'}
request = factory.patch('/999', data, format='json')
with self.assertNumQueries(1):
response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert not self.objects.filter(id=999).exists()
assertnotself.objects.filter(id=999).exists()
def test_put_error_instance_view(self):
"""
"""
Incorrect PUT requests in HTML should include a form error.
"""
data = {'text': 'foobar' * 100}
request = factory.put('/', data, HTTP_ACCEPT='text/html')
response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode()
data = {'text': 'foobar' * 100}
request = factory.put('/', data, HTTP_ACCEPT='text/html')
response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode()
class TestFKInstanceView(TestCase):
def setUp(self):
"""
def setUp(self):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
t = ForeignKeyTarget(name=item)
t.save()
ForeignKeySource(name='source_' + item, target=t).save()
items = ['foo', 'bar', 'baz']
for item in items:
t = ForeignKeyTarget(name=item)
t.save()
ForeignKeySource(name='source_' + item, target=t).save()
self.objects = ForeignKeySource.objects
self.data = [
{'id': obj.id, 'name': obj.name}
for obj in self.objects.all()
]
self.view = FKInstanceView.as_view()
self.data=[{'id':obj.id,'name':obj.name}forobjinself.objects.all()]
self.view=FKInstanceView.as_view()
class TestOverriddenGetObject(TestCase):
"""
"""
Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
queryset/model mechanism but instead overrides get_object()
"""
def setUp(self):
"""
defsetUp(self):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
"""
self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
classOverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
"""
Example detail view for override of get_object().
"""
serializer_class = BasicSerializer
def get_object(self):
pk = int(self.kwargs['pk'])
return get_object_or_404(BasicModel.objects.all(), id=pk)
serializer_class = BasicSerializer
def get_object(self):
pk = int(self.kwargs['pk'])
return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view()
def test_overridden_get_object_view(self):
"""
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0]
assertresponse.data==self.data[0]
# Regression test for #285
@ -388,23 +366,22 @@ class CommentView(generics.ListCreateAPIView):
model = Comment
class TestCreateModelWithAutoNowAddField(TestCase):
def setUp(self):
self.objects = Comment.objects
self.view = CommentView.as_view()
def setUp(self):
self.objects = Comment.objects
self.view = CommentView.as_view()
def test_create_model_with_auto_now_add_field(self):
"""
"""
Regression test for #285
https://github.com/encode/django-rest-framework/issues/285
"""
data = {'email': 'foobar@example.com', 'content': 'foobar'}
request = factory.post('/', data, format='json')
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1)
assert created.content == 'foobar'
data = {'email': 'foobar@example.com', 'content': 'foobar'}
request = factory.post('/', data, format='json')
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1)
assert created.content == 'foobar'
# Test for particularly ugly regression with m2m in browsable API
@ -432,15 +409,14 @@ class ExampleView(generics.ListCreateAPIView):
queryset = ClassA.objects.all()
class TestM2MBrowsableAPI(TestCase):
def test_m2m_in_browsable_api(self):
"""
def test_m2m_in_browsable_api(self):
"""
Test for particularly ugly regression with m2m in browsable API
"""
request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view()
response = view(request).render()
assert response.status_code == status.HTTP_200_OK
request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view()
response = view(request).render()
assert response.status_code == status.HTTP_200_OK
class InclusiveFilterBackend:
@ -476,189 +452,179 @@ class DynamicSerializerView(generics.ListCreateAPIView):
return DynamicSerializer
class TestFilterBackendAppliedToViews(TestCase):
def setUp(self):
"""
def setUp(self):
"""
Create 3 BasicModel instances to filter on.
"""
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
items = ['foo', 'bar', 'baz']
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
"""
GET requests to ListCreateAPIView should return filtered list.
"""
root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/')
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}]
root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/')
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}]
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
"""
"""
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/')
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == []
root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/')
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == []
def test_get_instance_view_filters_out_name_with_filter_backend(self):
"""
"""
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1')
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'}
instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1')
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'}
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
"""
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1')
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'}
instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1')
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'}
def test_dynamic_serializer_form_in_browsable_api(self):
"""
"""
GET requests to ListCreateAPIView should return filtered list.
"""
view = DynamicSerializerView.as_view()
request = factory.get('/')
response = view(request).render()
content = response.content.decode()
assert 'field_b' in content
assert 'field_a' not in content
view = DynamicSerializerView.as_view()
request = factory.get('/')
response = view(request).render()
content = response.content.decode()
assert 'field_b' in content
assert 'field_a' not in content
class TestGuardedQueryset(TestCase):
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all()
def get(self, request):
return Response(list(self.queryset))
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all()
def get(self, request):
return Response(list(self.queryset))
view = QuerysetAccessError.as_view()
request = factory.get('/')
with pytest.raises(RuntimeError):
view(request).render()
request=factory.get('/')
withpytest.raises(RuntimeError):
view(request).render()
class ApiViewsTests(TestCase):
def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView):
def create(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView):
def create(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockCreateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.post('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
data=('test request',('test arg',),{'test_kwarg':'test'})
view.post('test request','test arg',test_kwarg='test')
assertview.calledisTrue
assertview.call_args==data
def test_destroy_api_view_delete(self):
class MockDestroyApiView(generics.DestroyAPIView):
def destroy(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
class MockDestroyApiView(generics.DestroyAPIView):
def destroy(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockDestroyApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
data=('test request',('test arg',),{'test_kwarg':'test'})
view.delete('test request','test arg',test_kwarg='test')
assertview.calledisTrue
assertview.call_args==data
def test_update_api_view_partial_update(self):
class MockUpdateApiView(generics.UpdateAPIView):
def partial_update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
class MockUpdateApiView(generics.UpdateAPIView):
def partial_update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
data=('test request',('test arg',),{'test_kwarg':'test'})
view.patch('test request','test arg',test_kwarg='test')
assertview.calledisTrue
assertview.call_args==data
def test_retrieve_update_api_view_get(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def retrieve(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def retrieve(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
data=('test request',('test arg',),{'test_kwarg':'test'})
view.get('test request','test arg',test_kwarg='test')
assertview.calledisTrue
assertview.call_args==data
def test_retrieve_update_api_view_put(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.put('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
data=('test request',('test arg',),{'test_kwarg':'test'})
view.put('test request','test arg',test_kwarg='test')
assertview.calledisTrue
assertview.call_args==data
def test_retrieve_update_api_view_patch(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def partial_update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
def partial_update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
data=('test request',('test arg',),{'test_kwarg':'test'})
view.patch('test request','test arg',test_kwarg='test')
assertview.calledisTrue
assertview.call_args==data
def test_retrieve_destroy_api_view_get(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
def retrieve(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
def retrieve(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
data=('test request',('test arg',),{'test_kwarg':'test'})
view.get('test request','test arg',test_kwarg='test')
assertview.calledisTrue
assertview.call_args==data
def test_retrieve_destroy_api_view_delete(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
def destroy(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
def destroy(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
data=('test request',('test arg',),{'test_kwarg':'test'})
view.delete('test request','test arg',test_kwarg='test')
assertview.calledisTrue
assertview.call_args==data
class GetObjectOr404Tests(TestCase):
def setUp(self):
super().setUp()
self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
def setUp(self):
super().setUp()
self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
def test_get_object_or_404_with_valid_uuid(self):
obj = generics.get_object_or_404(
UUIDForeignKeyTarget, pk=self.uuid_object.pk
)
assert obj == self.uuid_object
obj = generics.get_object_or_404( UUIDForeignKeyTarget, pk=self.uuid_object.pk )
assert obj == self.uuid_object
def test_get_object_or_404_with_invalid_string_for_uuid(self):
with pytest.raises(Http404):
generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid')
with pytest.raises(Http404):
generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid')

View File

@ -42,122 +42,113 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
class TemplateHTMLRendererTests(TestCase):
def setUp(self):
class MockResponse:
template_name = None
def setUp(self):
class MockResponse:
template_name = None
self.mock_response = MockResponse()
self._monkey_patch_get_template()
self._monkey_patch_get_template()
def _monkey_patch_get_template(self):
"""
"""
Monkeypatch get_template
"""
self.get_template = django.template.loader.get_template
def get_template(template_name, dirs=None):
if template_name == 'example.html':
return engines['django'].from_string("example: {{ object }}")
self.get_template = django.template.loader.get_template
def get_template(template_name, dirs=None):
if template_name == 'example.html':
return engines['django'].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name)
def select_template(template_name_list, dirs=None, using=None):
if template_name_list == ['example.html']:
return engines['django'].from_string("example: {{ object }}")
if template_name_list == ['example.html']:
return engines['django'].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name_list[0])
django.template.loader.get_template = get_template
django.template.loader.select_template = select_template
django.template.loader.select_template=select_template
def tearDown(self):
"""
"""
Revert monkeypatching
"""
django.template.loader.get_template = self.get_template
django.template.loader.get_template = self.get_template
def test_simple_html_view(self):
response = self.client.get('/')
self.assertContains(response, "example: foobar")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
response = self.client.get('/')
self.assertContains(response, "example: foobar")
assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_not_found_html_view(self):
response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, b"404 Not Found")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
response = self.client.get('/not_found')
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.content == b"404 Not Found"
assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, b"403 Forbidden")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
response = self.client.get('/permission_denied')
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.content == b"403 Forbidden"
assert response['Content-Type'] == 'text/html; charset=utf-8'
# 2 tests below are based on order of if statements in corresponding method
# of TemplateHTMLRenderer
def test_get_template_names_returns_own_template_name(self):
renderer = TemplateHTMLRenderer()
renderer.template_name = 'test_template'
template_name = renderer.get_template_names(self.mock_response, view={})
assert template_name == ['test_template']
renderer = TemplateHTMLRenderer()
renderer.template_name = 'test_template'
template_name = renderer.get_template_names(self.mock_response, view={})
assert template_name == ['test_template']
def test_get_template_names_returns_view_template_name(self):
renderer = TemplateHTMLRenderer()
class MockResponse:
template_name = None
renderer = TemplateHTMLRenderer()
class MockResponse:
template_name = None
class MockView:
def get_template_names(self):
return ['template from get_template_names method']
def get_template_names(self):
return ['template from get_template_names method']
class MockView2:
template_name = 'template from template_name attribute'
template_name = 'template from template_name attribute'
template_name = renderer.get_template_names(self.mock_response,
MockView())
assert template_name == ['template from get_template_names method']
template_name = renderer.get_template_names(self.mock_response,
MockView2())
assert template_name == ['template from template_name attribute']
template_name = renderer.get_template_names(self.mock_response,MockView())
asserttemplate_name==['template from get_template_names method']
template_name=renderer.get_template_names(self.mock_response,MockView2())
asserttemplate_name==['template from template_name attribute']
def test_get_template_names_raises_error_if_no_template_found(self):
renderer = TemplateHTMLRenderer()
with pytest.raises(ImproperlyConfigured):
renderer.get_template_names(self.mock_response, view=object())
renderer = TemplateHTMLRenderer()
with pytest.raises(ImproperlyConfigured):
renderer.get_template_names(self.mock_response, view=object())
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
class TemplateHTMLRendererExceptionTests(TestCase):
def setUp(self):
"""
def setUp(self):
"""
Monkeypatch get_template
"""
self.get_template = django.template.loader.get_template
def get_template(template_name):
if template_name == '404.html':
return engines['django'].from_string("404: {{ detail }}")
self.get_template = django.template.loader.get_template
def get_template(template_name):
if template_name == '404.html':
return engines['django'].from_string("404: {{ detail }}")
if template_name == '403.html':
return engines['django'].from_string("403: {{ detail }}")
return engines['django'].from_string("403: {{ detail }}")
raise TemplateDoesNotExist(template_name)
django.template.loader.get_template = get_template
def tearDown(self):
"""
"""
Revert monkeypatching
"""
django.template.loader.get_template = self.get_template
django.template.loader.get_template = self.get_template
def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertTrue(response.content in (
b"404: Not found", b"404 Not Found"))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
response = self.client.get('/not_found')
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.content in ( b"404: Not found", b"404 Not Found")
assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertTrue(response.content in (b"403: Permission denied", b"403 Forbidden"))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
response = self.client.get('/permission_denied')
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.content in (b"403: Permission denied", b"403 Forbidden")
assert response['Content-Type'] == 'text/html; charset=utf-8'

View File

@ -34,16 +34,15 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks')
class TestLazyHyperlinkNames(TestCase):
def setUp(self):
self.example = Example.objects.create(text='foo')
def setUp(self):
self.example = Example.objects.create(text='foo')
def test_lazy_hyperlink_names(self):
global str_called
context = {'request': None}
serializer = ExampleSerializer(self.example, context=context)
JSONRenderer().render(serializer.data)
assert not str_called
hyperlink_string = format_value(serializer.data['url'])
assert hyperlink_string == '<a href=/example/1/>An example</a>'
assert str_called
global str_called
context = {'request': None}
serializer = ExampleSerializer(self.example, context=context)
JSONRenderer().render(serializer.data)
assert not str_called
hyperlink_string = format_value(serializer.data['url'])
assert hyperlink_string == '<a href=/example/1/>An example</a>'
assert str_called

View File

@ -308,98 +308,49 @@ class TestMetadata:
assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer)
class TestSimpleMetadataFieldInfo(TestCase):
def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField())
assert field_info['type'] == 'boolean'
def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField())
assert field_info['type'] == 'boolean'
def test_related_field_choices(self):
options = metadata.SimpleMetadata()
BasicModel.objects.create()
with self.assertNumQueries(0):
field_info = options.get_field_info(
serializers.RelatedField(queryset=BasicModel.objects.all())
)
options = metadata.SimpleMetadata()
BasicModel.objects.create()
with self.assertNumQueries(0):
field_info = options.get_field_info( serializers.RelatedField(queryset=BasicModel.objects.all()) )
assert 'choices' not in field_info
class TestModelSerializerMetadata(TestCase):
def test_read_only_primary_key_related_field(self):
"""
def test_read_only_primary_key_related_field(self):
"""
On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests. It should
not fail when a read_only PrimaryKeyRelatedField is present
"""
class Parent(models.Model):
integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)])
children = models.ManyToManyField('Child')
name = models.CharField(max_length=100, blank=True, null=True)
class Parent(models.Model):
integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)])
children = models.ManyToManyField('Child')
name = models.CharField(max_length=100, blank=True, null=True)
class Child(models.Model):
name = models.CharField(max_length=100)
name = models.CharField(max_length=100)
class ExampleSerializer(serializers.ModelSerializer):
children = serializers.PrimaryKeyRelatedField(read_only=True, many=True)
class Meta:
model = Parent
fields = '__all__'
children = serializers.PrimaryKeyRelatedField(read_only=True, many=True)
class Meta:
model = Parent
fields = '__all__'
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
"""Example view."""
def post(self, request):
pass
def get_serializer(self):
return ExampleSerializer()
return ExampleSerializer()
view = ExampleView.as_view()
response = view(request=request)
expected = {
'name': 'Example',
'description': 'Example view.',
'renders': [
'application/json',
'text/html'
],
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
],
'actions': {
'POST': {
'id': {
'type': 'integer',
'required': False,
'read_only': True,
'label': 'ID'
},
'children': {
'type': 'field',
'required': False,
'read_only': True,
'label': 'Children'
},
'integer_field': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'Integer field',
'min_value': 1,
'max_value': 1000
},
'name': {
'type': 'string',
'required': False,
'read_only': False,
'label': 'Name',
'max_length': 100
}
}
}
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
response=view(request=request)
expected={'name':'Example','description':'Example view.','renders':['application/json','text/html'],'parses':['application/json','application/x-www-form-urlencoded','multipart/form-data'],'actions':{'POST':{'id':{'type':'integer','required':False,'read_only':True,'label':'ID'},'children':{'type':'field','required':False,'read_only':True,'label':'Children'},'integer_field':{'type':'integer','required':True,'read_only':False,'label':'Integer field','min_value':1,'max_value':1000},'name':{'type':'string','required':False,'read_only':False,'label':'Name','max_length':100}}}}
assertresponse.status_code==status.HTTP_200_OK
assertresponse.data==expected

File diff suppressed because it is too large Load Diff

View File

@ -33,35 +33,31 @@ class AssociatedModelSerializer(serializers.ModelSerializer):
# Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self):
"""
def test_multitable_inherited_model_fields_as_expected(self):
"""
Assert that the parent pointer field is not included in the fields
serialized fields
"""
child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child)
assert set(serializer.data) == {'name1', 'name2', 'id'}
child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child)
assert set(serializer.data) == {'name1', 'name2', 'id'}
def test_onetoone_primary_key_model_fields_as_expected(self):
"""
"""
Assert that a model with a onetoone field that is the primary key is
not treated like a derived model
"""
parent = ParentModel.objects.create(name1='parent name')
associate = AssociatedModel.objects.create(name='hello', ref=parent)
serializer = AssociatedModelSerializer(associate)
assert set(serializer.data) == {'name', 'ref'}
parent = ParentModel.objects.create(name1='parent name')
associate = AssociatedModel.objects.create(name='hello', ref=parent)
serializer = AssociatedModelSerializer(associate)
assert set(serializer.data) == {'name', 'ref'}
def test_data_is_valid_without_parent_ptr(self):
"""
"""
Assert that the pointer to the parent table is not a required field
for input data
"""
data = {
'name1': 'parent name',
'name2': 'child name',
}
serializer = DerivedModelSerializer(data=data)
assert serializer.is_valid() is True
data = { 'name1': 'parent name', 'name2': 'child name', }
serializer = DerivedModelSerializer(data=data)
assert serializer.is_valid() is True

View File

@ -30,70 +30,68 @@ class NoCharsetSpecifiedRenderer(BaseRenderer):
media_type = 'my/media'
class TestAcceptedMediaType(TestCase):
def setUp(self):
self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()]
self.negotiator = DefaultContentNegotiation()
def setUp(self):
self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()]
self.negotiator = DefaultContentNegotiation()
def select_renderer(self, request):
return self.negotiator.select_renderer(request, self.renderers)
return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self):
request = Request(factory.get('/'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json'
request = Request(factory.get('/'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json'
def test_client_underspecifies_accept_use_renderer(self):
request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json'
request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json'
def test_client_overspecifies_accept_use_client(self):
request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json; indent=8'
request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json; indent=8'
def test_client_specifies_parameter(self):
request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/openapi+json;version=2.0'
assert accepted_renderer.format == 'swagger'
request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/openapi+json;version=2.0'
assert accepted_renderer.format == 'swagger'
def test_match_is_false_if_main_types_not_match(self):
mediatype = _MediaType('test_1')
anoter_mediatype = _MediaType('test_2')
assert mediatype.match(anoter_mediatype) is False
mediatype = _MediaType('test_1')
anoter_mediatype = _MediaType('test_2')
assert mediatype.match(anoter_mediatype) is False
def test_mediatype_match_is_false_if_keys_not_match(self):
mediatype = _MediaType(';test_param=foo')
another_mediatype = _MediaType(';test_param=bar')
assert mediatype.match(another_mediatype) is False
mediatype = _MediaType(';test_param=foo')
another_mediatype = _MediaType(';test_param=bar')
assert mediatype.match(another_mediatype) is False
def test_mediatype_precedence_with_wildcard_subtype(self):
mediatype = _MediaType('test/*')
assert mediatype.precedence == 1
mediatype = _MediaType('test/*')
assert mediatype.precedence == 1
def test_mediatype_string_representation(self):
mediatype = _MediaType('test/*; foo=bar')
assert str(mediatype) == 'test/*; foo=bar'
mediatype = _MediaType('test/*; foo=bar')
assert str(mediatype) == 'test/*; foo=bar'
def test_raise_error_if_no_suitable_renderers_found(self):
class MockRenderer:
format = 'xml'
class MockRenderer:
format = 'xml'
renderers = [MockRenderer()]
with pytest.raises(Http404):
self.negotiator.filter_renderers(renderers, format='json')
withpytest.raises(Http404):
self.negotiator.filter_renderers(renderers, format='json')
class BaseContentNegotiationTests(TestCase):
def setUp(self):
self.negotiator = BaseContentNegotiation()
def setUp(self):
self.negotiator = BaseContentNegotiation()
def test_raise_error_for_abstract_select_parser_method(self):
with pytest.raises(NotImplementedError):
self.negotiator.select_parser(None, None)
with pytest.raises(NotImplementedError):
self.negotiator.select_parser(None, None)
def test_raise_error_for_abstract_select_renderer_method(self):
with pytest.raises(NotImplementedError):
self.negotiator.select_renderer(None, None)
with pytest.raises(NotImplementedError):
self.negotiator.select_renderer(None, None)

View File

@ -28,14 +28,12 @@ class ChildAssociatedModelSerializer(serializers.ModelSerializer):
# Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self):
"""
def test_multitable_inherited_model_fields_as_expected(self):
"""
Assert that the parent pointer field is not included in the fields
serialized fields
"""
child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child)
self.assertEqual(set(serializer.data),
{'name1', 'name2', 'id', 'childassociatedmodel'})
child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child)
assert set(serializer.data) == {'name1', 'name2', 'id', 'childassociatedmodel'}

View File

@ -22,157 +22,138 @@ class Form(forms.Form):
field2 = forms.CharField()
class TestFormParser(TestCase):
def setUp(self):
self.string = "field1=abc&field2=defghijk"
def setUp(self):
self.string = "field1=abc&field2=defghijk"
def test_parse(self):
""" Make sure the `QueryDict` works OK """
parser = FormParser()
stream = io.StringIO(self.string)
data = parser.parse(stream)
assert Form(data).is_valid() is True
""" Make sure the `QueryDict` works OK """
parser = FormParser()
stream = io.StringIO(self.string)
data = parser.parse(stream)
assert Form(data).is_valid() is True
class TestFileUploadParser(TestCase):
def setUp(self):
class MockRequest:
pass
def setUp(self):
class MockRequest:
pass
self.stream = io.BytesIO(b"Test text file")
request = MockRequest()
request.upload_handlers = (MemoryFileUploadHandler(),)
request.META = {
'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt',
'HTTP_CONTENT_LENGTH': 14,
}
self.parser_context = {'request': request, 'kwargs': {}}
request=MockRequest()
request.upload_handlers=(MemoryFileUploadHandler(),)
request.META={'HTTP_CONTENT_DISPOSITION':'Content-Disposition: inline; filename=file.txt','HTTP_CONTENT_LENGTH':14,}
self.parser_context={'request':request,'kwargs':{}}
def test_parse(self):
"""
"""
Parse raw file upload.
"""
parser = FileUploadParser()
self.stream.seek(0)
data_and_files = parser.parse(self.stream, None, self.parser_context)
file_obj = data_and_files.files['file']
assert file_obj.size == 14
parser = FileUploadParser()
self.stream.seek(0)
data_and_files = parser.parse(self.stream, None, self.parser_context)
file_obj = data_and_files.files['file']
assert file_obj.size == 14
def test_parse_missing_filename(self):
"""
"""
Parse raw file upload when filename is missing.
"""
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_parse_missing_filename_multiple_upload_handlers(self):
"""
"""
Parse raw file upload with multiple handlers when filename is missing.
Regression test for #2109.
"""
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].upload_handlers = (
MemoryFileUploadHandler(),
MemoryFileUploadHandler()
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( MemoryFileUploadHandler(), MemoryFileUploadHandler() )
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_parse_missing_filename_large_file(self):
"""
"""
Parse raw file upload when filename is missing with TemporaryFileUploadHandler.
"""
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].upload_handlers = (
TemporaryFileUploadHandler(),
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( TemporaryFileUploadHandler(), )
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
def test_get_filename(self):
parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'file.txt'
parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'file.txt'
def test_get_encoded_filename(self):
parser = FileUploadParser()
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
parser = FileUploadParser()
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
def __replace_content_disposition(self, disposition):
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
class TestJSONParser(TestCase):
def bytes(self, value):
return io.BytesIO(value.encode())
def bytes(self, value):
return io.BytesIO(value.encode())
def test_float_strictness(self):
parser = JSONParser()
# Default to strict
for value in ['Infinity', '-Infinity', 'NaN']:
with pytest.raises(ParseError):
parser.parse(self.bytes(value))
parser = JSONParser()
for value in ['Infinity', '-Infinity', 'NaN']:
with pytest.raises(ParseError):
parser.parse(self.bytes(value))
parser.strict = False
assert parser.parse(self.bytes('Infinity')) == float('inf')
assert parser.parse(self.bytes('-Infinity')) == float('-inf')
assert math.isnan(parser.parse(self.bytes('NaN')))
assertparser.parse(self.bytes('Infinity'))==float('inf')
assertparser.parse(self.bytes('-Infinity'))==float('-inf')
assertmath.isnan(parser.parse(self.bytes('NaN')))
class TestPOSTAccessed(TestCase):
def setUp(self):
self.factory = APIRequestFactory()
def setUp(self):
self.factory = APIRequestFactory()
def test_post_accessed_in_post_method(self):
django_request = self.factory.post('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST
assert request.POST == {'foo': ['bar']}
assert request.data == {'foo': ['bar']}
django_request = self.factory.post('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST
assert request.POST == {'foo': ['bar']}
assert request.data == {'foo': ['bar']}
def test_post_accessed_in_post_method_with_json_parser(self):
django_request = self.factory.post('/', {'foo': 'bar'})
request = Request(django_request, parsers=[JSONParser()])
django_request.POST
assert request.POST == {}
assert request.data == {}
django_request = self.factory.post('/', {'foo': 'bar'})
request = Request(django_request, parsers=[JSONParser()])
django_request.POST
assert request.POST == {}
assert request.data == {}
def test_post_accessed_in_put_method(self):
django_request = self.factory.put('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST
assert request.POST == {'foo': ['bar']}
assert request.data == {'foo': ['bar']}
django_request = self.factory.put('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST
assert request.POST == {'foo': ['bar']}
assert request.data == {'foo': ['bar']}
def test_request_read_before_parsing(self):
django_request = self.factory.put('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.read()
django_request = self.factory.put('/', {'foo': 'bar'})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.read()
with pytest.raises(RawPostDataException):
request.POST
with pytest.raises(RawPostDataException):
request.POST
with pytest.raises(RawPostDataException):
request.POST
request.data
request.POST
request.data

View File

@ -72,177 +72,136 @@ def basic_auth_header(username, password):
return 'Basic %s' % base64_credentials
class ModelPermissionsIntegrationTests(TestCase):
def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
user.user_permissions.set([
Permission.objects.get(codename='add_basicmodel'),
Permission.objects.get(codename='change_basicmodel'),
Permission.objects.get(codename='delete_basicmodel')
])
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
user.user_permissions.set([
Permission.objects.get(codename='change_basicmodel'),
])
self.permitted_credentials = basic_auth_header('permitted', 'password')
self.disallowed_credentials = basic_auth_header('disallowed', 'password')
self.updateonly_credentials = basic_auth_header('updateonly', 'password')
BasicModel(text='foo').save()
def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
user.user_permissions.set([ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') ])
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
user.user_permissions.set([ Permission.objects.get(codename='change_basicmodel'), ])
self.permitted_credentials = basic_auth_header('permitted', 'password')
self.disallowed_credentials = basic_auth_header('disallowed', 'password')
self.updateonly_credentials = basic_auth_header('updateonly', 'password')
BasicModel(text='foo').save()
def test_has_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request, pk=1)
assert response.status_code == status.HTTP_201_CREATED
def test_api_root_view_discard_default_django_model_permission(self):
"""
"""
We check that DEFAULT_PERMISSION_CLASSES can
apply to APIRoot view. More specifically we check expected behavior of
``_ignore_model_permissions`` attribute support.
"""
request = factory.get('/', format='json',
HTTP_AUTHORIZATION=self.permitted_credentials)
request.resolver_match = ResolverMatch('get', (), {})
response = api_root_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
request = factory.get('/', format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
request.resolver_match = ResolverMatch('get', (), {})
response = api_root_view(request)
assert response.status_code == status.HTTP_200_OK
def test_get_queryset_has_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = get_queryset_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
response = get_queryset_list_view(request, pk=1)
assert response.status_code == status.HTTP_201_CREATED
def test_has_put_permissions(self):
request = factory.put('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
def test_has_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk=1)
assert response.status_code == status.HTTP_204_NO_CONTENT
def test_does_not_have_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = root_view(request, pk=1)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_does_not_have_put_permissions(self):
request = factory.put('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk='1')
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_does_not_have_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk=1)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_options_permitted(self):
request = factory.options(
'/',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions']), ['POST'])
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions']), ['PUT'])
request = factory.options( '/', HTTP_AUTHORIZATION=self.permitted_credentials )
response = root_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
assert 'actions' in response.data
assert list(response.data['actions']) == ['POST']
request = factory.options( '/1', HTTP_AUTHORIZATION=self.permitted_credentials )
response = instance_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
assert 'actions' in response.data
assert list(response.data['actions']) == ['PUT']
def test_options_disallowed(self):
request = factory.options(
'/',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
request = factory.options( '/', HTTP_AUTHORIZATION=self.disallowed_credentials )
response = root_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
assert 'actions' not in response.data
request = factory.options( '/1', HTTP_AUTHORIZATION=self.disallowed_credentials )
response = instance_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
assert 'actions' not in response.data
def test_options_updateonly(self):
request = factory.options(
'/',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data)
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions']), ['PUT'])
request = factory.options( '/', HTTP_AUTHORIZATION=self.updateonly_credentials )
response = root_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
assert 'actions' not in response.data
request = factory.options( '/1', HTTP_AUTHORIZATION=self.updateonly_credentials )
response = instance_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
assert 'actions' in response.data
assert list(response.data['actions']) == ['PUT']
def test_empty_view_does_not_assert(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = empty_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK)
request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = empty_list_view(request, pk=1)
assert response.status_code == status.HTTP_200_OK
def test_calling_method_not_allowed(self):
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1')
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_check_auth_before_queryset_call(self):
class View(RootView):
def get_queryset(_):
self.fail('should not reach due to auth check')
class View(RootView):
def get_queryset(_):
self.fail('should not reach due to auth check')
view = View.as_view()
request = factory.get('/', HTTP_AUTHORIZATION='')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
request=factory.get('/',HTTP_AUTHORIZATION='')
response=view(request)
assertresponse.status_code==status.HTTP_401_UNAUTHORIZED
def test_queryset_assertions(self):
class View(views.APIView):
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions]
class View(views.APIView):
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions]
view = View.as_view()
request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials)
msg = 'Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.'
with self.assertRaisesMessage(AssertionError, msg):
view(request)
request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
msg='Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.'
withself.assertRaisesMessage(AssertionError,msg):
view(request)
# Faulty `get_queryset()` methods should trigger the above "view does not have a queryset" assertion.
class View(RootView):
def get_queryset(self):
return None
def get_queryset(self):
return None
view = View.as_view()
request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials)
with self.assertRaisesMessage(AssertionError, 'View.get_queryset() returned None'):
view(request)
request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
withself.assertRaisesMessage(AssertionError,'View.get_queryset() returned None'):
view(request)
class BasicPermModel(models.Model):
@ -310,149 +269,117 @@ get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.a
@unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed')
class ObjectPermissionsIntegrationTests(TestCase):
"""
"""
Integration tests for the object level permissions API.
"""
def setUp(self):
from guardian.shortcuts import assign_perm
# create users
create = User.objects.create_user
users = {
'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
'readonly': create('readonly', 'readonly@example.com', 'password'),
'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
}
# give everyone model level permissions, as we are not testing those
everyone = Group.objects.create(name='everyone')
model_name = BasicPermModel._meta.model_name
app_label = BasicPermModel._meta.app_label
f = '{}_{}'.format
perms = {
'view': f('view', model_name),
'change': f('change', model_name),
'delete': f('delete', model_name)
}
for perm in perms.values():
perm = '{}.{}'.format(app_label, perm)
assign_perm(perm, everyone)
defsetUp(self):
from guardian.shortcuts import assign_perm
create = User.objects.create_user
users = { 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), 'readonly': create('readonly', 'readonly@example.com', 'password'), 'writeonly': create('writeonly', 'writeonly@example.com', 'password'), 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), }
everyone = Group.objects.create(name='everyone')
model_name = BasicPermModel._meta.model_name
app_label = BasicPermModel._meta.app_label
f = '{}_{}'.format
perms = { 'view': f('view', model_name), 'change': f('change', model_name), 'delete': f('delete', model_name) }
for perm in perms.values():
perm = '{}.{}'.format(app_label, perm)
assign_perm(perm, everyone)
everyone.user_set.add(*users.values())
# appropriate object level permissions
readers = Group.objects.create(name='readers')
writers = Group.objects.create(name='writers')
deleters = Group.objects.create(name='deleters')
model = BasicPermModel.objects.create(text='foo')
assign_perm(perms['view'], readers, model)
assign_perm(perms['change'], writers, model)
assign_perm(perms['delete'], deleters, model)
readers.user_set.add(users['fullaccess'], users['readonly'])
writers.user_set.add(users['fullaccess'], users['writeonly'])
deleters.user_set.add(users['fullaccess'], users['deleteonly'])
self.credentials = {}
for user in users.values():
self.credentials[user.username] = basic_auth_header(user.username, 'password')
readers=Group.objects.create(name='readers')
writers=Group.objects.create(name='writers')
deleters=Group.objects.create(name='deleters')
model=BasicPermModel.objects.create(text='foo')
assign_perm(perms['view'],readers,model)
assign_perm(perms['change'],writers,model)
assign_perm(perms['delete'],deleters,model)
readers.user_set.add(users['fullaccess'],users['readonly'])
writers.user_set.add(users['fullaccess'],users['writeonly'])
deleters.user_set.add(users['fullaccess'],users['deleteonly'])
self.credentials={}
foruserinusers.values():
self.credentials[user.username] = basic_auth_header(user.username, 'password')
# Delete
def test_can_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
response = object_permissions_view(request, pk='1')
assert response.status_code == status.HTTP_204_NO_CONTENT
def test_cannot_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1')
assert response.status_code == status.HTTP_403_FORBIDDEN
# Update
def test_can_update_permissions(self):
request = factory.patch(
'/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['writeonly']
)
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data.get('text'), 'foobar')
request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['writeonly'] )
response = object_permissions_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
assert response.data.get('text') == 'foobar'
def test_cannot_update_permissions(self):
request = factory.patch(
'/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
response = object_permissions_view(request, pk='1')
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_cannot_update_permissions_non_existing(self):
request = factory.patch(
'/999', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='999')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
request = factory.patch( '/999', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
response = object_permissions_view(request, pk='999')
assert response.status_code == status.HTTP_404_NOT_FOUND
# Read
def test_can_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
def test_cannot_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
response = object_permissions_view(request, pk='1')
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_can_read_get_queryset_permissions(self):
"""
"""
same as ``test_can_read_permissions`` but with a view
that rely on ``.get_queryset()`` instead of ``.queryset``.
"""
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = get_queryset_object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = get_queryset_object_permissions_view(request, pk='1')
assert response.status_code == status.HTTP_200_OK
# Read list
def test_django_object_permissions_filter_deprecated(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
DjangoObjectPermissionsFilter()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
DjangoObjectPermissionsFilter()
message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved "
"to the 3rd-party django-rest-framework-guardian package.")
self.assertEqual(len(w), 1)
self.assertIs(w[-1].category, RemovedInDRF310Warning)
self.assertEqual(str(w[-1].message), message)
message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved ""to the 3rd-party django-rest-framework-guardian package.")
assertlen(w)==1
assertw[-1].categoryisRemovedInDRF310Warning
assertstr(w[-1].message)==message
def test_can_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
# TODO: remove in version 3.10
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data[0].get('id'), 1)
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
response = object_permissions_list_view(request)
assert response.status_code== status.HTTP_200_OK
assertresponse.data[0].get('id')==1
def test_cannot_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
# TODO: remove in version 3.10
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(response.data, [])
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
response = object_permissions_list_view(request)
assert response.status_code== status.HTTP_200_OK
assertresponse.data==[]
def test_cannot_method_not_allowed(self):
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_list_view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
class BasicPerm(permissions.BasePermission):
@ -507,203 +434,176 @@ denied_view_with_detail = DeniedViewWithDetail.as_view()
denied_object_view = DeniedObjectView.as_view()
denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()
class CustomPermissionsTests(TestCase):
def setUp(self):
BasicModel(text='foo').save()
User.objects.create_user('username', 'username@example.com', 'password')
credentials = basic_auth_header('username', 'password')
self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials)
self.custom_message = 'Custom: You cannot access this resource'
def setUp(self):
BasicModel(text='foo').save()
User.objects.create_user('username', 'username@example.com', 'password')
credentials = basic_auth_header('username', 'password')
self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials)
self.custom_message = 'Custom: You cannot access this resource'
def test_permission_denied(self):
response = denied_view(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertNotEqual(detail, self.custom_message)
response = denied_view(self.request, pk=1)
detail = response.data.get('detail')
assert response.status_code == status.HTTP_403_FORBIDDEN
assert detail != self.custom_message
def test_permission_denied_with_custom_detail(self):
response = denied_view_with_detail(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(detail, self.custom_message)
response = denied_view_with_detail(self.request, pk=1)
detail = response.data.get('detail')
assert response.status_code == status.HTTP_403_FORBIDDEN
assert detail == self.custom_message
def test_permission_denied_for_object(self):
response = denied_object_view(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertNotEqual(detail, self.custom_message)
response = denied_object_view(self.request, pk=1)
detail = response.data.get('detail')
assert response.status_code == status.HTTP_403_FORBIDDEN
assert detail != self.custom_message
def test_permission_denied_for_object_with_custom_detail(self):
response = denied_object_view_with_detail(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(detail, self.custom_message)
response = denied_object_view_with_detail(self.request, pk=1)
detail = response.data.get('detail')
assert response.status_code == status.HTTP_403_FORBIDDEN
assert detail == self.custom_message
class PermissionsCompositionTests(TestCase):
def setUp(self):
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(
self.username,
self.email,
self.password
)
self.client.login(username=self.username, password=self.password)
def setUp(self):
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user( self.username, self.email, self.password )
self.client.login(username=self.username, password=self.password)
def test_and_false(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
composed_perm = permissions.IsAuthenticated & permissions.AllowAny
assert composed_perm().has_permission(request, None) is False
request = factory.get('/1', format='json')
request.user = AnonymousUser()
composed_perm = permissions.IsAuthenticated & permissions.AllowAny
assert composed_perm().has_permission(request, None) is False
def test_and_true(self):
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = permissions.IsAuthenticated & permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = permissions.IsAuthenticated & permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
def test_or_false(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
composed_perm = permissions.IsAuthenticated | permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
request = factory.get('/1', format='json')
request.user = AnonymousUser()
composed_perm = permissions.IsAuthenticated | permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
def test_or_true(self):
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = permissions.IsAuthenticated | permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = permissions.IsAuthenticated | permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
def test_not_false(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
composed_perm = ~permissions.IsAuthenticated
assert composed_perm().has_permission(request, None) is True
request = factory.get('/1', format='json')
request.user = AnonymousUser()
composed_perm = ~permissions.IsAuthenticated
assert composed_perm().has_permission(request, None) is True
def test_not_true(self):
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = ~permissions.AllowAny
assert composed_perm().has_permission(request, None) is False
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = ~permissions.AllowAny
assert composed_perm().has_permission(request, None) is False
def test_several_levels_without_negation(self):
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = (
permissions.IsAuthenticated &
permissions.IsAuthenticated &
permissions.IsAuthenticated &
permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated )
assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence_with_negation(self):
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = (
permissions.IsAuthenticated &
~ permissions.IsAdminUser &
permissions.IsAuthenticated &
~(permissions.IsAdminUser & permissions.IsAdminUser)
)
assert composed_perm().has_permission(request, None) is True
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = ( permissions.IsAuthenticated & ~ permissions.IsAdminUser & permissions.IsAuthenticated & ~(permissions.IsAdminUser & permissions.IsAdminUser) )
assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence(self):
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = (
permissions.IsAuthenticated &
permissions.IsAuthenticated |
permissions.IsAuthenticated &
permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True
request = factory.get('/1', format='json')
request.user = self.user
composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated | permissions.IsAuthenticated & permissions.IsAuthenticated )
assert composed_perm().has_permission(request, None) is True
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_or_lazyness(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
deftest_or_lazyness(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None)
assert hasperm is True
mock_allow.assert_called_once()
mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True)
mock_allow.assert_called_once()
mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True)
mock_deny.assert_called_once()
mock_allow.assert_called_once()
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None)
assert hasperm is True
mock_deny.assert_called_once()
mock_allow.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_or_lazyness(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
deftest_object_or_lazyness(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None)
assert hasperm is True
mock_allow.assert_called_once()
mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True)
mock_allow.assert_called_once()
mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True)
mock_deny.assert_called_once()
mock_allow.assert_called_once()
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None)
assert hasperm is True
mock_deny.assert_called_once()
mock_allow.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_and_lazyness(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
deftest_and_lazyness(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None)
assert hasperm is False
mock_allow.assert_called_once()
mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False)
mock_allow.assert_called_once()
mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False)
mock_allow.assert_not_called()
mock_deny.assert_called_once()
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None)
assert hasperm is False
mock_allow.assert_not_called()
mock_deny.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_and_lazyness(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
deftest_object_and_lazyness(self):
request = factory.get('/1', format='json')
request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None)
assert hasperm is False
mock_allow.assert_called_once()
mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False)
mock_allow.assert_called_once()
mock_deny.assert_called_once()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False)
mock_allow.assert_not_called()
mock_deny.assert_called_once()
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None)
assert hasperm is False
mock_allow.assert_not_called()
mock_deny.assert_called_once()

View File

@ -18,41 +18,30 @@ class UserUpdate(generics.UpdateAPIView):
serializer_class = UserSerializer
class TestPrefetchRelatedUpdates(TestCase):
def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.user.groups.set(self.groups)
def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.user.groups.set(self.groups)
def test_prefetch_related_updates(self):
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'new',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = { 'id': pk, 'username': 'new', 'groups': [1], 'email': 'tom@example.com' }
assert response.data == expected
def test_prefetch_related_excluding_instance_from_original_queryset(self):
"""
"""
Regression test for https://github.com/encode/django-rest-framework/issues/4661
"""
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = { 'id': pk, 'username': 'exclude', 'groups': [1], 'email': 'tom@example.com' }
assert response.data == expected

View File

@ -70,380 +70,268 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedManyToManyTests(TestCase):
def setUp(self):
for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx)
target.save()
source = ManyToManySource(name='source-%d' % idx)
source.save()
for target in ManyToManyTarget.objects.all():
source.targets.add(target)
def setUp(self):
for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx)
target.save()
source = ManyToManySource(name='source-%d' % idx)
source.save()
for target in ManyToManyTarget.objects.all():
source.targets.add(target)
def test_relative_hyperlinks(self):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
expected = [
{'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
]
with self.assertNumQueries(4):
assert serializer.data == expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
expected = [ {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} ]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
with self.assertNumQueries(4):
assert serializer.data == expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve_prefetch_related(self):
queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
with self.assertNumQueries(2):
serializer.data
queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
with self.assertNumQueries(2):
serializer.data
def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
with self.assertNumQueries(4):
assert serializer.data == expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
assert serializer.data == expected
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
assert serializer.data == expected
def test_reverse_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
assert serializer.data == expected
data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
assert serializer.data == expected
def test_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
serializer = ManyToManySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
]
assert serializer.data == expected
data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
serializer = ManyToManySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ]
assert serializer.data == expected
def test_reverse_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-4'
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
]
assert serializer.data == expected
data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-4'
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ]
assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
with self.assertNumQueries(1):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
with self.assertNumQueries(3):
assert serializer.data == expected
def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert not serializer.is_valid()
assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']}
def test_reverse_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected
serializer.save()
assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
assert serializer.data == expected
def test_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-3'
# Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']}
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedNullableForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedNullableOneToOneTests(TestCase):
def setUp(self):
target = OneToOneTarget(name='target-1')
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
source = NullableOneToOneSource(name='source-1', target=target)
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
expected = [
{'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
with self.assertNumQueries(1):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
with self.assertNumQueries(3):
assert serializer.data == expected
def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert not serializer.is_valid()
assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']}
def test_reverse_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
assert new_serializer.data == expected
serializer.save()
assert serializer.data == data
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
assert serializer.data == expected
def test_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, ]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-3'
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']}
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
def setUp(self):
target = OneToOneTarget(name='target-1')
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
source = NullableOneToOneSource(name='source-1', target=target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
expected = [ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ]
assert serializer.data == expected

View File

@ -77,496 +77,373 @@ class OneToOnePKSourceSerializer(serializers.ModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid
class PKManyToManyTests(TestCase):
def setUp(self):
for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx)
target.save()
source = ManyToManySource(name='source-%d' % idx)
source.save()
for target in ManyToManyTarget.objects.all():
source.targets.add(target)
def setUp(self):
for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx)
target.save()
source = ManyToManySource(name='source-%d' % idx)
source.save()
for target in ManyToManyTarget.objects.all():
source.targets.add(target)
def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
with self.assertNumQueries(4):
assert serializer.data == expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve_prefetch_related(self):
queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]}
]
with self.assertNumQueries(4):
assert serializer.data == expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_update(self):
data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
assert serializer.data == expected
data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
assert serializer.data == expected
def test_reverse_many_to_many_update(self):
data = {'id': 1, 'name': 'target-1', 'sources': [1]}
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [1]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]}
]
assert serializer.data == expected
data = {'id': 1, 'name': 'target-1', 'sources': [1]}
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': [1]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
assert serializer.data == expected
def test_many_to_many_create(self):
data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
{'id': 4, 'name': 'source-4', 'targets': [1, 3]},
]
assert serializer.data == expected
data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, ]
assert serializer.data == expected
def test_many_to_many_unsaved(self):
source = ManyToManySource(name='source-unsaved')
serializer = ManyToManySourceSerializer(source)
expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
# no query if source hasn't been created yet
with self.assertNumQueries(0):
assert serializer.data == expected
source = ManyToManySource(name='source-unsaved')
serializer = ManyToManySourceSerializer(source)
expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
with self.assertNumQueries(0):
assert serializer.data == expected
def test_reverse_many_to_many_create(self):
data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-4'
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]},
{'id': 4, 'name': 'target-4', 'sources': [1, 3]}
]
assert serializer.data == expected
data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-4'
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]}, {'id': 4, 'name': 'target-4', 'sources': [1, 3]} ]
assert serializer.data == expected
class PKForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1}
]
with self.assertNumQueries(1):
assert serializer.data == expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
with self.assertNumQueries(1):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []},
]
with self.assertNumQueries(3):
assert serializer.data == expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
with self.assertNumQueries(3):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self):
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 2},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1}
]
assert serializer.data == expected
data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 2}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received str.']}
data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received str.']}
def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected
serializer.save()
assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': [1, 3]},
]
assert serializer.data == expected
data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid()
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
assert new_serializer.data == expected
serializer.save()
assert serializer.data == data
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, ]
assert serializer.data == expected
def test_foreign_key_create(self):
data = {'id': 4, 'name': 'source-4', 'target': 2}
serializer = ForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1},
{'id': 4, 'name': 'source-4', 'target': 2},
]
assert serializer.data == expected
data = {'id': 4, 'name': 'source-4', 'target': 2}
serializer = ForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1}, {'id': 4, 'name': 'source-4', 'target': 2}, ]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
serializer = ForeignKeyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-3'
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': [1, 3]},
]
assert serializer.data == expected
data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
serializer = ForeignKeyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-3'
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, ]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']}
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']}
def test_foreign_key_with_unsaved(self):
source = ForeignKeySource(name='source-unsaved')
expected = {'id': None, 'name': 'source-unsaved', 'target': None}
serializer = ForeignKeySourceSerializer(source)
# no query if source hasn't been created yet
with self.assertNumQueries(0):
assert serializer.data == expected
source = ForeignKeySource(name='source-unsaved')
expected = {'id': None, 'name': 'source-unsaved', 'target': None}
serializer = ForeignKeySourceSerializer(source)
with self.assertNumQueries(0):
assert serializer.data == expected
def test_foreign_key_with_empty(self):
"""
"""
Regression test for #1072
https://github.com/encode/django-rest-framework/issues/1072
"""
serializer = NullableForeignKeySourceSerializer()
assert serializer.data['target'] is None
serializer = NullableForeignKeySourceSerializer()
assert serializer.data['target'] is None
def test_foreign_key_not_required(self):
"""
"""
Let's say we wanted to fill the non-nullable model field inside
Model.save(), we would make it empty and not required.
"""
class ModelSerializer(ForeignKeySourceSerializer):
class Meta(ForeignKeySourceSerializer.Meta):
extra_kwargs = {'target': {'required': False}}
class ModelSerializer(ForeignKeySourceSerializer):
class Meta(ForeignKeySourceSerializer.Meta):
extra_kwargs = {'target': {'required': False}}
serializer = ModelSerializer(data={'name': 'test'})
serializer.is_valid(raise_exception=True)
assert 'target' not in serializer.validated_data
serializer.is_valid(raise_exception=True)
assert'target'notinserializer.validated_data
def test_queryset_size_without_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save()
queryset = ForeignKeySourceSerializer().fields["target"].get_queryset()
assert len(queryset) == 3
limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save()
queryset = ForeignKeySourceSerializer().fields["target"].get_queryset()
assert len(queryset) == 3
def test_queryset_size_with_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save()
queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset()
assert len(queryset) == 1
limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save()
queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset()
assert len(queryset) == 1
def test_queryset_size_with_Q_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save()
class QLimitedChoicesSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySourceWithQLimitedChoices
fields = ("id", "target")
limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save()
class QLimitedChoicesSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySourceWithQLimitedChoices
fields = ("id", "target")
queryset = QLimitedChoicesSerializer().fields["target"].get_queryset()
assert len(queryset) == 1
assertlen(queryset)==1
class PKNullableForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, ]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self):
"""
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self):
"""
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected
data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
assert serializer.data == expected
def test_null_uuid_foreign_key_serializes_as_none(self):
source = NullableUUIDForeignKeySource(name='Source')
serializer = NullableUUIDForeignKeySourceSerializer(source)
data = serializer.data
assert data["target"] is None
source = NullableUUIDForeignKeySource(name='Source')
serializer = NullableUUIDForeignKeySourceSerializer(source)
data = serializer.data
assert data["target"] is None
def test_nullable_uuid_foreign_key_is_valid_when_none(self):
data = {"name": "Source", "target": None}
serializer = NullableUUIDForeignKeySourceSerializer(data=data)
assert serializer.is_valid(), serializer.errors
data = {"name": "Source", "target": None}
serializer = NullableUUIDForeignKeySourceSerializer(data=data)
assert serializer.is_valid(), serializer.errors
class PKNullableOneToOneTests(TestCase):
def setUp(self):
target = OneToOneTarget(name='target-1')
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
source = NullableOneToOneSource(name='source-1', target=new_target)
source.save()
def setUp(self):
target = OneToOneTarget(name='target-1')
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
source = NullableOneToOneSource(name='source-1', target=new_target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'nullable_source': None},
{'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
assert serializer.data == expected
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ]
assert serializer.data == expected
class OneToOnePrimaryKeyTests(TestCase):
def setUp(self):
def setUp(self):
# Given: Some target models already exist
self.target = target = OneToOneTarget(name='target-1')
target.save()
self.alt_target = alt_target = OneToOneTarget(name='target-2')
alt_target.save()
self.target = target = OneToOneTarget(name='target-1')
target.save()
self.alt_target = alt_target = OneToOneTarget(name='target-2')
alt_target.save()
def test_one_to_one_when_primary_key(self):
# When: Creating a Source pointing at the id of the second Target
target_pk = self.alt_target.id
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
# Then: The source is valid with the serializer
if not source.is_valid():
self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
target_pk = self.alt_target.id
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
if not source.is_valid():
self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
# Then: Saving the serializer creates a new object
new_source = source.save()
# Then: The new object has the same pk as the target object
self.assertEqual(new_source.pk, target_pk)
assertnew_source.pk==target_pk
def test_one_to_one_when_primary_key_no_duplicates(self):
# When: Creating a Source pointing at the id of the second Target
target_pk = self.target.id
data = {'name': 'source-1', 'target': target_pk}
source = OneToOnePKSourceSerializer(data=data)
# Then: The source is valid with the serializer
self.assertTrue(source.is_valid())
# Then: Saving the serializer creates a new object
new_source = source.save()
# Then: The new object has the same pk as the target object
self.assertEqual(new_source.pk, target_pk)
# When: Trying to create a second object
second_source = OneToOnePKSourceSerializer(data=data)
self.assertFalse(second_source.is_valid())
expected = {'target': ['one to one pk source with this target already exists.']}
self.assertDictEqual(second_source.errors, expected)
target_pk = self.target.id
data = {'name': 'source-1', 'target': target_pk}
source = OneToOnePKSourceSerializer(data=data)
assert source.is_valid()
new_source = source.save()
assert new_source.pk == target_pk
second_source = OneToOnePKSourceSerializer(data=data)
assert not second_source.is_valid()
expected = {'target': ['one to one pk source with this target already exists.']}
assert second_source.errors == expected
def test_one_to_one_when_primary_key_does_not_exist(self):
# Given: a target PK that does not exist
target_pk = self.target.pk + self.alt_target.pk
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
# Then: The source is not valid with the serializer
self.assertFalse(source.is_valid())
self.assertIn("Invalid pk", source.errors['target'][0])
self.assertIn("object does not exist", source.errors['target'][0])
target_pk = self.target.pk + self.alt_target.pk
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
assert not source.is_valid()
assert "Invalid pk" in source.errors['target'][0]
assert "object does not exist" in source.errors['target'][0]

View File

@ -42,246 +42,177 @@ class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
# TODO: M2M Tests, FKTests (Non-nullable), One2One
class SlugForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
new_target = ForeignKeyTarget(name='target-2')
new_target.save()
for idx in range(1, 4):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
with self.assertNumQueries(4):
assert serializer.data == expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_foreign_key_retrieve_select_related(self):
queryset = ForeignKeySource.objects.all().select_related('target')
serializer = ForeignKeySourceSerializer(queryset, many=True)
with self.assertNumQueries(1):
serializer.data
queryset = ForeignKeySource.objects.all().select_related('target')
serializer = ForeignKeySourceSerializer(queryset, many=True)
with self.assertNumQueries(1):
serializer.data
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert serializer.data == expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self):
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-2'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
assert serializer.data == expected
data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-2'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
data = {'id': 1, 'name': 'source-1', 'target': 123}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
assert serializer.errors == {'target': ['Object with name=123 does not exist.']}
data = {'id': 1, 'name': 'source-1', 'target': 123}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
assert serializer.errors == {'target': ['Object with name=123 does not exist.']}
def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected
serializer.save()
assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
]
assert serializer.data == expected
data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid()
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
assert new_serializer.data == expected
serializer.save()
assert serializer.data == data
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, ]
assert serializer.data == expected
def test_foreign_key_create(self):
data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
serializer = ForeignKeySourceSerializer(data=data)
serializer.is_valid()
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'},
{'id': 4, 'name': 'source-4', 'target': 'target-2'},
]
assert serializer.data == expected
data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
serializer = ForeignKeySourceSerializer(data=data)
serializer.is_valid()
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'}, {'id': 4, 'name': 'source-4', 'target': 'target-2'}, ]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-3'
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
]
assert serializer.data == expected
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'target-3'
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, ]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']}
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
assert serializer.errors == {'target': ['This field may not be null.']}
class SlugNullableForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
]
assert serializer.data == expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, ]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self):
"""
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected
data = {'id': 4, 'name': 'source-4', 'target': ''}
expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
assert obj.name == 'source-4'
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self):
"""
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected
data = {'id': 1, 'name': 'source-1', 'target': ''}
expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
assert serializer.data == expected

View File

@ -49,11 +49,10 @@ class DummyTestModel(models.Model):
name = models.CharField(max_length=42, default='')
class BasicRendererTests(TestCase):
def test_expected_results(self):
for value, renderer_cls, expected in expected_results:
output = renderer_cls().render(value)
self.assertEqual(output, expected)
def test_expected_results(self):
for value, renderer_cls, expected in expected_results:
output = renderer_cls().render(value)
assert output == expected
class RendererA(BaseRenderer):
@ -144,123 +143,114 @@ class POSTDeniedView(APIView):
return Response()
class DocumentingRendererTests(TestCase):
def test_only_permitted_forms_are_displayed(self):
view = POSTDeniedView.as_view()
request = APIRequestFactory().get('/')
response = view(request).render()
self.assertNotContains(response, '>POST<')
self.assertContains(response, '>PUT<')
self.assertContains(response, '>PATCH<')
def test_only_permitted_forms_are_displayed(self):
view = POSTDeniedView.as_view()
request = APIRequestFactory().get('/')
response = view(request).render()
self.assertNotContains(response, '>POST<')
self.assertContains(response, '>PUT<')
self.assertContains(response, '>PATCH<')
@override_settings(ROOT_URLCONF='tests.test_renderers')
class RendererEndToEndTests(TestCase):
"""
"""
End-to-end testing of renderers using an RendererMixin on a generic view.
"""
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
deftest_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, b'')
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
assert resp.status_code == DUMMYSTATUS
assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
assert resp.status_code == status.HTTP_406_NOT_ACCEPTABLE
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
param = '?%s=%s' % (
api_settings.URL_FORMAT_OVERRIDE,
RendererB.format
)
resp = self.client.get('/' + param)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
resp = self.client.get('/' + param)
assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
resp = self.client.get('/something.formatb')
assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
param = '?%s=%s' % (
api_settings.URL_FORMAT_OVERRIDE,
RendererB.format
)
resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type)
assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_parse_error_renderers_browsable_api(self):
"""Invalid data should still render the browsable API correctly."""
resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
"""Invalid data should still render the browsable API correctly."""
resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
assert resp['Content-Type'] == 'text/html; charset=utf-8'
assert resp.status_code == status.HTTP_400_BAD_REQUEST
def test_204_no_content_responses_have_no_content_type_set(self):
"""
"""
Regression test for #1196
https://github.com/encode/django-rest-framework/issues/1196
"""
resp = self.client.get('/empty')
self.assertEqual(resp.get('Content-Type', None), None)
self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
resp = self.client.get('/empty')
assert resp.get('Content-Type', None) == None
assert resp.status_code == status.HTTP_204_NO_CONTENT
def test_contains_headers_of_api_response(self):
"""
"""
Issue #1437
Test we display the headers of the API response and not those from the
HTML response
"""
resp = self.client.get('/html1')
self.assertContains(resp, '>GET, HEAD, OPTIONS<')
self.assertContains(resp, '>application/json<')
self.assertNotContains(resp, '>text/html; charset=utf-8<')
resp = self.client.get('/html1')
self.assertContains(resp, '>GET, HEAD, OPTIONS<')
self.assertContains(resp, '>application/json<')
self.assertNotContains(resp, '>text/html; charset=utf-8<')
_flat_repr = '{"foo":["bar","baz"]}'
@ -275,183 +265,174 @@ def strip_trailing_whitespace(content):
return re.sub(' +\n', '\n', content)
class BaseRendererTests(TestCase):
"""
"""
Tests BaseRenderer
"""
def test_render_raise_error(self):
"""
deftest_render_raise_error(self):
"""
BaseRenderer.render should raise NotImplementedError
"""
with pytest.raises(NotImplementedError):
BaseRenderer().render('test')
with pytest.raises(NotImplementedError):
BaseRenderer().render('test')
class JSONRendererTests(TestCase):
"""
"""
Tests specific to the JSON Renderer
"""
def test_render_lazy_strings(self):
"""
deftest_render_lazy_strings(self):
"""
JSONRenderer should deal with lazy translated strings.
"""
ret = JSONRenderer().render(_('test'))
self.assertEqual(ret, b'"test"')
ret = JSONRenderer().render(_('test'))
assert ret == b'"test"'
def test_render_queryset_values(self):
o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values('id', 'name')
ret = JSONRenderer().render(qs)
data = json.loads(ret.decode())
self.assertEqual(data, [{'id': o.id, 'name': o.name}])
o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values('id', 'name')
ret = JSONRenderer().render(qs)
data = json.loads(ret.decode())
assert data == [{'id': o.id, 'name': o.name}]
def test_render_queryset_values_list(self):
o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values_list('id', 'name')
ret = JSONRenderer().render(qs)
data = json.loads(ret.decode())
self.assertEqual(data, [[o.id, o.name]])
o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values_list('id', 'name')
ret = JSONRenderer().render(qs)
data = json.loads(ret.decode())
assert data == [[o.id, o.name]]
def test_render_dict_abc_obj(self):
class Dict(MutableMapping):
def __init__(self):
self._dict = {}
class Dict(MutableMapping):
def __init__(self):
self._dict = {}
def __getitem__(self, key):
return self._dict.__getitem__(key)
return self._dict.__getitem__(key)
def __setitem__(self, key, value):
return self._dict.__setitem__(key, value)
return self._dict.__setitem__(key, value)
def __delitem__(self, key):
return self._dict.__delitem__(key)
return self._dict.__delitem__(key)
def __iter__(self):
return self._dict.__iter__()
return self._dict.__iter__()
def __len__(self):
return self._dict.__len__()
return self._dict.__len__()
def keys(self):
return self._dict.keys()
return self._dict.keys()
x = Dict()
x['key'] = 'string value'
x[2] = 3
ret = JSONRenderer().render(x)
data = json.loads(ret.decode())
self.assertEqual(data, {'key': 'string value', '2': 3})
x['key']='string value'
x[2]=3
ret=JSONRenderer().render(x)
data=json.loads(ret.decode())
assertdata=={'key':'string value','2':3}
def test_render_obj_with_getitem(self):
class DictLike:
def __init__(self):
self._dict = {}
class DictLike:
def __init__(self):
self._dict = {}
def set(self, value):
self._dict = dict(value)
self._dict = dict(value)
def __getitem__(self, key):
return self._dict[key]
return self._dict[key]
x = DictLike()
x.set({'a': 1, 'b': 'string'})
with self.assertRaises(TypeError):
JSONRenderer().render(x)
x.set({'a':1,'b':'string'})
withpytest.raises(TypeError):
JSONRenderer().render(x)
def test_float_strictness(self):
renderer = JSONRenderer()
# Default to strict
for value in [float('inf'), float('-inf'), float('nan')]:
with pytest.raises(ValueError):
renderer.render(value)
renderer = JSONRenderer()
for value in [float('inf'), float('-inf'), float('nan')]:
with pytest.raises(ValueError):
renderer.render(value)
renderer.strict = False
assert renderer.render(float('inf')) == b'Infinity'
assert renderer.render(float('-inf')) == b'-Infinity'
assert renderer.render(float('nan')) == b'NaN'
assertrenderer.render(float('inf'))==b'Infinity'
assertrenderer.render(float('-inf'))==b'-Infinity'
assertrenderer.render(float('nan'))==b'NaN'
def test_without_content_type_args(self):
"""
"""
Test basic JSON rendering.
"""
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library.
self.assertEqual(content.decode(), _flat_repr)
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
assert content.decode() == _flat_repr
def test_with_content_type_args(self):
"""
"""
Test JSON rendering with additional content type arguments supplied.
"""
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2')
self.assertEqual(strip_trailing_whitespace(content.decode()), _indented_repr)
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2')
assert strip_trailing_whitespace(content.decode()) == _indented_repr
class UnicodeJSONRendererTests(TestCase):
"""
"""
Tests specific for the Unicode JSON Renderer
"""
def test_proper_encoding(self):
obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode())
deftest_proper_encoding(self):
obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
assert content == '{"countries":["United Kingdom","France","España"]}'.encode()
def test_u2028_u2029(self):
# The \u2028 and \u2029 characters should be escaped,
# even when the non-escaping unicode representation is used.
# Regression test for #2169
obj = {'should_escape': '\u2028\u2029'}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode())
obj = {'should_escape': '\u2028\u2029'}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
assert content == '{"should_escape":"\\u2028\\u2029"}'.encode()
class AsciiJSONRendererTests(TestCase):
"""
"""
Tests specific for the Unicode JSON Renderer
"""
def test_proper_encoding(self):
class AsciiJSONRenderer(JSONRenderer):
ensure_ascii = True
deftest_proper_encoding(self):
class AsciiJSONRenderer(JSONRenderer):
ensure_ascii = True
obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = AsciiJSONRenderer()
content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode())
renderer=AsciiJSONRenderer()
content=renderer.render(obj,'application/json')
assertcontent=='{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()
# Tests for caching issue, #346
@override_settings(ROOT_URLCONF='tests.test_renderers')
class CacheRenderTest(TestCase):
"""
"""
Tests specific to caching responses
"""
def test_head_caching(self):
"""
deftest_head_caching(self):
"""
Test caching of HEAD requests
"""
response = self.client.head('/cache')
cache.set('key', response)
cached_response = cache.get('key')
assert isinstance(cached_response, Response)
assert cached_response.content == response.content
assert cached_response.status_code == response.status_code
response = self.client.head('/cache')
cache.set('key', response)
cached_response = cache.get('key')
assert isinstance(cached_response, Response)
assert cached_response.content == response.content
assert cached_response.status_code == response.status_code
def test_get_caching(self):
"""
"""
Test caching of GET requests
"""
response = self.client.get('/cache')
cache.set('key', response)
cached_response = cache.get('key')
assert isinstance(cached_response, Response)
assert cached_response.content == response.content
assert cached_response.status_code == response.status_code
response = self.client.get('/cache')
cache.set('key', response)
cached_response = cache.get('key')
assert isinstance(cached_response, Response)
assert cached_response.content == response.content
assert cached_response.status_code == response.status_code
class TestJSONIndentationStyles:
@ -476,150 +457,116 @@ class TestJSONIndentationStyles:
assert renderer.render(data) == b'{"a": 1, "b": 2}'
class TestHiddenFieldHTMLFormRenderer(TestCase):
def test_hidden_field_rendering(self):
class TestSerializer(serializers.Serializer):
published = serializers.HiddenField(default=True)
def test_hidden_field_rendering(self):
class TestSerializer(serializers.Serializer):
published = serializers.HiddenField(default=True)
serializer = TestSerializer(data={})
serializer.is_valid()
renderer = HTMLFormRenderer()
field = serializer['published']
rendered = renderer.render_field(field, {})
assert rendered == ''
serializer.is_valid()
renderer=HTMLFormRenderer()
field=serializer['published']
rendered=renderer.render_field(field,{})
assertrendered==''
class TestHTMLFormRenderer(TestCase):
def setUp(self):
class TestSerializer(serializers.Serializer):
test_field = serializers.CharField()
def setUp(self):
class TestSerializer(serializers.Serializer):
test_field = serializers.CharField()
self.renderer = HTMLFormRenderer()
self.serializer = TestSerializer(data={})
self.serializer=TestSerializer(data={})
def test_render_with_default_args(self):
self.serializer.is_valid()
renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data)
self.assertIsInstance(result, SafeText)
self.serializer.is_valid()
renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data)
assert isinstance(result, SafeText)
def test_render_with_provided_args(self):
self.serializer.is_valid()
renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data, None, {})
self.assertIsInstance(result, SafeText)
self.serializer.is_valid()
renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data, None, {})
assert isinstance(result, SafeText)
class TestChoiceFieldHTMLFormRenderer(TestCase):
"""
"""
Test rendering ChoiceField with HTMLFormRenderer.
"""
def setUp(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
test_field = serializers.ChoiceField(choices=choices,
initial=2)
defsetUp(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
test_field = serializers.ChoiceField(choices=choices, initial=2)
self.TestSerializer = TestSerializer
self.renderer = HTMLFormRenderer()
self.renderer=HTMLFormRenderer()
def test_render_initial_option(self):
serializer = self.TestSerializer()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
self.assertInHTML('<option value="2" selected>Option2</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="12">Option12</option>', result)
serializer = self.TestSerializer()
result = self.renderer.render(serializer.data)
assert isinstance(result, SafeText)
self.assertInHTML('<option value="2" selected>Option2</option>', result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="12">Option12</option>', result)
def test_render_selected_option(self):
serializer = self.TestSerializer(data={'test_field': '12'})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
serializer = self.TestSerializer(data={'test_field': '12'})
serializer.is_valid()
result = self.renderer.render(serializer.data)
assert isinstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>', result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
"""
"""
Test rendering MultipleChoiceField with HTMLFormRenderer.
"""
def setUp(self):
self.renderer = HTMLFormRenderer()
defsetUp(self):
self.renderer = HTMLFormRenderer()
def test_render_selected_option_with_string_option_ids(self):
choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'),
('}', 'OptionBrace'))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), ('}', 'OptionBrace'))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
self.assertInHTML('<option value="}">OptionBrace</option>', result)
serializer.is_valid()
result=self.renderer.render(serializer.data)
assertisinstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',result)
self.assertInHTML('<option value="1">Option1</option>',result)
self.assertInHTML('<option value="2">Option2</option>',result)
self.assertInHTML('<option value="}">OptionBrace</option>',result)
def test_render_selected_option_with_integer_option_ids(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
serializer.is_valid()
result=self.renderer.render(serializer.data)
assertisinstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',result)
self.assertInHTML('<option value="1">Option1</option>',result)
self.assertInHTML('<option value="2">Option2</option>',result)
class StaticHTMLRendererTests(TestCase):
"""
"""
Tests specific for Static HTML Renderer
"""
def setUp(self):
self.renderer = StaticHTMLRenderer()
defsetUp(self):
self.renderer = StaticHTMLRenderer()
def test_static_renderer(self):
data = '<html><body>text</body></html>'
result = self.renderer.render(data)
assert result == data
data = '<html><body>text</body></html>'
result = self.renderer.render(data)
assert result == data
def test_static_renderer_with_exception(self):
context = {
'response': Response(status=500, exception=True),
'request': Request(HttpRequest())
}
result = self.renderer.render({}, renderer_context=context)
assert result == '500 Internal Server Error'
context = { 'response': Response(status=500, exception=True), 'request': Request(HttpRequest()) }
result = self.renderer.render({}, renderer_context=context)
assert result == '500 Internal Server Error'
class BrowsableAPIRendererTests(URLPatternsTestCase):
@ -658,196 +605,142 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
assert '>Extra list action<' in resp.content.decode()
class AdminRendererTests(TestCase):
def setUp(self):
self.renderer = AdminRenderer()
def setUp(self):
self.renderer = AdminRenderer()
def test_render_when_resource_created(self):
class DummyView(APIView):
renderer_classes = (AdminRenderer, )
class DummyView(APIView):
renderer_classes = (AdminRenderer, )
request = Request(HttpRequest())
request.build_absolute_uri = lambda: 'http://example.com'
response = Response(status=201, headers={'Location': '/test'})
context = {
'view': DummyView(),
'request': request,
'response': response
}
result = self.renderer.render(data={'test': 'test'},
renderer_context=context)
assert result == ''
assert response.status_code == status.HTTP_303_SEE_OTHER
assert response['Location'] == 'http://example.com'
request.build_absolute_uri=lambda:'http://example.com'
response=Response(status=201,headers={'Location':'/test'})
context={'view':DummyView(),'request':request,'response':response}
result=self.renderer.render(data={'test':'test'},renderer_context=context)
assertresult==''
assertresponse.status_code==status.HTTP_303_SEE_OTHER
assertresponse['Location']=='http://example.com'
def test_render_dict(self):
factory = APIRequestFactory()
class DummyView(APIView):
renderer_classes = (AdminRenderer, )
def get(self, request):
return Response({'foo': 'a string'})
factory = APIRequestFactory()
class DummyView(APIView):
renderer_classes = (AdminRenderer, )
def get(self, request):
return Response({'foo': 'a string'})
view = DummyView.as_view()
request = factory.get('/')
response = view(request)
response.render()
self.assertContains(response, '<tr><th>Foo</th><td>a string</td></tr>', html=True)
request=factory.get('/')
response=view(request)
response.render()
self.assertContains(response,'<tr><th>Foo</th><td>a string</td></tr>',html=True)
def test_render_dict_with_items_key(self):
factory = APIRequestFactory()
class DummyView(APIView):
renderer_classes = (AdminRenderer, )
def get(self, request):
return Response({'items': 'a string'})
factory = APIRequestFactory()
class DummyView(APIView):
renderer_classes = (AdminRenderer, )
def get(self, request):
return Response({'items': 'a string'})
view = DummyView.as_view()
request = factory.get('/')
response = view(request)
response.render()
self.assertContains(response, '<tr><th>Items</th><td>a string</td></tr>', html=True)
request=factory.get('/')
response=view(request)
response.render()
self.assertContains(response,'<tr><th>Items</th><td>a string</td></tr>',html=True)
def test_render_dict_with_iteritems_key(self):
factory = APIRequestFactory()
class DummyView(APIView):
renderer_classes = (AdminRenderer, )
def get(self, request):
return Response({'iteritems': 'a string'})
factory = APIRequestFactory()
class DummyView(APIView):
renderer_classes = (AdminRenderer, )
def get(self, request):
return Response({'iteritems': 'a string'})
view = DummyView.as_view()
request = factory.get('/')
response = view(request)
response.render()
self.assertContains(response, '<tr><th>Iteritems</th><td>a string</td></tr>', html=True)
request=factory.get('/')
response=view(request)
response.render()
self.assertContains(response,'<tr><th>Iteritems</th><td>a string</td></tr>',html=True)
def test_get_result_url(self):
factory = APIRequestFactory()
class DummyGenericViewsetLike(APIView):
lookup_field = 'test'
def reverse_action(view, *args, **kwargs):
self.assertEqual(kwargs['kwargs']['test'], 1)
return '/example/'
factory = APIRequestFactory()
class DummyGenericViewsetLike(APIView):
lookup_field = 'test'
def reverse_action(view, *args, **kwargs):
assert kwargs['kwargs']['test'] == 1
return '/example/'
# get the view instance instead of the view function
view = DummyGenericViewsetLike.as_view()
request = factory.get('/')
response = view(request)
view = response.renderer_context['view']
self.assertEqual(self.renderer.get_result_url({'test': 1}, view), '/example/')
self.assertIsNone(self.renderer.get_result_url({}, view))
request=factory.get('/')
response=view(request)
view=response.renderer_context['view']
assertself.renderer.get_result_url({'test':1},view)=='/example/'
assertself.renderer.get_result_url({},view)is None
def test_get_result_url_no_result(self):
factory = APIRequestFactory()
class DummyView(APIView):
lookup_field = 'test'
factory = APIRequestFactory()
class DummyView(APIView):
lookup_field = 'test'
# get the view instance instead of the view function
view = DummyView.as_view()
request = factory.get('/')
response = view(request)
view = response.renderer_context['view']
self.assertIsNone(self.renderer.get_result_url({'test': 1}, view))
self.assertIsNone(self.renderer.get_result_url({}, view))
request=factory.get('/')
response=view(request)
view=response.renderer_context['view']
assertself.renderer.get_result_url({'test':1},view)is None
assertself.renderer.get_result_url({},view)is None
def test_get_context_result_urls(self):
factory = APIRequestFactory()
class DummyView(APIView):
lookup_field = 'test'
def reverse_action(view, url_name, args=None, kwargs=None):
return '/%s/%d' % (url_name, kwargs['test'])
factory = APIRequestFactory()
class DummyView(APIView):
lookup_field = 'test'
def reverse_action(view, url_name, args=None, kwargs=None):
return '/%s/%d' % (url_name, kwargs['test'])
# get the view instance instead of the view function
view = DummyView.as_view()
request = factory.get('/')
response = view(request)
data = [
{'test': 1},
{'url': '/example', 'test': 2},
{'url': None, 'test': 3},
{},
]
context = {
'view': DummyView(),
'request': Request(request),
'response': response
}
context = self.renderer.get_context(data, None, context)
results = context['results']
self.assertEqual(len(results), 4)
self.assertEqual(results[0]['url'], '/detail/1')
self.assertEqual(results[1]['url'], '/example')
self.assertEqual(results[2]['url'], None)
self.assertNotIn('url', results[3])
request=factory.get('/')
response=view(request)
data=[{'test':1},{'url':'/example','test':2},{'url':None,'test':3},{},]
context={'view':DummyView(),'request':Request(request),'response':response}
context=self.renderer.get_context(data,None,context)
results=context['results']
assertlen(results)==4
assertresults[0]['url']=='/detail/1'
assertresults[1]['url']=='/example'
assertresults[2]['url']==None
assert'url'not inresults[3]
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestDocumentationRenderer(TestCase):
def test_document_with_link_named_data(self):
"""
def test_document_with_link_named_data(self):
"""
Ref #5395: Doc's `document.data` would fail with a Link named "data".
As per #4972, use templatetag instead.
"""
document = coreapi.Document(
title='Data Endpoint API',
url='https://api.example.org/',
content={
'data': coreapi.Link(
url='/data/',
action='get',
fields=[],
description='Return data.'
)
}
)
factory = APIRequestFactory()
request = factory.get('/')
renderer = DocumentationRenderer()
html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
assert '<h1>Data Endpoint API</h1>' in html
document = coreapi.Document( title='Data Endpoint API', url='https://api.example.org/', content={ 'data': coreapi.Link( url='/data/', action='get', fields=[], description='Return data.' ) } )
factory = APIRequestFactory()
request = factory.get('/')
renderer = DocumentationRenderer()
html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
assert '<h1>Data Endpoint API</h1>' in html
def test_shell_code_example_rendering(self):
template = loader.get_template('rest_framework/docs/langs/shell.html')
context = {
'document': coreapi.Document(url='https://api.example.org/'),
'link_key': 'testcases > list',
'link': coreapi.Link(url='/data/', action='get', fields=[]),
}
html = template.render(context)
assert 'testcases list' in html
template = loader.get_template('rest_framework/docs/langs/shell.html')
context = { 'document': coreapi.Document(url='https://api.example.org/'), 'link_key': 'testcases > list', 'link': coreapi.Link(url='/data/', action='get', fields=[]), }
html = template.render(context)
assert 'testcases list' in html
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestSchemaJSRenderer(TestCase):
def test_schemajs_output(self):
"""
def test_schemajs_output(self):
"""
Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates,
and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix.
"""
factory = APIRequestFactory()
request = factory.get('/')
renderer = SchemaJSRenderer()
output = renderer.render('data', renderer_context={"request": request})
assert "'ImRhdGEi'" in output
assert "'b'ImRhdGEi''" not in output
factory = APIRequestFactory()
request = factory.get('/')
renderer = SchemaJSRenderer()
output = renderer.render('data', renderer_context={"request": request})
assert "'ImRhdGEi'" in output
assert "'b'ImRhdGEi''" not in output

View File

@ -23,18 +23,11 @@ from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView
factory = APIRequestFactory()
class TestInitializer(TestCase):
def test_request_type(self):
request = Request(factory.get('/'))
message = (
'The `request` argument must be an instance of '
'`django.http.HttpRequest`, not `rest_framework.request.Request`.'
)
with self.assertRaisesMessage(AssertionError, message):
Request(request)
def test_request_type(self):
request = Request(factory.get('/'))
message = ( 'The `request` argument must be an instance of ' '`django.http.HttpRequest`, not `rest_framework.request.Request`.' )
with self.assertRaisesMessage(AssertionError, message):
Request(request)
class PlainTextParser(BaseParser):
@ -50,79 +43,78 @@ class PlainTextParser(BaseParser):
return stream.read()
class TestContentParsing(TestCase):
def test_standard_behaviour_determines_no_content_GET(self):
"""
def test_standard_behaviour_determines_no_content_GET(self):
"""
Ensure request.data returns empty QueryDict for GET request.
"""
request = Request(factory.get('/'))
assert request.data == {}
request = Request(factory.get('/'))
assert request.data == {}
def test_standard_behaviour_determines_no_content_HEAD(self):
"""
"""
Ensure request.data returns empty QueryDict for HEAD request.
"""
request = Request(factory.head('/'))
assert request.data == {}
request = Request(factory.head('/'))
assert request.data == {}
def test_request_DATA_with_form_content(self):
"""
"""
Ensure request.data returns content for POST request with form content.
"""
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.data.items()) == list(data.items())
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.data.items()) == list(data.items())
def test_request_DATA_with_text_content(self):
"""
"""
Ensure request.data returns content for POST request with
non-form content.
"""
content = b'qwerty'
content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type))
request.parsers = (PlainTextParser(),)
assert request.data == content
content = b'qwerty'
content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type))
request.parsers = (PlainTextParser(),)
assert request.data == content
def test_request_POST_with_form_content(self):
"""
"""
Ensure request.POST returns content for POST request with form content.
"""
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.POST.items()) == list(data.items())
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.POST.items()) == list(data.items())
def test_request_POST_with_files(self):
"""
"""
Ensure request.POST returns no content for POST request with file content.
"""
upload = SimpleUploadedFile("file.txt", b"file_content")
request = Request(factory.post('/', {'upload': upload}))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.POST) == []
assert list(request.FILES) == ['upload']
upload = SimpleUploadedFile("file.txt", b"file_content")
request = Request(factory.post('/', {'upload': upload}))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.POST) == []
assert list(request.FILES) == ['upload']
def test_standard_behaviour_determines_form_content_PUT(self):
"""
"""
Ensure request.data returns content for PUT request with form content.
"""
data = {'qwerty': 'uiop'}
request = Request(factory.put('/', data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.data.items()) == list(data.items())
data = {'qwerty': 'uiop'}
request = Request(factory.put('/', data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.data.items()) == list(data.items())
def test_standard_behaviour_determines_non_form_content_PUT(self):
"""
"""
Ensure request.data returns content for PUT request with
non-form content.
"""
content = b'qwerty'
content_type = 'text/plain'
request = Request(factory.put('/', content, content_type=content_type))
request.parsers = (PlainTextParser(), )
assert request.data == content
content = b'qwerty'
content_type = 'text/plain'
request = Request(factory.put('/', content, content_type=content_type))
request.parsers = (PlainTextParser(), )
assert request.data == content
class MockView(APIView):
@ -160,175 +152,148 @@ urlpatterns = [
@override_settings(
ROOT_URLCONF='tests.test_request',
FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler'])
class FileUploadTests(TestCase):
def test_fileuploads_closed_at_request_end(self):
with tempfile.NamedTemporaryFile() as f:
response = self.client.post('/upload/', {'file': f})
def test_fileuploads_closed_at_request_end(self):
with tempfile.NamedTemporaryFile() as f:
response = self.client.post('/upload/', {'file': f})
# sanity check that file was processed
assert len(response.data) == 1
for file in response.data:
assert not os.path.exists(file)
forfileinresponse.data:
assert not os.path.exists(file)
@override_settings(ROOT_URLCONF='tests.test_request')
class TestContentParsingWithAuthentication(TestCase):
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password)
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password)
def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
"""
"""
Ensures request.POST exists after SessionAuthentication when user
doesn't log in.
"""
content = {'example': 'example'}
response = self.client.post('/', content)
assert status.HTTP_200_OK == response.status_code
response = self.csrf_client.post('/', content)
assert status.HTTP_200_OK == response.status_code
content = {'example': 'example'}
response = self.client.post('/', content)
assert status.HTTP_200_OK == response.status_code
response = self.csrf_client.post('/', content)
assert status.HTTP_200_OK == response.status_code
class TestUserSetter(TestCase):
def setUp(self):
def setUp(self):
# Pass request object through session middleware so session is
# available to login and logout functions
self.wrapped_request = factory.get('/')
self.request = Request(self.wrapped_request)
SessionMiddleware().process_request(self.wrapped_request)
AuthenticationMiddleware().process_request(self.wrapped_request)
User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
self.user = authenticate(username='ringo', password='yellow')
self.wrapped_request = factory.get('/')
self.request = Request(self.wrapped_request)
SessionMiddleware().process_request(self.wrapped_request)
AuthenticationMiddleware().process_request(self.wrapped_request)
User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
self.user = authenticate(username='ringo', password='yellow')
def test_user_can_be_set(self):
self.request.user = self.user
assert self.request.user == self.user
self.request.user = self.user
assert self.request.user == self.user
def test_user_can_login(self):
login(self.request, self.user)
assert self.request.user == self.user
login(self.request, self.user)
assert self.request.user == self.user
def test_user_can_logout(self):
self.request.user = self.user
assert not self.request.user.is_anonymous
logout(self.request)
assert self.request.user.is_anonymous
self.request.user = self.user
assert not self.request.user.is_anonymous
logout(self.request)
assert self.request.user.is_anonymous
def test_logged_in_user_is_set_on_wrapped_request(self):
login(self.request, self.user)
assert self.wrapped_request.user == self.user
login(self.request, self.user)
assert self.wrapped_request.user == self.user
def test_calling_user_fails_when_attribute_error_is_raised(self):
"""
"""
This proves that when an AttributeError is raised inside of the request.user
property, that we can handle this and report the true, underlying error.
"""
class AuthRaisesAttributeError:
def authenticate(self, request):
self.MISSPELLED_NAME_THAT_DOESNT_EXIST
class AuthRaisesAttributeError:
def authenticate(self, request):
self.MISSPELLED_NAME_THAT_DOESNT_EXIST
request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),))
# The middleware processes the underlying Django request, sets anonymous user
assert self.wrapped_request.user.is_anonymous
# The DRF request object does not have a user and should run authenticators
expected = r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
with pytest.raises(WrappedAttributeError, match=expected):
request.user
assertself.wrapped_request.user.is_anonymous
expected=r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
withpytest.raises(WrappedAttributeError,match=expected):
request.user
with pytest.raises(WrappedAttributeError, match=expected):
hasattr(request, 'user')
hasattr(request, 'user')
with pytest.raises(WrappedAttributeError, match=expected):
login(request, self.user)
login(request, self.user)
class TestAuthSetter(TestCase):
def test_auth_can_be_set(self):
request = Request(factory.get('/'))
request.auth = 'DUMMY'
assert request.auth == 'DUMMY'
def test_auth_can_be_set(self):
request = Request(factory.get('/'))
request.auth = 'DUMMY'
assert request.auth == 'DUMMY'
class TestSecure(TestCase):
def test_default_secure_false(self):
request = Request(factory.get('/', secure=False))
assert request.scheme == 'http'
def test_default_secure_false(self):
request = Request(factory.get('/', secure=False))
assert request.scheme == 'http'
def test_default_secure_true(self):
request = Request(factory.get('/', secure=True))
assert request.scheme == 'https'
request = Request(factory.get('/', secure=True))
assert request.scheme == 'https'
class TestHttpRequest(TestCase):
def test_attribute_access_proxy(self):
http_request = factory.get('/')
request = Request(http_request)
inner_sentinel = object()
http_request.inner_property = inner_sentinel
assert request.inner_property is inner_sentinel
outer_sentinel = object()
request.inner_property = outer_sentinel
assert request.inner_property is outer_sentinel
def test_attribute_access_proxy(self):
http_request = factory.get('/')
request = Request(http_request)
inner_sentinel = object()
http_request.inner_property = inner_sentinel
assert request.inner_property is inner_sentinel
outer_sentinel = object()
request.inner_property = outer_sentinel
assert request.inner_property is outer_sentinel
def test_exception_proxy(self):
# ensure the exception message is not for the underlying WSGIRequest
http_request = factory.get('/')
request = Request(http_request)
message = "'Request' object has no attribute 'inner_property'"
with self.assertRaisesMessage(AttributeError, message):
request.inner_property
http_request = factory.get('/')
request = Request(http_request)
message = "'Request' object has no attribute 'inner_property'"
with self.assertRaisesMessage(AttributeError, message):
request.inner_property
@override_settings(ROOT_URLCONF='tests.test_request')
def test_duplicate_request_stream_parsing_exception(self):
"""
deftest_duplicate_request_stream_parsing_exception(self):
"""
Check assumption that duplicate stream parsing will result in a
`RawPostDataException` being raised.
"""
response = APIClient().post('/echo/', data={'a': 'b'}, format='json')
request = response.renderer_context['request']
# ensure that request stream was consumed by json parser
assert request.content_type.startswith('application/json')
assert response.data == {'a': 'b'}
# pass same HttpRequest to view, stream already consumed
with pytest.raises(RawPostDataException):
EchoView.as_view()(request._request)
response = APIClient().post('/echo/', data={'a': 'b'}, format='json')
request = response.renderer_context['request']
assert request.content_type.startswith('application/json')
assert response.data == {'a': 'b'}
with pytest.raises(RawPostDataException):
EchoView.as_view()(request._request)
@override_settings(ROOT_URLCONF='tests.test_request')
def test_duplicate_request_form_data_access(self):
"""
deftest_duplicate_request_form_data_access(self):
"""
Form data is copied to the underlying django request for middleware
and file closing reasons. Duplicate processing of a request with form
data is 'safe' in so far as accessing `request.POST` does not trigger
the duplicate stream parse exception.
"""
response = APIClient().post('/echo/', data={'a': 'b'})
request = response.renderer_context['request']
# ensure that request stream was consumed by form parser
assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']}
# pass same HttpRequest to view, form data set on underlying request
response = EchoView.as_view()(request._request)
request = response.renderer_context['request']
# ensure that request stream was consumed by form parser
assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']}
response = APIClient().post('/echo/', data={'a': 'b'})
request = response.renderer_context['request']
assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']}
response = EchoView.as_view()(request._request)
request = response.renderer_context['request']
assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']}

View File

@ -131,155 +131,146 @@ urlpatterns = [
# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
@override_settings(ROOT_URLCONF='tests.test_response')
class RendererIntegrationTests(TestCase):
"""
"""
End-to-end testing of renderers using an ResponseMixin on a generic view.
"""
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
deftest_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, b'')
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
assert resp.status_code == DUMMYSTATUS
assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
resp = self.client.get('/?format=%s' % RendererB.format)
assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
resp = self.client.get('/something.formatb')
assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type)
assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
assert resp.status_code == DUMMYSTATUS
@override_settings(ROOT_URLCONF='tests.test_response')
class UnsupportedMediaTypeTests(TestCase):
def test_should_allow_posting_json(self):
response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
self.assertEqual(response.status_code, 200)
def test_should_allow_posting_json(self):
response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
assert response.status_code == 200
def test_should_not_allow_posting_xml(self):
response = self.client.post('/json', data='<test>123</test>', content_type='application/xml')
self.assertEqual(response.status_code, 415)
response = self.client.post('/json', data='<test>123</test>', content_type='application/xml')
assert response.status_code == 415
def test_should_not_allow_posting_a_form(self):
response = self.client.post('/json', data={'test': 123})
self.assertEqual(response.status_code, 415)
response = self.client.post('/json', data={'test': 123})
assert response.status_code == 415
@override_settings(ROOT_URLCONF='tests.test_response')
class Issue122Tests(TestCase):
"""
"""
Tests that covers #122.
"""
def test_only_html_renderer(self):
"""
deftest_only_html_renderer(self):
"""
Test if no infinite recursion occurs.
"""
self.client.get('/html')
self.client.get('/html')
def test_html_renderer_is_first(self):
"""
"""
Test if no infinite recursion occurs.
"""
self.client.get('/html1')
self.client.get('/html1')
@override_settings(ROOT_URLCONF='tests.test_response')
class Issue467Tests(TestCase):
"""
"""
Tests for #467
"""
def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
deftest_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model')
assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.')
@override_settings(ROOT_URLCONF='tests.test_response')
class Issue807Tests(TestCase):
"""
"""
Covers #807
"""
def test_does_not_append_charset_by_default(self):
"""
deftest_does_not_append_charset_by_default(self):
"""
Renderers don't include a charset unless set explicitly.
"""
headers = {"HTTP_ACCEPT": RendererA.media_type}
resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererA.media_type, 'utf-8')
self.assertEqual(expected, resp['Content-Type'])
headers = {"HTTP_ACCEPT": RendererA.media_type}
resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererA.media_type, 'utf-8')
assert expected == resp['Content-Type']
def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
"""
"""
If renderer class has charset attribute declared, it gets appended
to Response's Content-Type
"""
headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset)
self.assertEqual(expected, resp['Content-Type'])
headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset)
assert expected == resp['Content-Type']
def test_content_type_set_explicitly_on_response(self):
"""
"""
The content type may be set explicitly on the response.
"""
headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/setbyview', **headers)
self.assertEqual('setbyview', resp['Content-Type'])
headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/setbyview', **headers)
assert 'setbyview' == resp['Content-Type']
def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
resp = self.client.get('/html_new_model')
assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.')

View File

@ -30,25 +30,22 @@ class MockVersioningScheme:
@override_settings(ROOT_URLCONF='tests.test_reverse')
class ReverseTests(TestCase):
"""
"""
Tests for fully qualified URLs when using `reverse`.
"""
def test_reversed_urls_are_fully_qualified(self):
request = factory.get('/view')
url = reverse('view', request=request)
assert url == 'http://testserver/view'
deftest_reversed_urls_are_fully_qualified(self):
request = factory.get('/view')
url = reverse('view', request=request)
assert url == 'http://testserver/view'
def test_reverse_with_versioning_scheme(self):
request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme()
url = reverse('view', request=request)
assert url == 'http://scheme-reversed/view'
request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme()
url = reverse('view', request=request)
assert url == 'http://scheme-reversed/view'
def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self):
request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme(raise_error=True)
url = reverse('view', request=request)
assert url == 'http://testserver/view'
request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme(raise_error=True)
url = reverse('view', request=request)
assert url == 'http://testserver/view'

View File

@ -214,25 +214,24 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"}
class TestLookupValueRegex(TestCase):
"""
"""
Ensure the router honors lookup_value_regex when applied
to the viewset.
"""
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f]{32}'
defsetUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
lookup_field = 'uuid'
lookup_value_regex = '[0-9a-f]{32}'
self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet)
self.urls = self.router.urls
self.router.register(r'notes',NoteViewSet)
self.urls=self.router.urls
def test_urls_limited_by_lookup_value_regex(self):
expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$']
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$']
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
@override_settings(ROOT_URLCONF='tests.test_routers')
@ -264,97 +263,84 @@ class TestLookupUrlKwargs(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"}
class TestTrailingSlashIncluded(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet)
self.urls = self.router.urls
self.router.register(r'notes',NoteViewSet)
self.urls=self.router.urls
def test_urls_have_trailing_slash_by_default(self):
expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$']
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$']
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
class TestTrailingSlashRemoved(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
self.router = SimpleRouter(trailing_slash=False)
self.router.register(r'notes', NoteViewSet)
self.urls = self.router.urls
self.router.register(r'notes',NoteViewSet)
self.urls=self.router.urls
def test_urls_can_have_trailing_slash_removed(self):
expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
class TestNameableRoot(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
self.router = DefaultRouter()
self.router.root_view_name = 'nameable-root'
self.router.register(r'notes', NoteViewSet)
self.urls = self.router.urls
self.router.root_view_name='nameable-root'
self.router.register(r'notes',NoteViewSet)
self.urls=self.router.urls
def test_router_has_custom_name(self):
expected = 'nameable-root'
assert expected == self.urls[-1].name
expected = 'nameable-root'
assert expected == self.urls[-1].name
class TestActionKeywordArgs(TestCase):
"""
"""
Ensure keyword arguments passed in the `@action` decorator
are properly handled. Refs #940.
"""
def setUp(self):
class TestViewSet(viewsets.ModelViewSet):
permission_classes = []
@action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny])
def custom(self, request, *args, **kwargs):
return Response({
'permission_classes': self.permission_classes
})
defsetUp(self):
class TestViewSet(viewsets.ModelViewSet):
permission_classes = []
@action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny])
def custom(self, request, *args, **kwargs):
return Response({ 'permission_classes': self.permission_classes })
self.router = SimpleRouter()
self.router.register(r'test', TestViewSet, basename='test')
self.view = self.router.urls[-1].callback
self.router.register(r'test',TestViewSet,basename='test')
self.view=self.router.urls[-1].callback
def test_action_kwargs(self):
request = factory.post('/test/0/custom/')
response = self.view(request)
assert response.data == {'permission_classes': [permissions.AllowAny]}
request = factory.post('/test/0/custom/')
response = self.view(request)
assert response.data == {'permission_classes': [permissions.AllowAny]}
class TestActionAppliedToExistingRoute(TestCase):
"""
"""
Ensure `@action` decorator raises an except when applied
to an existing route
"""
deftest_exception_raised_when_action_applied_to_existing_route(self):
class TestViewSet(viewsets.ModelViewSet):
def test_exception_raised_when_action_applied_to_existing_route(self):
class TestViewSet(viewsets.ModelViewSet):
@action(methods=['post'], detail=True)
def retrieve(self, request, *args, **kwargs):
return Response({
'hello': 'world'
})
@action(methods=['post'], detail=True)
def retrieve(self, request, *args, **kwargs):
return Response({ 'hello': 'world' })
self.router = SimpleRouter()
self.router.register(r'test', TestViewSet, basename='test')
with pytest.raises(ImproperlyConfigured):
self.router.urls
self.router.register(r'test',TestViewSet,basename='test')
withpytest.raises(ImproperlyConfigured):
self.router.urls
class DynamicListAndDetailViewSet(viewsets.ViewSet):
@ -390,44 +376,33 @@ class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet):
pass
class TestDynamicListAndDetailRouter(TestCase):
def setUp(self):
self.router = SimpleRouter()
def setUp(self):
self.router = SimpleRouter()
def _test_list_and_detail_route_decorators(self, viewset):
routes = self.router.get_routes(viewset)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
# Make sure all these endpoints exist and none have been clobbered
for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'),
MethodNamesMap('list_route_get', 'list_route_get'),
MethodNamesMap('list_route_post', 'list_route_post'),
MethodNamesMap('detail_custom_route_get', 'detail_custom-route'),
MethodNamesMap('detail_route_get', 'detail_route_get'),
MethodNamesMap('detail_route_post', 'detail_route_post')
]):
route = decorator_routes[i]
# check url listing
method_name = endpoint.method_name
url_path = endpoint.url_path
if method_name.startswith('list_'):
assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
routes = self.router.get_routes(viewset)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), MethodNamesMap('list_route_get', 'list_route_get'), MethodNamesMap('list_route_post', 'list_route_post'), MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), MethodNamesMap('detail_route_get', 'detail_route_get'), MethodNamesMap('detail_route_post', 'detail_route_post') ]):
route = decorator_routes[i]
method_name = endpoint.method_name
url_path = endpoint.url_path
if method_name.startswith('list_'):
assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
else:
assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)
assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)
# check method to function mapping
if method_name.endswith('_post'):
method_map = 'post'
method_map = 'post'
else:
method_map = 'get'
method_map = 'get'
assert route.mapping[method_map] == method_name
def test_list_and_detail_route_decorators(self):
self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet)
self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet)
def test_inherited_list_and_detail_route_decorators(self):
self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
class TestEmptyPrefix(URLPatternsTestCase, TestCase):
@ -490,69 +465,52 @@ class TestViewInitkwargs(URLPatternsTestCase, TestCase):
assert initkwargs['basename'] == 'routertestmodel'
class TestBaseNameRename(TestCase):
def test_base_name_and_basename_assertion(self):
router = SimpleRouter()
msg = "Do not provide both the `basename` and `base_name` arguments."
with warnings.catch_warnings(record=True) as w, \
self.assertRaisesMessage(AssertionError, msg):
warnings.simplefilter('always')
router.register('mock', MockViewSet, 'mock', base_name='mock')
def test_base_name_and_basename_assertion(self):
router = SimpleRouter()
msg = "Do not provide both the `basename` and `base_name` arguments."
with warnings.catch_warnings(record=True) as w, self.assertRaisesMessage(AssertionError, msg):
warnings.simplefilter('always')
router.register('mock', MockViewSet, 'mock', base_name='mock')
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assert len(w) == 1
assert str(w[0].message) == msg
assertlen(w)==1
assertstr(w[0].message)==msg
def test_base_name_argument_deprecation(self):
router = SimpleRouter()
with pytest.warns(RemovedInDRF311Warning) as w:
warnings.simplefilter('always')
router.register('mock', MockViewSet, base_name='mock')
router = SimpleRouter()
with pytest.warns(RemovedInDRF311Warning) as w:
warnings.simplefilter('always')
router.register('mock', MockViewSet, base_name='mock')
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assert len(w) == 1
assert str(w[0].message) == msg
assert router.registry == [
('mock', MockViewSet, 'mock'),
]
assertlen(w)==1
assertstr(w[0].message)==msg
assertrouter.registry==[('mock',MockViewSet,'mock'),]
def test_basename_argument_no_warnings(self):
router = SimpleRouter()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
router.register('mock', MockViewSet, basename='mock')
router = SimpleRouter()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
router.register('mock', MockViewSet, basename='mock')
assert len(w) == 0
assert router.registry == [
('mock', MockViewSet, 'mock'),
]
assertrouter.registry==[('mock',MockViewSet,'mock'),]
def test_get_default_base_name_deprecation(self):
msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`."
# Class definition should raise a warning
with pytest.warns(RemovedInDRF311Warning) as w:
warnings.simplefilter('always')
class CustomRouter(SimpleRouter):
def get_default_base_name(self, viewset):
return 'foo'
msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`."
with pytest.warns(RemovedInDRF311Warning) as w:
warnings.simplefilter('always')
class CustomRouter(SimpleRouter):
def get_default_base_name(self, viewset):
return 'foo'
assert len(w) == 1
assert str(w[0].message) == msg
# Deprecated method implementation should still be called
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
router = CustomRouter()
router.register('mock', MockViewSet)
assertstr(w[0].message)==msg
withwarnings.catch_warnings(record=True)asw:
warnings.simplefilter('always')
router = CustomRouter()
router.register('mock', MockViewSet)
assert len(w) == 0
assert router.registry == [
('mock', MockViewSet, 'foo'),
]
assertrouter.registry==[('mock',MockViewSet,'foo'),]

File diff suppressed because one or more lines are too long

View File

@ -4,120 +4,66 @@ Tests to cover bulk create and update using serializers.
from django.test import TestCase
from rest_framework import serializers
class BulkCreateSerializerTests(TestCase):
"""
"""
Creating multiple instances using serializers.
"""
def setUp(self):
class BookSerializer(serializers.Serializer):
id = serializers.IntegerField()
title = serializers.CharField(max_length=100)
author = serializers.CharField(max_length=100)
defsetUp(self):
class BookSerializer(serializers.Serializer):
id = serializers.IntegerField()
title = serializers.CharField(max_length=100)
author = serializers.CharField(max_length=100)
self.BookSerializer = BookSerializer
def test_bulk_create_success(self):
"""
"""
Correct bulk update serialization should return the input data.
"""
data = [
{
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 2,
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is True
assert serializer.validated_data == data
assert serializer.errors == []
data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 2, 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is True
assert serializer.validated_data == data
assert serializer.errors == []
def test_bulk_create_errors(self):
"""
"""
Incorrect bulk create serialization should return errors.
"""
data = [
{
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 'foo',
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
expected_errors = [
{},
{},
{'id': ['A valid integer is required.']}
]
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
assert serializer.errors == expected_errors
assert serializer.validated_data == []
data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 'foo', 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
expected_errors = [ {}, {}, {'id': ['A valid integer is required.']} ]
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
assert serializer.errors == expected_errors
assert serializer.validated_data == []
def test_invalid_list_datatype(self):
"""
"""
Data containing list of incorrect data type should return errors.
"""
data = ['foo', 'bar', 'baz']
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
message = 'Invalid data. Expected a dictionary, but got str.'
expected_errors = [
{'non_field_errors': [message]},
{'non_field_errors': [message]},
{'non_field_errors': [message]}
]
assert serializer.errors == expected_errors
data = ['foo', 'bar', 'baz']
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
message = 'Invalid data. Expected a dictionary, but got str.'
expected_errors = [ {'non_field_errors': [message]}, {'non_field_errors': [message]}, {'non_field_errors': [message]} ]
assert serializer.errors == expected_errors
def test_invalid_single_datatype(self):
"""
"""
Data containing a single incorrect data type should return errors.
"""
data = 123
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
assert serializer.errors == expected_errors
data = 123
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
assert serializer.errors == expected_errors
def test_invalid_single_object(self):
"""
"""
Data containing only a single object, instead of a list of objects
should return errors.
"""
data = {
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
assert serializer.errors == expected_errors
data = { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
assert serializer.errors == expected_errors