unittest to pytest assertion convertions

This commit is contained in:
Asif Saif Uddin 2019-05-02 13:36:18 +06:00
parent 375f1b59c6
commit 3d28279d05
31 changed files with 3503 additions and 5311 deletions

2
.gitignore vendored
View File

@ -2,7 +2,7 @@
*.db *.db
*~ *~
.* .*
.py.bak *.py.bak
/site/ /site/

View File

@ -11,6 +11,7 @@ from rest_framework.response import Response
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from tests.models import BasicModel from tests.models import BasicModel
import pytest
factory = APIRequestFactory() factory = APIRequestFactory()
@ -52,7 +53,6 @@ urlpatterns = (
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionTests(TestCase):
def setUp(self): def setUp(self):
self.view = BasicView.as_view() self.view = BasicView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases['default']['ATOMIC_REQUESTS'] = True
@ -62,7 +62,6 @@ class DBTransactionTests(TestCase):
def test_no_exception_commit_transaction(self): def test_no_exception_commit_transaction(self):
request = factory.post('/') request = factory.post('/')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request) response = self.view(request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
@ -74,7 +73,6 @@ class DBTransactionTests(TestCase):
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionErrorTests(TestCase):
def setUp(self): def setUp(self):
self.view = ErrorView.as_view() self.view = ErrorView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases['default']['ATOMIC_REQUESTS'] = True
@ -95,7 +93,8 @@ class DBTransactionErrorTests(TestCase):
# 2 - insert # 2 - insert
# 3 - release savepoint # 3 - release savepoint
with transaction.atomic(): with transaction.atomic():
self.assertRaises(Exception, self.view, request) with pytest.raises(Exception):
self.view(request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
@ -104,7 +103,6 @@ class DBTransactionErrorTests(TestCase):
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionAPIExceptionTests(TestCase):
def setUp(self): def setUp(self):
self.view = APIExceptionView.as_view() self.view = APIExceptionView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases['default']['ATOMIC_REQUESTS'] = True

View File

@ -13,9 +13,6 @@ from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
class AuthTokenTests(TestCase):
def setUp(self): def setUp(self):
self.site = site self.site = site
self.user = User.objects.create_user(username='test_user') self.user = User.objects.create_user(username='test_user')
@ -40,7 +37,6 @@ class AuthTokenTests(TestCase):
assert AuthTokenSerializer(data=data).is_valid() assert AuthTokenSerializer(data=data).is_valid()
class AuthTokenCommandTests(TestCase):
def setUp(self): def setUp(self):
self.site = site self.site = site
@ -61,7 +57,6 @@ class AuthTokenCommandTests(TestCase):
first_token_key = Token.objects.first().key first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, True) AuthTokenCommand().create_user_token(self.user.username, True)
second_token_key = Token.objects.first().key second_token_key = Token.objects.first().key
assert first_token_key != second_token_key assert first_token_key != second_token_key
def test_command_do_not_reset_user_token(self): def test_command_do_not_reset_user_token(self):
@ -69,7 +64,6 @@ class AuthTokenCommandTests(TestCase):
first_token_key = Token.objects.first().key first_token_key = Token.objects.first().key
AuthTokenCommand().create_user_token(self.user.username, False) AuthTokenCommand().create_user_token(self.user.username, False)
second_token_key = Token.objects.first().key second_token_key = Token.objects.first().key
assert first_token_key == second_token_key assert first_token_key == second_token_key
def test_command_raising_error_for_invalid_user(self): def test_command_raising_error_for_invalid_user(self):
@ -81,6 +75,6 @@ class AuthTokenCommandTests(TestCase):
out = StringIO() out = StringIO()
call_command('drf_create_token', self.user.username, stdout=out) call_command('drf_create_token', self.user.username, stdout=out)
token_saved = Token.objects.first() token_saved = Token.objects.first()
self.assertIn('Generated token', out.getvalue()) assert 'Generated token' in out.getvalue()
self.assertIn(self.user.username, out.getvalue()) assert self.user.username in out.getvalue()
self.assertIn(token_saved.key, out.getvalue()) assert token_saved.key in out.getvalue()

View File

@ -17,9 +17,6 @@ from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView from rest_framework.views import APIView
class DecoratorTestCase(TestCase):
def setUp(self): def setUp(self):
self.factory = APIRequestFactory() self.factory = APIRequestFactory()
@ -31,20 +28,19 @@ class DecoratorTestCase(TestCase):
""" """
If @api_view is not applied correct, we should raise an assertion. If @api_view is not applied correct, we should raise an assertion.
""" """
@api_view @api_view
def view(request): def view(request):
return Response() return Response()
request = self.factory.get('/') request = self.factory.get('/')
self.assertRaises(AssertionError, view, request) withpytest.raises(AssertionError):
view(request)
def test_api_view_incorrect_arguments(self): def test_api_view_incorrect_arguments(self):
""" """
If @api_view is missing arguments, we should raise an assertion. If @api_view is missing arguments, we should raise an assertion.
""" """
with pytest.raises(AssertionError):
with self.assertRaises(AssertionError):
@api_view('GET') @api_view('GET')
def view(request): def view(request):
return Response() return Response()
@ -58,7 +54,6 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/') request = self.factory.get('/')
response=view(request) response=view(request)
assertresponse.status_code==status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/') request=self.factory.post('/')
response=view(request) response=view(request)
assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
@ -72,7 +67,6 @@ class DecoratorTestCase(TestCase):
request = self.factory.put('/') request = self.factory.put('/')
response=view(request) response=view(request)
assertresponse.status_code==status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/') request=self.factory.post('/')
response=view(request) response=view(request)
assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
@ -86,7 +80,6 @@ class DecoratorTestCase(TestCase):
request = self.factory.patch('/') request = self.factory.patch('/')
response=view(request) response=view(request)
assertresponse.status_code==status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
request=self.factory.post('/') request=self.factory.post('/')
response=view(request) response=view(request)
assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED assertresponse.status_code==status.HTTP_405_METHOD_NOT_ALLOWED
@ -149,7 +142,6 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/') request = self.factory.get('/')
response=view(request) response=view(request)
assertresponse.status_code==status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
response=view(request) response=view(request)
assertresponse.status_code==status.HTTP_429_TOO_MANY_REQUESTS assertresponse.status_code==status.HTTP_429_TOO_MANY_REQUESTS
@ -168,7 +160,6 @@ class DecoratorTestCase(TestCase):
assert isinstance(view.cls.schema, CustomSchema) assert isinstance(view.cls.schema, CustomSchema)
class ActionDecoratorTestCase(TestCase):
def test_defaults(self): def test_defaults(self):
@action(detail=True) @action(detail=True)
@ -179,10 +170,7 @@ class ActionDecoratorTestCase(TestCase):
asserttest_action.detailisTrue asserttest_action.detailisTrue
asserttest_action.url_path=='test_action' asserttest_action.url_path=='test_action'
asserttest_action.url_name=='test-action' asserttest_action.url_name=='test-action'
assert test_action.kwargs == { asserttest_action.kwargs=={'name':'Test action','description':'Description',}
'name': 'Test action',
'description': 'Description',
}
def test_detail_required(self): def test_detail_required(self):
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
@ -214,37 +202,21 @@ class ActionDecoratorTestCase(TestCase):
'name' and 'suffix' are mutually exclusive kwargs used for generating 'name' and 'suffix' are mutually exclusive kwargs used for generating
a view's display name. a view's display name.
""" """
# by default, generate name from method
@action(detail=True) @action(detail=True)
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {'description':None,'name':'Test action',}
'description': None,
'name': 'Test action',
}
# name kwarg supersedes name generation
@action(detail=True,name='test name') @action(detail=True,name='test name')
deftest_action(request): deftest_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {'description':None,'name':'test name',}
'description': None,
'name': 'test name',
}
# suffix kwarg supersedes name generation
@action(detail=True,suffix='Suffix') @action(detail=True,suffix='Suffix')
deftest_action(request): deftest_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {'description':None,'suffix':'Suffix',}
'description': None,
'suffix': 'Suffix',
}
# name + suffix is a conflict.
withpytest.raises(TypeError)asexcinfo: withpytest.raises(TypeError)asexcinfo:
action(detail=True, name='test name', suffix='Suffix') action(detail=True, name='test name', suffix='Suffix')
@ -279,8 +251,7 @@ class ActionDecoratorTestCase(TestCase):
def test_action(): def test_action():
raise NotImplementedError raise NotImplementedError
msg = ("Method mapping does not behave like the property decorator. You " msg = ("Method mapping does not behave like the property decorator. You ""cannot use the same method name for each mapping declaration.")
"cannot use the same method name for each mapping declaration.")
withself.assertRaisesMessage(AssertionError,msg): withself.assertRaisesMessage(AssertionError,msg):
@test_action.mapping.post @test_action.mapping.post
def test_action(): def test_action():
@ -293,11 +264,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
assert len(record) == 1 assert len(record) == 1
assert str(record[0].message) == ( assertstr(record[0].message)==("`detail_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=True)` instead.")
"`detail_route` is deprecated and will be removed in "
"3.10 in favor of `action`, which accepts a `detail` bool. Use "
"`@action(detail=True)` instead."
)
def test_list_route_deprecation(self): def test_list_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record: with pytest.warns(RemovedInDRF310Warning) as record:
@ -306,11 +273,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
assert len(record) == 1 assert len(record) == 1
assert str(record[0].message) == ( assertstr(record[0].message)==("`list_route` is deprecated and will be removed in ""3.10 in favor of `action`, which accepts a `detail` bool. Use ""`@action(detail=False)` instead.")
"`list_route` is deprecated and will be removed in "
"3.10 in favor of `action`, which accepts a `detail` bool. Use "
"`@action(detail=False)` instead."
)
def test_route_url_name_from_path(self): def test_route_url_name_from_path(self):
# pre-3.8 behavior was to base the `url_name` off of the `url_path` # pre-3.8 behavior was to base the `url_name` off of the `url_path`

View File

@ -69,9 +69,6 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
</code></pre> </code></pre>
<p>indented</p> <p>indented</p>
<h2 id="hash-style-header">hash style header</h2>%s""" <h2 id="hash-style-header">hash style header</h2>%s"""
class TestViewNamesAndDescriptions(TestCase):
def test_view_name_uses_class_name(self): def test_view_name_uses_class_name(self):
""" """
Ensure view names are based on the class name. Ensure view names are based on the class name.
@ -150,9 +147,6 @@ class TestViewNamesAndDescriptions(TestCase):
See: https://github.com/encode/django-rest-framework/issues/1708 See: https://github.com/encode/django-rest-framework/issues/1708
""" """
# use a mock object instead of gettext_lazy to ensure that we can't end
# up with a test case string in our l10n catalog
class MockLazyStr: class MockLazyStr:
def __init__(self, string): def __init__(self, string):
self.s = string self.s = string
@ -171,16 +165,8 @@ class TestViewNamesAndDescriptions(TestCase):
""" """
if apply_markdown: if apply_markdown:
md_applied = apply_markdown(DESCRIPTION) md_applied = apply_markdown(DESCRIPTION)
gte_21_match = ( gte_21_match = ( md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
md_applied == ( lt_21_match = ( md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or md_applied == ( MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
lt_21_match = (
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
assert gte_21_match or lt_21_match assert gte_21_match or lt_21_match

View File

@ -15,11 +15,9 @@ class MockList:
return [1, 2, 3] return [1, 2, 3]
class JSONEncoderTests(TestCase):
""" """
Tests the JSONEncoder method Tests the JSONEncoder method
""" """
defsetUp(self): defsetUp(self):
self.encoder = JSONEncoder() self.encoder = JSONEncoder()

View File

@ -7,72 +7,42 @@ from rest_framework.exceptions import (
server_error server_error
) )
class ExceptionTestCase(TestCase):
def test_get_error_details(self): def test_get_error_details(self):
example = "string" example = "string"
lazy_example = _(example) lazy_example = _(example)
assert _get_error_details(lazy_example) == example assert _get_error_details(lazy_example) == example
assert isinstance( _get_error_details(lazy_example), ErrorDetail )
assert isinstance(
_get_error_details(lazy_example),
ErrorDetail
)
assert _get_error_details({'nested': lazy_example})['nested'] == example assert _get_error_details({'nested': lazy_example})['nested'] == example
assert isinstance( _get_error_details({'nested': lazy_example})['nested'], ErrorDetail )
assert isinstance(
_get_error_details({'nested': lazy_example})['nested'],
ErrorDetail
)
assert _get_error_details([[lazy_example]])[0][0] == example assert _get_error_details([[lazy_example]])[0][0] == example
assert isinstance( _get_error_details([[lazy_example]])[0][0], ErrorDetail )
assert isinstance(
_get_error_details([[lazy_example]])[0][0],
ErrorDetail
)
def test_get_full_details_with_throttling(self): def test_get_full_details_with_throttling(self):
exception = Throttled() exception = Throttled()
assert exception.get_full_details() == { assert exception.get_full_details() == { 'message': 'Request was throttled.', 'code': 'throttled'}
'message': 'Request was throttled.', 'code': 'throttled'}
exception = Throttled(wait=2) exception = Throttled(wait=2)
assert exception.get_full_details() == { assert exception.get_full_details() == { 'message': 'Request was throttled. Expected available in {} seconds.'.format(2), 'code': 'throttled'}
'message': 'Request was throttled. Expected available in {} seconds.'.format(2),
'code': 'throttled'}
exception = Throttled(wait=2, detail='Slow down!') exception = Throttled(wait=2, detail='Slow down!')
assert exception.get_full_details() == { assert exception.get_full_details() == { 'message': 'Slow down! Expected available in {} seconds.'.format(2), 'code': 'throttled'}
'message': 'Slow down! Expected available in {} seconds.'.format(2),
'code': 'throttled'}
class ErrorDetailTests(TestCase):
def test_eq(self): def test_eq(self):
assert ErrorDetail('msg') == ErrorDetail('msg') assert ErrorDetail('msg') == ErrorDetail('msg')
assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code') assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
assert ErrorDetail('msg') == 'msg' assert ErrorDetail('msg') == 'msg'
assert ErrorDetail('msg', 'code') == 'msg' assert ErrorDetail('msg', 'code') == 'msg'
def test_ne(self): def test_ne(self):
assert ErrorDetail('msg1') != ErrorDetail('msg2') assert ErrorDetail('msg1') != ErrorDetail('msg2')
assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid') assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
assert ErrorDetail('msg1') != 'msg2' assert ErrorDetail('msg1') != 'msg2'
assert ErrorDetail('msg1', 'code') != 'msg2' assert ErrorDetail('msg1', 'code') != 'msg2'
def test_repr(self): def test_repr(self):
assert repr(ErrorDetail('msg1')) == \ assert repr(ErrorDetail('msg1')) == 'ErrorDetail(string={!r}, code=None)'.format('msg1')
'ErrorDetail(string={!r}, code=None)'.format('msg1') assert repr(ErrorDetail('msg1', 'code')) == 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
assert repr(ErrorDetail('msg1', 'code')) == \
'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
def test_str(self): def test_str(self):
assert str(ErrorDetail('msg1')) == 'msg1' assert str(ErrorDetail('msg1')) == 'msg1'
@ -83,13 +53,12 @@ class ErrorDetailTests(TestCase):
assert hash(ErrorDetail('msg', 'code')) == hash('msg') assert hash(ErrorDetail('msg', 'code')) == hash('msg')
class TranslationTests(TestCase):
@translation.override('fr') @translation.override('fr')
deftest_message(self): deftest_message(self):
# this test largely acts as a sanity test to ensure the translation files are present. # this test largely acts as a sanity test to ensure the translation files are present.
self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.') assert _('A server error occurred.') == 'Une erreur du serveur est survenue.'
self.assertEqual(str(APIException()), 'Une erreur du serveur est survenue.') assert str(APIException()) == 'Une erreur du serveur est survenue.'
def test_server_error(): def test_server_error():

View File

@ -1134,7 +1134,6 @@ class TestNoStringCoercionDecimalField(FieldValues):
) )
class TestLocalizedDecimalField(TestCase):
@override_settings(USE_L10N=True, LANGUAGE_CODE='pl') @override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
deftest_to_internal_value(self): deftest_to_internal_value(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True) field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
@ -1150,7 +1149,6 @@ class TestLocalizedDecimalField(TestCase):
assert isinstance(field.to_representation(Decimal('1.1')), str) assert isinstance(field.to_representation(Decimal('1.1')), str)
class TestQuantizedValueForDecimal(TestCase):
def test_int_quantized_value_for_decimal(self): def test_int_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2) field = serializers.DecimalField(max_digits=4, decimal_places=2)
value = field.to_internal_value(12).as_tuple() value = field.to_internal_value(12).as_tuple()
@ -1185,11 +1183,9 @@ class TestNoDecimalPlaces(FieldValues):
field = serializers.DecimalField(max_digits=6, decimal_places=None) field = serializers.DecimalField(max_digits=6, decimal_places=None)
class TestRoundingDecimalField(TestCase):
def test_valid_rounding(self): def test_valid_rounding(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP) field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
assert field.to_representation(Decimal('1.234')) == '1.24' assert field.to_representation(Decimal('1.234')) == '1.24'
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN) field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
assert field.to_representation(Decimal('1.234')) == '1.23' assert field.to_representation(Decimal('1.234')) == '1.23'
@ -1369,11 +1365,9 @@ class TestTZWithDateTimeField(FieldValues):
@override_settings(TIME_ZONE='UTC', USE_TZ=True) @override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestDefaultTZDateTimeField(TestCase):
""" """
Test the current/default timezone handling in `DateTimeField`. Test the current/default timezone handling in `DateTimeField`.
""" """
@classmethod @classmethod
defsetup_class(cls): defsetup_class(cls):
cls.field = serializers.DateTimeField() cls.field = serializers.DateTimeField()
@ -1392,7 +1386,6 @@ class TestDefaultTZDateTimeField(TestCase):
@pytest.mark.skipif(pytz is None, reason='pytz not installed') @pytest.mark.skipif(pytz is None, reason='pytz not installed')
@override_settings(TIME_ZONE='UTC', USE_TZ=True) @override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestCustomTimezoneForDateTimeField(TestCase):
@classmethod @classmethod
defsetup_class(cls): defsetup_class(cls):
@ -1402,12 +1395,10 @@ class TestCustomTimezoneForDateTimeField(TestCase):
def test_should_render_date_time_in_default_timezone(self): def test_should_render_date_time_in_default_timezone(self):
field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format) field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc) dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
with override(self.kolkata): with override(self.kolkata):
rendered_date = field.to_representation(dt) rendered_date = field.to_representation(dt)
rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format) rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format)
assertrendered_date==rendered_date_in_timezone assertrendered_date==rendered_date_in_timezone

View File

@ -13,12 +13,9 @@ from rest_framework.compat import coreschema
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
class BaseFilterTests(TestCase):
def setUp(self): def setUp(self):
self.original_coreapi = filters.coreapi self.original_coreapi = filters.coreapi
filters.coreapi = True # mock it, because not None value needed filters.coreapi = True
self.filter_backend = filters.BaseFilterBackend() self.filter_backend = filters.BaseFilterBackend()
def tearDown(self): def tearDown(self):
@ -48,7 +45,6 @@ class SearchFilterSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterTests(TestCase):
def setUp(self): def setUp(self):
# Sequence of title/text is: # Sequence of title/text is:
# #
@ -58,11 +54,7 @@ class SearchFilterTests(TestCase):
# ... # ...
for idx in range(10): for idx in range(10):
title = 'z' * (idx + 1) title = 'z' * (idx + 1)
text = ( text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModel(title=title, text=text).save() SearchFilterModel(title=title, text=text).save()
def test_search(self): def test_search(self):
@ -75,10 +67,7 @@ class SearchFilterTests(TestCase):
view = SearchListView.as_view() view = SearchListView.as_view()
request=factory.get('/',{'search':'b'}) request=factory.get('/',{'search':'b'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self): def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
@ -89,8 +78,7 @@ class SearchFilterTests(TestCase):
view = SearchListView.as_view() view = SearchListView.as_view()
request=factory.get('/') request=factory.get('/')
response=view(request) response=view(request)
expected = SearchFilterSerializer(SearchFilterModel.objects.all(), expected=SearchFilterSerializer(SearchFilterModel.objects.all(),many=True).data
many=True).data
assertresponse.data==expected assertresponse.data==expected
def test_exact_search(self): def test_exact_search(self):
@ -103,9 +91,7 @@ class SearchFilterTests(TestCase):
view = SearchListView.as_view() view = SearchListView.as_view()
request=factory.get('/',{'search':'zzz'}) request=factory.get('/',{'search':'zzz'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
def test_startswith_search(self): def test_startswith_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
@ -117,9 +103,7 @@ class SearchFilterTests(TestCase):
view = SearchListView.as_view() view = SearchListView.as_view()
request=factory.get('/',{'search':'b'}) request=factory.get('/',{'search':'b'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_regexp_search(self): def test_regexp_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
@ -131,14 +115,11 @@ class SearchFilterTests(TestCase):
view = SearchListView.as_view() view = SearchListView.as_view()
request=factory.get('/',{'search':'z{2} ^b'}) request=factory.get('/',{'search':'z{2} ^b'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':2,'title':'zz','text':'bcd'}]
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_with_nonstandard_search_param(self): def test_search_with_nonstandard_search_param(self):
with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
reload_module(filters) reload_module(filters)
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
@ -148,10 +129,7 @@ class SearchFilterTests(TestCase):
view = SearchListView.as_view() view = SearchListView.as_view()
request=factory.get('/',{'query':'b'}) request=factory.get('/',{'query':'b'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'z','text':'abc'},{'id':2,'title':'zz','text':'bcd'}]
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
reload_module(filters) reload_module(filters)
@ -173,12 +151,9 @@ class SearchFilterTests(TestCase):
request=factory.get('/',{'search':r'^\w{3}$'}) request=factory.get('/',{'search':r'^\w{3}$'})
response=view(request) response=view(request)
assertlen(response.data)==10 assertlen(response.data)==10
request=factory.get('/',{'search':r'^\w{3}$','title_only':'true'}) request=factory.get('/',{'search':r'^\w{3}$','title_only':'true'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'zzz','text':'cde'}]
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
class AttributeModel(models.Model): class AttributeModel(models.Model):
@ -196,20 +171,13 @@ class SearchFilterFkSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterFkTests(TestCase):
def test_must_call_distinct(self): def test_must_call_distinct(self):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix] )
SearchFilterModelFk._meta, assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%stitle" % prefix, "%sattribute__label" % prefix] )
["%stitle" % prefix]
)
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%stitle" % prefix, "%sattribute__label" % prefix]
)
def test_must_call_distinct_restores_meta_for_each_field(self): def test_must_call_distinct_restores_meta_for_each_field(self):
# In this test case the attribute of the fk model comes first in the # In this test case the attribute of the fk model comes first in the
@ -217,10 +185,7 @@ class SearchFilterFkTests(TestCase):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelFk._meta, ["%sattribute__label" % prefix, "%stitle" % prefix] )
SearchFilterModelFk._meta,
["%sattribute__label" % prefix, "%stitle" % prefix]
)
class SearchFilterModelM2M(models.Model): class SearchFilterModelM2M(models.Model):
@ -235,7 +200,6 @@ class SearchFilterM2MSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterM2MTests(TestCase):
def setUp(self): def setUp(self):
# Sequence of title/text/attributes is: # Sequence of title/text/attributes is:
# #
@ -249,11 +213,7 @@ class SearchFilterM2MTests(TestCase):
for idx in range(10): for idx in range(10):
title = 'z' * (idx + 1) title = 'z' * (idx + 1)
text = ( text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModelM2M(title=title, text=text).save() SearchFilterModelM2M(title=title, text=text).save()
SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3) SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3)
@ -273,15 +233,8 @@ class SearchFilterM2MTests(TestCase):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [''] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix] )
SearchFilterModelM2M._meta, assert filter_.must_call_distinct( SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix] )
["%stitle" % prefix]
)
assert filter_.must_call_distinct(
SearchFilterModelM2M._meta,
["%stitle" % prefix, "%sattributes__label" % prefix]
)
class Blog(models.Model): class Blog(models.Model):
@ -300,18 +253,13 @@ class BlogSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class SearchFilterToManyTests(TestCase):
@classmethod @classmethod
defsetUpTestData(cls): defsetUpTestData(cls):
b1 = Blog.objects.create(name='Blog 1') b1 = Blog.objects.create(name='Blog 1')
b2 = Blog.objects.create(name='Blog 2') b2 = Blog.objects.create(name='Blog 2')
# Multiple entries on Lennon published in 1979 - distinct should deduplicate
Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1)) Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1)) Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
# Entry on Lennon *and* a separate entry in 1979 - should not match
Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1)) Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1)) Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
@ -336,7 +284,6 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
fields = ('title', 'text', 'title_text') fields = ('title', 'text', 'title_text')
class SearchFilterAnnotatedFieldTests(TestCase):
@classmethod @classmethod
defsetUpTestData(cls): defsetUpTestData(cls):
SearchFilterModel.objects.create(title='abc', text='def') SearchFilterModel.objects.create(title='abc', text='def')
@ -344,11 +291,7 @@ class SearchFilterAnnotatedFieldTests(TestCase):
def test_search_in_annotated_field(self): def test_search_in_annotated_field(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.annotate( queryset = SearchFilterModel.objects.annotate( title_text=Upper( Concat(models.F('title'), models.F('text')) ) ).all()
title_text=Upper(
Concat(models.F('title'), models.F('text'))
)
).all()
serializer_class = SearchFilterAnnotatedSerializer serializer_class = SearchFilterAnnotatedSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title_text',) search_fields = ('title_text',)
@ -403,7 +346,6 @@ class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
fields = '__all__' fields = '__all__'
class OrderingFilterTests(TestCase):
def setUp(self): def setUp(self):
# Sequence of title/text is: # Sequence of title/text is:
# #
@ -411,16 +353,8 @@ class OrderingFilterTests(TestCase):
# yxw bcd # yxw bcd
# xwv cde # xwv cde
for idx in range(3): for idx in range(3):
title = ( title = ( chr(ord('z') - idx) + chr(ord('y') - idx) + chr(ord('x') - idx) )
chr(ord('z') - idx) + text = ( chr(idx + ord('a')) + chr(idx + ord('b')) + chr(idx + ord('c')) )
chr(ord('y') - idx) +
chr(ord('x') - idx)
)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
OrderingFilterModel(title=title, text=text).save() OrderingFilterModel(title=title, text=text).save()
def test_ordering(self): def test_ordering(self):
@ -434,11 +368,7 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'text'}) request=factory.get('/',{'ordering':'text'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
def test_reverse_ordering(self): def test_reverse_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -451,11 +381,7 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'-text'}) request=factory.get('/',{'ordering':'-text'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_incorrecturl_extrahyphens_ordering(self): def test_incorrecturl_extrahyphens_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -468,11 +394,7 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'--text'}) request=factory.get('/',{'ordering':'--text'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_incorrectfield_ordering(self): def test_incorrectfield_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -485,11 +407,7 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'foobar'}) request=factory.get('/',{'ordering':'foobar'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_default_ordering(self): def test_default_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -502,11 +420,7 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('') request=factory.get('')
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_default_ordering_using_string(self): def test_default_ordering_using_string(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -519,21 +433,14 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('') request=factory.get('')
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},{'id':1,'title':'zyx','text':'abc'},]
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
def test_ordering_by_aggregate_field(self): def test_ordering_by_aggregate_field(self):
# create some related models to aggregate order by # create some related models to aggregate order by
num_objs = [2, 5, 3] num_objs = [2, 5, 3]
for obj, num_relateds in zip(OrderingFilterModel.objects.all(), for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs):
num_objs):
for _ in range(num_relateds): for _ in range(num_relateds):
new_related = OrderingFilterRelatedModel( new_related = OrderingFilterRelatedModel( related_object=obj )
related_object=obj
)
new_related.save() new_related.save()
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -541,25 +448,17 @@ class OrderingFilterTests(TestCase):
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = 'title' ordering = 'title'
ordering_fields = '__all__' ordering_fields = '__all__'
queryset = OrderingFilterModel.objects.all().annotate( queryset = OrderingFilterModel.objects.all().annotate( models.Count("relateds"))
models.Count("relateds"))
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'relateds__count'}) request=factory.get('/',{'ordering':'relateds__count'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':3,'title':'xwv','text':'cde'},{'id':2,'title':'yxw','text':'bcd'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
]
def test_ordering_by_dotted_source(self): def test_ordering_by_dotted_source(self):
for index, obj in enumerate(OrderingFilterModel.objects.all()): for index, obj in enumerate(OrderingFilterModel.objects.all()):
OrderingFilterRelatedModel.objects.create( OrderingFilterRelatedModel.objects.create( related_object=obj, index=index )
related_object=obj,
index=index
)
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
serializer_class = OrderingDottedRelatedSerializer serializer_class = OrderingDottedRelatedSerializer
@ -569,24 +468,14 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'related_object__text'}) request=factory.get('/',{'ordering':'related_object__text'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'related_title':'zyx','related_text':'abc','index':0},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'xwv','related_text':'cde','index':2},]
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
]
request=factory.get('/',{'ordering':'-index'}) request=factory.get('/',{'ordering':'-index'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'related_title':'xwv','related_text':'cde','index':2},{'related_title':'yxw','related_text':'bcd','index':1},{'related_title':'zyx','related_text':'abc','index':0},]
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
]
def test_ordering_with_nonstandard_ordering_param(self): def test_ordering_with_nonstandard_ordering_param(self):
with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
reload_module(filters) reload_module(filters)
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
@ -597,11 +486,7 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'order':'text'}) request=factory.get('/',{'order':'text'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
reload_module(filters) reload_module(filters)
@ -615,7 +500,6 @@ class OrderingFilterTests(TestCase):
request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html') request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html')
view=OrderingListView.as_view() view=OrderingListView.as_view()
response=view(request) response=view(request)
self.assertContains(response,'verbose title') self.assertContains(response,'verbose title')
def test_ordering_with_overridden_get_serializer_class(self): def test_ordering_with_overridden_get_serializer_class(self):
@ -623,20 +507,13 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ('title',)
# note: no ordering_fields and serializer_class specified
def get_serializer_class(self): def get_serializer_class(self):
return OrderingFilterSerializer return OrderingFilterSerializer
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'text'}) request=factory.get('/',{'ordering':'text'})
response=view(request) response=view(request)
assert response.data == [ assertresponse.data==[{'id':1,'title':'zyx','text':'abc'},{'id':2,'title':'yxw','text':'bcd'},{'id':3,'title':'xwv','text':'cde'},]
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
]
def test_ordering_with_improper_configuration(self): def test_ordering_with_improper_configuration(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
@ -648,7 +525,7 @@ class OrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'text'}) request=factory.get('/',{'ordering':'text'})
with self.assertRaises(ImproperlyConfigured): withpytest.raises(ImproperlyConfigured):
view(request) view(request)
@ -684,7 +561,6 @@ class SensitiveDataSerializer3(serializers.ModelSerializer):
fields = ('id', 'user') fields = ('id', 'user')
class SensitiveOrderingFilterTests(TestCase):
def setUp(self): def setUp(self):
for idx in range(3): for idx in range(3):
username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
@ -692,11 +568,7 @@ class SensitiveOrderingFilterTests(TestCase):
SensitiveOrderingFilterModel(username=username, password=password).save() SensitiveOrderingFilterModel(username=username, password=password).save()
def test_order_by_serializer_fields(self): def test_order_by_serializer_fields(self):
for serializer_cls in [ for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
SensitiveDataSerializer1,
SensitiveDataSerializer2,
SensitiveDataSerializer3
]:
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
@ -705,25 +577,16 @@ class SensitiveOrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'-username'}) request=factory.get('/',{'ordering':'-username'})
response=view(request) response=view(request)
ifserializer_cls==SensitiveDataSerializer3: ifserializer_cls==SensitiveDataSerializer3:
username_field = 'user' username_field = 'user'
else: else:
username_field = 'username' username_field = 'username'
# Note: Inverse username ordering correctly applied. # Note: Inverse username ordering correctly applied.
assert response.data == [ assert response.data == [{'id':3,username_field:'userC'},{'id':2,username_field:'userB'},{'id':1,username_field:'userA'},]
{'id': 3, username_field: 'userC'},
{'id': 2, username_field: 'userB'},
{'id': 1, username_field: 'userA'},
]
def test_cannot_order_by_non_serializer_fields(self): def test_cannot_order_by_non_serializer_fields(self):
for serializer_cls in [ for serializer_cls in [ SensitiveDataSerializer1, SensitiveDataSerializer2, SensitiveDataSerializer3 ]:
SensitiveDataSerializer1,
SensitiveDataSerializer2,
SensitiveDataSerializer3
]:
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
@ -732,15 +595,10 @@ class SensitiveOrderingFilterTests(TestCase):
view = OrderingListView.as_view() view = OrderingListView.as_view()
request=factory.get('/',{'ordering':'password'}) request=factory.get('/',{'ordering':'password'})
response=view(request) response=view(request)
ifserializer_cls==SensitiveDataSerializer3: ifserializer_cls==SensitiveDataSerializer3:
username_field = 'user' username_field = 'user'
else: else:
username_field = 'username' username_field = 'username'
# Note: The passwords are not in order. Default ordering is used. # Note: The passwords are not in order. Default ordering is used.
assert response.data == [ assert response.data == [{'id':1,username_field:'userA'},{'id':2,username_field:'userB'},{'id':3,username_field:'userC'},]
{'id': 1, username_field: 'userA'}, # PassB
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
]

View File

@ -23,9 +23,7 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_generateschema') @override_settings(ROOT_URLCONF='tests.test_generateschema')
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class GenerateSchemaTests(TestCase):
"""Tests for management command generateschema.""" """Tests for management command generateschema."""
defsetUp(self): defsetUp(self):
self.out = io.StringIO() self.out = io.StringIO()
@ -42,45 +40,16 @@ class GenerateSchemaTests(TestCase):
servers: servers:
- url: http://api.sample.com/ - url: http://api.sample.com/
""" """
call_command('generateschema', call_command('generateschema', '--title=SampleAPI', '--url=http://api.sample.com', '--description=Sample description', stdout=self.out)
'--title=SampleAPI', assert formatting.dedent(expected_out) in self.out.getvalue()
'--url=http://api.sample.com',
'--description=Sample description',
stdout=self.out)
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
def test_renders_openapi_json_schema(self): def test_renders_openapi_json_schema(self):
expected_out = { expected_out = { "openapi": "3.0.0", "info": { "version": "", "title": "", "description": "" }, "servers": [ { "url": "" } ], "paths": { "/": { "get": { "operationId": "list" } } } }
"openapi": "3.0.0", call_command('generateschema', '--format=openapi-json', stdout=self.out)
"info": {
"version": "",
"title": "",
"description": ""
},
"servers": [
{
"url": ""
}
],
"paths": {
"/": {
"get": {
"operationId": "list"
}
}
}
}
call_command('generateschema',
'--format=openapi-json',
stdout=self.out)
out_json = json.loads(self.out.getvalue()) out_json = json.loads(self.out.getvalue())
assert out_json == expected_out
self.assertDictEqual(out_json, expected_out)
def test_renders_corejson_schema(self): def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema', call_command('generateschema', '--format=corejson', stdout=self.out)
'--format=corejson', assert expected_out in self.out.getvalue()
stdout=self.out)
self.assertIn(expected_out, self.out.getvalue())

View File

@ -75,7 +75,6 @@ class SlugBasedInstanceView(InstanceView):
# Tests # Tests
class TestRootView(TestCase):
def setUp(self): def setUp(self):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
@ -84,10 +83,7 @@ class TestRootView(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view=RootView.as_view() self.view=RootView.as_view()
def test_get_root_view(self): def test_get_root_view(self):
@ -168,9 +164,6 @@ class TestRootView(TestCase):
EXPECTED_QUERIES_FOR_PUT = 2 EXPECTED_QUERIES_FOR_PUT = 2
class TestInstanceView(TestCase):
def setUp(self): def setUp(self):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
@ -179,10 +172,7 @@ class TestInstanceView(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out') self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view=InstanceView.as_view() self.view=InstanceView.as_view()
self.slug_based_view=SlugBasedInstanceView.as_view() self.slug_based_view=SlugBasedInstanceView.as_view()
@ -226,7 +216,6 @@ class TestInstanceView(TestCase):
""" """
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.patch('/1', data, format='json') request = factory.patch('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -314,7 +303,6 @@ class TestInstanceView(TestCase):
assert expected_error in response.rendered_content.decode() assert expected_error in response.rendered_content.decode()
class TestFKInstanceView(TestCase):
def setUp(self): def setUp(self):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
@ -326,19 +314,14 @@ class TestFKInstanceView(TestCase):
ForeignKeySource(name='source_' + item, target=t).save() ForeignKeySource(name='source_' + item, target=t).save()
self.objects = ForeignKeySource.objects self.objects = ForeignKeySource.objects
self.data = [ self.data=[{'id':obj.id,'name':obj.name}forobjinself.objects.all()]
{'id': obj.id, 'name': obj.name}
for obj in self.objects.all()
]
self.view=FKInstanceView.as_view() self.view=FKInstanceView.as_view()
class TestOverriddenGetObject(TestCase):
""" """
Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
queryset/model mechanism but instead overrides get_object() queryset/model mechanism but instead overrides get_object()
""" """
defsetUp(self): defsetUp(self):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
@ -347,17 +330,12 @@ class TestOverriddenGetObject(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
classOverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): classOverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
""" """
Example detail view for override of get_object(). Example detail view for override of get_object().
""" """
serializer_class = BasicSerializer serializer_class = BasicSerializer
def get_object(self): def get_object(self):
pk = int(self.kwargs['pk']) pk = int(self.kwargs['pk'])
return get_object_or_404(BasicModel.objects.all(), id=pk) return get_object_or_404(BasicModel.objects.all(), id=pk)
@ -388,7 +366,6 @@ class CommentView(generics.ListCreateAPIView):
model = Comment model = Comment
class TestCreateModelWithAutoNowAddField(TestCase):
def setUp(self): def setUp(self):
self.objects = Comment.objects self.objects = Comment.objects
self.view = CommentView.as_view() self.view = CommentView.as_view()
@ -432,7 +409,6 @@ class ExampleView(generics.ListCreateAPIView):
queryset = ClassA.objects.all() queryset = ClassA.objects.all()
class TestM2MBrowsableAPI(TestCase):
def test_m2m_in_browsable_api(self): def test_m2m_in_browsable_api(self):
""" """
Test for particularly ugly regression with m2m in browsable API Test for particularly ugly regression with m2m in browsable API
@ -476,7 +452,6 @@ class DynamicSerializerView(generics.ListCreateAPIView):
return DynamicSerializer return DynamicSerializer
class TestFilterBackendAppliedToViews(TestCase):
def setUp(self): def setUp(self):
""" """
Create 3 BasicModel instances to filter on. Create 3 BasicModel instances to filter on.
@ -485,10 +460,7 @@ class TestFilterBackendAppliedToViews(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data=[{'id':obj.id,'text':obj.text}forobjinself.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
def test_get_root_view_filters_by_name_with_filter_backend(self): def test_get_root_view_filters_by_name_with_filter_backend(self):
""" """
@ -543,11 +515,9 @@ class TestFilterBackendAppliedToViews(TestCase):
assert 'field_a' not in content assert 'field_a' not in content
class TestGuardedQueryset(TestCase):
def test_guarded_queryset(self): def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView): class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all() queryset = BasicModel.objects.all()
def get(self, request): def get(self, request):
return Response(list(self.queryset)) return Response(list(self.queryset))
@ -557,7 +527,6 @@ class TestGuardedQueryset(TestCase):
view(request).render() view(request).render()
class ApiViewsTests(TestCase):
def test_create_api_view_post(self): def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView): class MockCreateApiView(generics.CreateAPIView):
@ -648,15 +617,12 @@ class ApiViewsTests(TestCase):
assertview.call_args==data assertview.call_args==data
class GetObjectOr404Tests(TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar') self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
def test_get_object_or_404_with_valid_uuid(self): def test_get_object_or_404_with_valid_uuid(self):
obj = generics.get_object_or_404( obj = generics.get_object_or_404( UUIDForeignKeyTarget, pk=self.uuid_object.pk )
UUIDForeignKeyTarget, pk=self.uuid_object.pk
)
assert obj == self.uuid_object assert obj == self.uuid_object
def test_get_object_or_404_with_invalid_string_for_uuid(self): def test_get_object_or_404_with_invalid_string_for_uuid(self):

View File

@ -42,7 +42,6 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer') @override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
class TemplateHTMLRendererTests(TestCase):
def setUp(self): def setUp(self):
class MockResponse: class MockResponse:
template_name = None template_name = None
@ -54,7 +53,6 @@ class TemplateHTMLRendererTests(TestCase):
Monkeypatch get_template Monkeypatch get_template
""" """
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name, dirs=None): def get_template(template_name, dirs=None):
if template_name == 'example.html': if template_name == 'example.html':
return engines['django'].from_string("example: {{ object }}") return engines['django'].from_string("example: {{ object }}")
@ -77,19 +75,19 @@ class TemplateHTMLRendererTests(TestCase):
def test_simple_html_view(self): def test_simple_html_view(self):
response = self.client.get('/') response = self.client.get('/')
self.assertContains(response, "example: foobar") self.assertContains(response, "example: foobar")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_not_found_html_view(self): def test_not_found_html_view(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
self.assertEqual(response.content, b"404 Not Found") assert response.content == b"404 Not Found"
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
def test_permission_denied_html_view(self): def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(response.content, b"403 Forbidden") assert response.content == b"403 Forbidden"
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'
# 2 tests below are based on order of if statements in corresponding method # 2 tests below are based on order of if statements in corresponding method
# of TemplateHTMLRenderer # of TemplateHTMLRenderer
@ -101,7 +99,6 @@ class TemplateHTMLRendererTests(TestCase):
def test_get_template_names_returns_view_template_name(self): def test_get_template_names_returns_view_template_name(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
class MockResponse: class MockResponse:
template_name = None template_name = None
@ -112,12 +109,9 @@ class TemplateHTMLRendererTests(TestCase):
class MockView2: class MockView2:
template_name = 'template from template_name attribute' template_name = 'template from template_name attribute'
template_name = renderer.get_template_names(self.mock_response, template_name = renderer.get_template_names(self.mock_response,MockView())
MockView())
asserttemplate_name==['template from get_template_names method'] asserttemplate_name==['template from get_template_names method']
template_name=renderer.get_template_names(self.mock_response,MockView2())
template_name = renderer.get_template_names(self.mock_response,
MockView2())
asserttemplate_name==['template from template_name attribute'] asserttemplate_name==['template from template_name attribute']
def test_get_template_names_raises_error_if_no_template_found(self): def test_get_template_names_raises_error_if_no_template_found(self):
@ -127,13 +121,11 @@ class TemplateHTMLRendererTests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer') @override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
class TemplateHTMLRendererExceptionTests(TestCase):
def setUp(self): def setUp(self):
""" """
Monkeypatch get_template Monkeypatch get_template
""" """
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name): def get_template(template_name):
if template_name == '404.html': if template_name == '404.html':
return engines['django'].from_string("404: {{ detail }}") return engines['django'].from_string("404: {{ detail }}")
@ -151,13 +143,12 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self): def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
self.assertTrue(response.content in ( assert response.content in ( b"404: Not found", b"404 Not Found")
b"404: Not found", b"404 Not Found")) assert response['Content-Type'] == 'text/html; charset=utf-8'
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self): def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertTrue(response.content in (b"403: Permission denied", b"403 Forbidden")) assert response.content in (b"403: Permission denied", b"403 Forbidden")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') assert response['Content-Type'] == 'text/html; charset=utf-8'

View File

@ -34,7 +34,6 @@ urlpatterns = [
@override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks') @override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks')
class TestLazyHyperlinkNames(TestCase):
def setUp(self): def setUp(self):
self.example = Example.objects.create(text='foo') self.example = Example.objects.create(text='foo')

View File

@ -308,7 +308,6 @@ class TestMetadata:
assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer) assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer)
class TestSimpleMetadataFieldInfo(TestCase):
def test_null_boolean_field_info_type(self): def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField()) field_info = options.get_field_info(serializers.NullBooleanField())
@ -318,13 +317,10 @@ class TestSimpleMetadataFieldInfo(TestCase):
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
BasicModel.objects.create() BasicModel.objects.create()
with self.assertNumQueries(0): with self.assertNumQueries(0):
field_info = options.get_field_info( field_info = options.get_field_info( serializers.RelatedField(queryset=BasicModel.objects.all()) )
serializers.RelatedField(queryset=BasicModel.objects.all())
)
assert 'choices' not in field_info assert 'choices' not in field_info
class TestModelSerializerMetadata(TestCase):
def test_read_only_primary_key_related_field(self): def test_read_only_primary_key_related_field(self):
""" """
On generic views OPTIONS should return an 'actions' key with metadata On generic views OPTIONS should return an 'actions' key with metadata
@ -341,7 +337,6 @@ class TestModelSerializerMetadata(TestCase):
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
children = serializers.PrimaryKeyRelatedField(read_only=True, many=True) children = serializers.PrimaryKeyRelatedField(read_only=True, many=True)
class Meta: class Meta:
model = Parent model = Parent
fields = '__all__' fields = '__all__'
@ -356,50 +351,6 @@ class TestModelSerializerMetadata(TestCase):
view = ExampleView.as_view() view = ExampleView.as_view()
response=view(request=request) response=view(request=request)
expected = { expected={'name':'Example','description':'Example view.','renders':['application/json','text/html'],'parses':['application/json','application/x-www-form-urlencoded','multipart/form-data'],'actions':{'POST':{'id':{'type':'integer','required':False,'read_only':True,'label':'ID'},'children':{'type':'field','required':False,'read_only':True,'label':'Children'},'integer_field':{'type':'integer','required':True,'read_only':False,'label':'Integer field','min_value':1,'max_value':1000},'name':{'type':'string','required':False,'read_only':False,'label':'Name','max_length':100}}}}
'name': 'Example',
'description': 'Example view.',
'renders': [
'application/json',
'text/html'
],
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
],
'actions': {
'POST': {
'id': {
'type': 'integer',
'required': False,
'read_only': True,
'label': 'ID'
},
'children': {
'type': 'field',
'required': False,
'read_only': True,
'label': 'Children'
},
'integer_field': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'Integer field',
'min_value': 1,
'max_value': 1000
},
'name': {
'type': 'string',
'required': False,
'read_only': False,
'label': 'Name',
'max_length': 100
}
}
}
}
assertresponse.status_code==status.HTTP_200_OK assertresponse.status_code==status.HTTP_200_OK
assertresponse.data==expected assertresponse.data==expected

View File

@ -110,21 +110,15 @@ class UniqueChoiceModel(models.Model):
name = models.CharField(max_length=254, unique=True, choices=CHOICES) name = models.CharField(max_length=254, unique=True, choices=CHOICES)
class TestModelSerializer(TestCase):
def test_create_method(self): def test_create_method(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
non_model_field = serializers.CharField() non_model_field = serializers.CharField()
class Meta: class Meta:
model = OneFieldModel model = OneFieldModel
fields = ('char_field', 'non_model_field') fields = ('char_field', 'non_model_field')
serializer = TestSerializer(data={ serializer = TestSerializer(data={'char_field':'foo','non_model_field':'bar',})
'char_field': 'foo',
'non_model_field': 'bar',
})
serializer.is_valid() serializer.is_valid()
msginitial='Got a `TypeError` when calling `OneFieldModel.objects.create()`.' msginitial='Got a `TypeError` when calling `OneFieldModel.objects.create()`.'
withself.assertRaisesMessage(TypeError,msginitial): withself.assertRaisesMessage(TypeError,msginitial):
serializer.save() serializer.save()
@ -136,7 +130,6 @@ class TestModelSerializer(TestCase):
""" """
class AbstractModel(models.Model): class AbstractModel(models.Model):
afield = models.CharField(max_length=255) afield = models.CharField(max_length=255)
class Meta: class Meta:
abstract = True abstract = True
@ -145,16 +138,12 @@ class TestModelSerializer(TestCase):
model = AbstractModel model = AbstractModel
fields = ('afield',) fields = ('afield',)
serializer = TestSerializer(data={ serializer = TestSerializer(data={'afield':'foo',})
'afield': 'foo',
})
msginitial='Cannot use ModelSerializer with Abstract Models.' msginitial='Cannot use ModelSerializer with Abstract Models.'
withself.assertRaisesMessage(ValueError,msginitial): withself.assertRaisesMessage(ValueError,msginitial):
serializer.is_valid() serializer.is_valid()
class TestRegularFieldMappings(TestCase):
def test_regular_fields(self): def test_regular_fields(self):
""" """
Model fields should map to their equivalent serializer fields. Model fields should map to their equivalent serializer fields.
@ -189,8 +178,7 @@ class TestRegularFieldMappings(TestCase):
custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>) custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>)
file_path_field = FilePathField(path='/tmp/') file_path_field = FilePathField(path='/tmp/')
""") """)
assertrepr(TestSerializer())==expected
self.assertEqual(repr(TestSerializer()), expected)
def test_field_options(self): def test_field_options(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -209,7 +197,7 @@ class TestRegularFieldMappings(TestCase):
descriptive_field = IntegerField(help_text='Some help text', label='A label') descriptive_field = IntegerField(help_text='Some help text', label='A label')
choices_field = ChoiceField(choices=(('red', 'Red'), ('blue', 'Blue'), ('green', 'Green'))) choices_field = ChoiceField(choices=(('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')))
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
# merge this into test_regular_fields / RegularFieldsModel when # merge this into test_regular_fields / RegularFieldsModel when
# Django 2.1 is the minimum supported version # Django 2.1 is the minimum supported version
@ -227,8 +215,7 @@ class TestRegularFieldMappings(TestCase):
NullableBooleanSerializer(): NullableBooleanSerializer():
field = BooleanField(allow_null=True, required=False) field = BooleanField(allow_null=True, required=False)
""") """)
assertrepr(NullableBooleanSerializer())==expected
self.assertEqual(repr(NullableBooleanSerializer()), expected)
def test_method_field(self): def test_method_field(self):
""" """
@ -245,7 +232,7 @@ class TestRegularFieldMappings(TestCase):
auto_field = IntegerField(read_only=True) auto_field = IntegerField(read_only=True)
method = ReadOnlyField() method = ReadOnlyField()
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_fields(self): def test_pk_fields(self):
""" """
@ -261,7 +248,7 @@ class TestRegularFieldMappings(TestCase):
pk = IntegerField(label='Auto field', read_only=True) pk = IntegerField(label='Auto field', read_only=True)
auto_field = IntegerField(read_only=True) auto_field = IntegerField(read_only=True)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_extra_field_kwargs(self): def test_extra_field_kwargs(self):
""" """
@ -278,7 +265,7 @@ class TestRegularFieldMappings(TestCase):
auto_field = IntegerField(read_only=True) auto_field = IntegerField(read_only=True)
char_field = CharField(default='extra', max_length=100) char_field = CharField(default='extra', max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_extra_field_kwargs_required(self): def test_extra_field_kwargs_required(self):
""" """
@ -295,7 +282,7 @@ class TestRegularFieldMappings(TestCase):
auto_field = IntegerField(read_only=False, required=False) auto_field = IntegerField(read_only=False, required=False)
char_field = CharField(max_length=100) char_field = CharField(max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_invalid_field(self): def test_invalid_field(self):
""" """
@ -318,15 +305,11 @@ class TestRegularFieldMappings(TestCase):
""" """
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
missing = serializers.ReadOnlyField() missing = serializers.ReadOnlyField()
class Meta: class Meta:
model = RegularFieldsModel model = RegularFieldsModel
fields = ('auto_field',) fields = ('auto_field',)
expected = ( expected = ("The field 'missing' was declared on serializer TestSerializer, ""but has not been included in the 'fields' option.")
"The field 'missing' was declared on serializer TestSerializer, "
"but has not been included in the 'fields' option."
)
withself.assertRaisesMessage(AssertionError,expected): withself.assertRaisesMessage(AssertionError,expected):
TestSerializer().fields TestSerializer().fields
@ -354,7 +337,6 @@ class TestRegularFieldMappings(TestCase):
ExampleSerializer() ExampleSerializer()
class TestDurationFieldMapping(TestCase):
def test_duration_field(self): def test_duration_field(self):
class DurationFieldModel(models.Model): class DurationFieldModel(models.Model):
""" """
@ -372,16 +354,14 @@ class TestDurationFieldMapping(TestCase):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
duration_field = DurationField() duration_field = DurationField()
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_duration_field_with_validators(self): def test_duration_field_with_validators(self):
class ValidatedDurationFieldModel(models.Model): class ValidatedDurationFieldModel(models.Model):
""" """
A model that defines DurationField with validators. A model that defines DurationField with validators.
""" """
duration_field = models.DurationField( duration_field = models.DurationField( validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))] )
validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))]
)
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -397,10 +377,9 @@ class TestDurationFieldMapping(TestCase):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
duration_field = DurationField(max_value=datetime.timedelta(days=3), min_value=datetime.timedelta(days=1)) duration_field = DurationField(max_value=datetime.timedelta(days=3), min_value=datetime.timedelta(days=1))
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
class TestGenericIPAddressFieldValidation(TestCase):
def test_ip_address_validation(self): def test_ip_address_validation(self):
class IPAddressFieldModel(models.Model): class IPAddressFieldModel(models.Model):
address = models.GenericIPAddressField() address = models.GenericIPAddressField()
@ -411,14 +390,11 @@ class TestGenericIPAddressFieldValidation(TestCase):
fields = '__all__' fields = '__all__'
s = TestSerializer(data={'address': 'not an ip address'}) s = TestSerializer(data={'address': 'not an ip address'})
self.assertFalse(s.is_valid()) assertnots.is_valid()
self.assertEqual(1, len(s.errors['address']), assert1==len(s.errors['address']),'Unexpected number of validation errors: ''{}'.format(s.errors)
'Unexpected number of validation errors: '
'{}'.format(s.errors))
@pytest.mark.skipif('not postgres_fields') @pytest.mark.skipif('not postgres_fields')
class TestPosgresFieldsMapping(TestCase):
def test_hstore_field(self): def test_hstore_field(self):
class HStoreFieldModel(models.Model): class HStoreFieldModel(models.Model):
hstore_field = postgres_fields.HStoreField() hstore_field = postgres_fields.HStoreField()
@ -432,7 +408,7 @@ class TestPosgresFieldsMapping(TestCase):
TestSerializer(): TestSerializer():
hstore_field = HStoreField() hstore_field = HStoreField()
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_array_field(self): def test_array_field(self):
class ArrayFieldModel(models.Model): class ArrayFieldModel(models.Model):
@ -447,7 +423,7 @@ class TestPosgresFieldsMapping(TestCase):
TestSerializer(): TestSerializer():
array_field = ListField(child=CharField(label='Array field', validators=[<django.core.validators.MaxLengthValidator object>])) array_field = ListField(child=CharField(label='Array field', validators=[<django.core.validators.MaxLengthValidator object>]))
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_json_field(self): def test_json_field(self):
class JSONFieldModel(models.Model): class JSONFieldModel(models.Model):
@ -462,7 +438,7 @@ class TestPosgresFieldsMapping(TestCase):
TestSerializer(): TestSerializer():
json_field = JSONField(style={'base_template': 'textarea.html'}) json_field = JSONField(style={'base_template': 'textarea.html'})
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
# Tests for relational field mappings. # Tests for relational field mappings.
@ -505,7 +481,6 @@ class UniqueTogetherModel(models.Model):
unique_together = ("foreign_key", "one_to_one") unique_together = ("foreign_key", "one_to_one")
class TestRelationalFieldMappings(TestCase):
def test_pk_relations(self): def test_pk_relations(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -520,7 +495,7 @@ class TestRelationalFieldMappings(TestCase):
many_to_many = PrimaryKeyRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all()) many_to_many = PrimaryKeyRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all())
through = PrimaryKeyRelatedField(many=True, read_only=True) through = PrimaryKeyRelatedField(many=True, read_only=True)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_nested_relations(self): def test_nested_relations(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -545,7 +520,7 @@ class TestRelationalFieldMappings(TestCase):
id = IntegerField(label='ID', read_only=True) id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_hyperlinked_relations(self): def test_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer): class TestSerializer(serializers.HyperlinkedModelSerializer):
@ -561,7 +536,7 @@ class TestRelationalFieldMappings(TestCase):
many_to_many = HyperlinkedRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') many_to_many = HyperlinkedRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_nested_hyperlinked_relations(self): def test_nested_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer): class TestSerializer(serializers.HyperlinkedModelSerializer):
@ -586,7 +561,7 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_nested_hyperlinked_relations_starred_source(self): def test_nested_hyperlinked_relations_starred_source(self):
class TestSerializer(serializers.HyperlinkedModelSerializer): class TestSerializer(serializers.HyperlinkedModelSerializer):
@ -594,11 +569,7 @@ class TestRelationalFieldMappings(TestCase):
model = RelationalModel model = RelationalModel
depth = 1 depth = 1
fields = '__all__' fields = '__all__'
extra_kwargs = { 'url': { 'source': '*', }}
extra_kwargs = {
'url': {
'source': '*',
}}
expected = dedent(""" expected = dedent("""
TestSerializer(): TestSerializer():
@ -617,7 +588,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
self.maxDiff=None self.maxDiff=None
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_nested_unique_together_relations(self): def test_nested_unique_together_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer): class TestSerializer(serializers.HyperlinkedModelSerializer):
@ -636,7 +607,7 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100) name = CharField(max_length=100)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_reverse_foreign_key(self): def test_pk_reverse_foreign_key(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -650,7 +621,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_reverse_one_to_one(self): def test_pk_reverse_one_to_one(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -664,7 +635,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all()) reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all())
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_reverse_many_to_many(self): def test_pk_reverse_many_to_many(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -678,7 +649,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
def test_pk_reverse_through(self): def test_pk_reverse_through(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -692,7 +663,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100) name = CharField(max_length=100)
reverse_through = PrimaryKeyRelatedField(many=True, read_only=True) reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)
""") """)
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
class DisplayValueTargetModel(models.Model): class DisplayValueTargetModel(models.Model):
@ -706,13 +677,8 @@ class DisplayValueModel(models.Model):
color = models.ForeignKey(DisplayValueTargetModel, on_delete=models.CASCADE) color = models.ForeignKey(DisplayValueTargetModel, on_delete=models.CASCADE)
class TestRelationalFieldDisplayValue(TestCase):
def setUp(self): def setUp(self):
DisplayValueTargetModel.objects.bulk_create([ DisplayValueTargetModel.objects.bulk_create([ DisplayValueTargetModel(name='Red'), DisplayValueTargetModel(name='Yellow'), DisplayValueTargetModel(name='Green'), ])
DisplayValueTargetModel(name='Red'),
DisplayValueTargetModel(name='Yellow'),
DisplayValueTargetModel(name='Green'),
])
def test_default_display_value(self): def test_default_display_value(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -722,7 +688,7 @@ class TestRelationalFieldDisplayValue(TestCase):
serializer = TestSerializer() serializer = TestSerializer()
expected=OrderedDict([(1,'Red Color'),(2,'Yellow Color'),(3,'Green Color')]) expected=OrderedDict([(1,'Red Color'),(2,'Yellow Color'),(3,'Green Color')])
self.assertEqual(serializer.fields['color'].choices, expected) assertserializer.fields['color'].choices==expected
def test_custom_display_value(self): def test_custom_display_value(self):
class TestField(serializers.PrimaryKeyRelatedField): class TestField(serializers.PrimaryKeyRelatedField):
@ -731,33 +697,20 @@ class TestRelationalFieldDisplayValue(TestCase):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
color = TestField(queryset=DisplayValueTargetModel.objects.all()) color = TestField(queryset=DisplayValueTargetModel.objects.all())
class Meta: class Meta:
model = DisplayValueModel model = DisplayValueModel
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
expected=OrderedDict([(1,'My Red Color'),(2,'My Yellow Color'),(3,'My Green Color')]) expected=OrderedDict([(1,'My Red Color'),(2,'My Yellow Color'),(3,'My Green Color')])
self.assertEqual(serializer.fields['color'].choices, expected) assertserializer.fields['color'].choices==expected
class TestIntegration(TestCase):
def setUp(self): def setUp(self):
self.foreign_key_target = ForeignKeyTargetModel.objects.create( self.foreign_key_target = ForeignKeyTargetModel.objects.create( name='foreign_key' )
name='foreign_key' self.one_to_one_target = OneToOneTargetModel.objects.create( name='one_to_one' )
) self.many_to_many_targets = [ ManyToManyTargetModel.objects.create( name='many_to_many (%d)' % idx ) for idx in range(3) ]
self.one_to_one_target = OneToOneTargetModel.objects.create( self.instance = RelationalModel.objects.create( foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, )
name='one_to_one'
)
self.many_to_many_targets = [
ManyToManyTargetModel.objects.create(
name='many_to_many (%d)' % idx
) for idx in range(3)
]
self.instance = RelationalModel.objects.create(
foreign_key=self.foreign_key_target,
one_to_one=self.one_to_one_target,
)
self.instance.many_to_many.set(self.many_to_many_targets) self.instance.many_to_many.set(self.many_to_many_targets)
def test_pk_retrival(self): def test_pk_retrival(self):
@ -767,14 +720,8 @@ class TestIntegration(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer(self.instance) serializer = TestSerializer(self.instance)
expected = { expected={'id':self.instance.pk,'foreign_key':self.foreign_key_target.pk,'one_to_one':self.one_to_one_target.pk,'many_to_many':[item.pkforiteminself.many_to_many_targets],'through':[]}
'id': self.instance.pk, assertserializer.data==expected
'foreign_key': self.foreign_key_target.pk,
'one_to_one': self.one_to_one_target.pk,
'many_to_many': [item.pk for item in self.many_to_many_targets],
'through': []
}
self.assertEqual(serializer.data, expected)
def test_pk_create(self): def test_pk_create(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -782,47 +729,19 @@ class TestIntegration(TestCase):
model = RelationalModel model = RelationalModel
fields = '__all__' fields = '__all__'
new_foreign_key = ForeignKeyTargetModel.objects.create( new_foreign_key = ForeignKeyTargetModel.objects.create(name='foreign_key')
name='foreign_key' new_one_to_one=OneToOneTargetModel.objects.create(name='one_to_one')
) new_many_to_many=[ManyToManyTargetModel.objects.create(name='new many_to_many (%d)'%idx)foridxinrange(3)]
new_one_to_one = OneToOneTargetModel.objects.create( data={'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],}
name='one_to_one'
)
new_many_to_many = [
ManyToManyTargetModel.objects.create(
name='new many_to_many (%d)' % idx
) for idx in range(3)
]
data = {
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
}
# Serializer should validate okay.
serializer=TestSerializer(data=data) serializer=TestSerializer(data=data)
assertserializer.is_valid() assertserializer.is_valid()
# Creating the instance, relationship attributes should be set.
instance=serializer.save() instance=serializer.save()
assertinstance.foreign_key.pk==new_foreign_key.pk assertinstance.foreign_key.pk==new_foreign_key.pk
assertinstance.one_to_one.pk==new_one_to_one.pk assertinstance.one_to_one.pk==new_one_to_one.pk
assert [ assert[item.pkforitemininstance.many_to_many.all()]==[item.pkforiteminnew_many_to_many]
item.pk for item in instance.many_to_many.all()
] == [
item.pk for item in new_many_to_many
]
assertlist(instance.through.all())==[] assertlist(instance.through.all())==[]
expected={'id':instance.pk,'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],'through':[]}
# Representation should be correct. assertserializer.data==expected
expected = {
'id': instance.pk,
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'through': []
}
self.assertEqual(serializer.data, expected)
def test_pk_update(self): def test_pk_update(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
@ -830,47 +749,19 @@ class TestIntegration(TestCase):
model = RelationalModel model = RelationalModel
fields = '__all__' fields = '__all__'
new_foreign_key = ForeignKeyTargetModel.objects.create( new_foreign_key = ForeignKeyTargetModel.objects.create(name='foreign_key')
name='foreign_key' new_one_to_one=OneToOneTargetModel.objects.create(name='one_to_one')
) new_many_to_many=[ManyToManyTargetModel.objects.create(name='new many_to_many (%d)'%idx)foridxinrange(3)]
new_one_to_one = OneToOneTargetModel.objects.create( data={'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],}
name='one_to_one'
)
new_many_to_many = [
ManyToManyTargetModel.objects.create(
name='new many_to_many (%d)' % idx
) for idx in range(3)
]
data = {
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
}
# Serializer should validate okay.
serializer=TestSerializer(self.instance,data=data) serializer=TestSerializer(self.instance,data=data)
assertserializer.is_valid() assertserializer.is_valid()
# Creating the instance, relationship attributes should be set.
instance=serializer.save() instance=serializer.save()
assertinstance.foreign_key.pk==new_foreign_key.pk assertinstance.foreign_key.pk==new_foreign_key.pk
assertinstance.one_to_one.pk==new_one_to_one.pk assertinstance.one_to_one.pk==new_one_to_one.pk
assert [ assert[item.pkforitemininstance.many_to_many.all()]==[item.pkforiteminnew_many_to_many]
item.pk for item in instance.many_to_many.all()
] == [
item.pk for item in new_many_to_many
]
assertlist(instance.through.all())==[] assertlist(instance.through.all())==[]
expected={'id':self.instance.pk,'foreign_key':new_foreign_key.pk,'one_to_one':new_one_to_one.pk,'many_to_many':[item.pkforiteminnew_many_to_many],'through':[]}
# Representation should be correct. assertserializer.data==expected
expected = {
'id': self.instance.pk,
'foreign_key': new_foreign_key.pk,
'one_to_one': new_one_to_one.pk,
'many_to_many': [item.pk for item in new_many_to_many],
'through': []
}
self.assertEqual(serializer.data, expected)
# Tests for bulk create using `ListSerializer`. # Tests for bulk create using `ListSerializer`.
@ -879,7 +770,6 @@ class BulkCreateModel(models.Model):
name = models.CharField(max_length=10) name = models.CharField(max_length=10)
class TestBulkCreate(TestCase):
def test_bulk_create(self): def test_bulk_create(self):
class BasicModelSerializer(serializers.ModelSerializer): class BasicModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -892,17 +782,11 @@ class TestBulkCreate(TestCase):
data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}] data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}]
serializer=BulkCreateSerializer(data=data) serializer=BulkCreateSerializer(data=data)
assertserializer.is_valid() assertserializer.is_valid()
# Objects are returned by save().
instances=serializer.save() instances=serializer.save()
assertlen(instances)==3 assertlen(instances)==3
assert[item.nameforitemininstances]==['a','b','c'] assert[item.nameforitemininstances]==['a','b','c']
# Objects have been created in the database.
assertBulkCreateModel.objects.count()==3 assertBulkCreateModel.objects.count()==3
assertlist(BulkCreateModel.objects.values_list('name',flat=True))==['a','b','c'] assertlist(BulkCreateModel.objects.values_list('name',flat=True))==['a','b','c']
# Serializer returns correct data.
assertserializer.data==data assertserializer.data==data
@ -910,7 +794,6 @@ class MetaClassTestModel(models.Model):
text = models.CharField(max_length=100) text = models.CharField(max_length=100)
class TestSerializerMetaClass(TestCase):
def test_meta_class_fields_option(self): def test_meta_class_fields_option(self):
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -945,42 +828,29 @@ class TestSerializerMetaClass(TestCase):
def test_declared_fields_with_exclude_option(self): def test_declared_fields_with_exclude_option(self):
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
text = serializers.CharField() text = serializers.CharField()
class Meta: class Meta:
model = MetaClassTestModel model = MetaClassTestModel
exclude = ('text',) exclude = ('text',)
expected = ( expected = ("Cannot both declare the field 'text' and include it in the ""ExampleSerializer 'exclude' option. Remove the field or, if ""inherited from a parent serializer, disable with `text = None`.")
"Cannot both declare the field 'text' and include it in the "
"ExampleSerializer 'exclude' option. Remove the field or, if "
"inherited from a parent serializer, disable with `text = None`."
)
withself.assertRaisesMessage(AssertionError,expected): withself.assertRaisesMessage(AssertionError,expected):
ExampleSerializer().fields ExampleSerializer().fields
class Issue2704TestCase(TestCase):
def test_queryset_all(self): def test_queryset_all(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
additional_attr = serializers.CharField() additional_attr = serializers.CharField()
class Meta: class Meta:
model = OneFieldModel model = OneFieldModel
fields = ('char_field', 'additional_attr') fields = ('char_field', 'additional_attr')
OneFieldModel.objects.create(char_field='abc') OneFieldModel.objects.create(char_field='abc')
qs=OneFieldModel.objects.all() qs=OneFieldModel.objects.all()
foroinqs: foroinqs:
o.additional_attr = '123' o.additional_attr = '123'
serializer = TestSerializer(instance=qs, many=True) serializer = TestSerializer(instance=qs, many=True)
expected=[{'char_field':'abc','additional_attr':'123',}]
expected = [{
'char_field': 'abc',
'additional_attr': '123',
}]
assertserializer.data==expected assertserializer.data==expected
@ -992,7 +862,6 @@ class DecimalFieldModel(models.Model):
) )
class TestDecimalFieldMappings(TestCase):
def test_decimal_field_has_decimal_validator(self): def test_decimal_field_has_decimal_validator(self):
""" """
Test that a `DecimalField` has no `DecimalValidator`. Test that a `DecimalField` has no `DecimalValidator`.
@ -1003,7 +872,6 @@ class TestDecimalFieldMappings(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
assertlen(serializer.fields['decimal_field'].validators)==2 assertlen(serializer.fields['decimal_field'].validators)==2
def test_min_value_is_passed(self): def test_min_value_is_passed(self):
@ -1017,7 +885,6 @@ class TestDecimalFieldMappings(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
assertserializer.fields['decimal_field'].min_value==1 assertserializer.fields['decimal_field'].min_value==1
def test_max_value_is_passed(self): def test_max_value_is_passed(self):
@ -1031,15 +898,12 @@ class TestDecimalFieldMappings(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
assertserializer.fields['decimal_field'].max_value==3 assertserializer.fields['decimal_field'].max_value==3
class TestMetaInheritance(TestCase):
def test_extra_kwargs_not_altered(self): def test_extra_kwargs_not_altered(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
non_model_field = serializers.CharField() non_model_field = serializers.CharField()
class Meta: class Meta:
model = OneFieldModel model = OneFieldModel
read_only_fields = ('char_field', 'non_model_field') read_only_fields = ('char_field', 'non_model_field')
@ -1055,15 +919,14 @@ class TestMetaInheritance(TestCase):
char_field = CharField(read_only=True) char_field = CharField(read_only=True)
non_model_field = CharField() non_model_field = CharField()
""") """)
child_expected=dedent(""" child_expected=dedent("""
ChildSerializer(): ChildSerializer():
char_field = CharField(max_length=100) char_field = CharField(max_length=100)
non_model_field = CharField() non_model_field = CharField()
""") """)
self.assertEqual(repr(ChildSerializer()), child_expected) assertrepr(ChildSerializer())==child_expected
self.assertEqual(repr(TestSerializer()), test_expected) assertrepr(TestSerializer())==test_expected
self.assertEqual(repr(ChildSerializer()), child_expected) assertrepr(ChildSerializer())==child_expected
class OneToOneTargetTestModel(models.Model): class OneToOneTargetTestModel(models.Model):
@ -1074,7 +937,6 @@ class OneToOneSourceTestModel(models.Model):
target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE) target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE)
class TestModelFieldValues(TestCase):
def test_model_field(self): def test_model_field(self):
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -1084,15 +946,13 @@ class TestModelFieldValues(TestCase):
target = OneToOneTargetTestModel(id=1, text='abc') target = OneToOneTargetTestModel(id=1, text='abc')
source=OneToOneSourceTestModel(target=target) source=OneToOneSourceTestModel(target=target)
serializer=ExampleSerializer(source) serializer=ExampleSerializer(source)
self.assertEqual(serializer.data, {'target': 1}) assertserializer.data=={'target':1}
class TestUniquenessOverride(TestCase):
def test_required_not_overwritten(self): def test_required_not_overwritten(self):
class TestModel(models.Model): class TestModel(models.Model):
field_1 = models.IntegerField(null=True) field_1 = models.IntegerField(null=True)
field_2 = models.IntegerField() field_2 = models.IntegerField()
class Meta: class Meta:
unique_together = (('field_1', 'field_2'),) unique_together = (('field_1', 'field_2'),)
@ -1103,11 +963,10 @@ class TestUniquenessOverride(TestCase):
extra_kwargs = {'field_1': {'required': False}} extra_kwargs = {'field_1': {'required': False}}
fields = TestSerializer().fields fields = TestSerializer().fields
self.assertFalse(fields['field_1'].required) assertnotfields['field_1'].required
self.assertTrue(fields['field_2'].required) assertfields['field_2'].required
class Issue3674Test(TestCase):
def test_nonPK_foreignkey_model_serializer(self): def test_nonPK_foreignkey_model_serializer(self):
class TestParentModel(models.Model): class TestParentModel(models.Model):
title = models.CharField(max_length=64) title = models.CharField(max_length=64)
@ -1132,14 +991,13 @@ class Issue3674Test(TestCase):
title = CharField(max_length=64) title = CharField(max_length=64)
children = PrimaryKeyRelatedField(many=True, queryset=TestChildModel.objects.all()) children = PrimaryKeyRelatedField(many=True, queryset=TestChildModel.objects.all())
""") """)
self.assertEqual(repr(TestParentModelSerializer()), parent_expected) assertrepr(TestParentModelSerializer())==parent_expected
child_expected=dedent(""" child_expected=dedent("""
TestChildModelSerializer(): TestChildModelSerializer():
value = CharField(max_length=64, validators=[<UniqueValidator(queryset=TestChildModel.objects.all())>]) value = CharField(max_length=64, validators=[<UniqueValidator(queryset=TestChildModel.objects.all())>])
parent = PrimaryKeyRelatedField(queryset=TestParentModel.objects.all()) parent = PrimaryKeyRelatedField(queryset=TestParentModel.objects.all())
""") """)
self.assertEqual(repr(TestChildModelSerializer()), child_expected) assertrepr(TestChildModelSerializer())==child_expected
def test_nonID_PK_foreignkey_model_serializer(self): def test_nonID_PK_foreignkey_model_serializer(self):
@ -1155,18 +1013,14 @@ class Issue3674Test(TestCase):
parent = Issue3674ParentModel.objects.create(title='abc') parent = Issue3674ParentModel.objects.create(title='abc')
child=Issue3674ChildModel.objects.create(value='def',parent=parent) child=Issue3674ChildModel.objects.create(value='def',parent=parent)
parent_serializer=TestParentModelSerializer(parent) parent_serializer=TestParentModelSerializer(parent)
child_serializer=TestChildModelSerializer(child) child_serializer=TestChildModelSerializer(child)
parent_expected={'children':['def'],'id':1,'title':'abc'} parent_expected={'children':['def'],'id':1,'title':'abc'}
self.assertEqual(parent_serializer.data, parent_expected) assertparent_serializer.data==parent_expected
child_expected={'parent':1,'value':'def'} child_expected={'parent':1,'value':'def'}
self.assertEqual(child_serializer.data, child_expected) assertchild_serializer.data==child_expected
class Issue4897TestCase(TestCase):
def test_should_assert_if_writing_readonly_fields(self): def test_should_assert_if_writing_readonly_fields(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -1175,13 +1029,11 @@ class Issue4897TestCase(TestCase):
readonly_fields = fields readonly_fields = fields
obj = OneFieldModel.objects.create(char_field='abc') obj = OneFieldModel.objects.create(char_field='abc')
withpytest.raises(AssertionError)ascm: withpytest.raises(AssertionError)ascm:
TestSerializer(obj).fields TestSerializer(obj).fields
cm.match(r'readonly_fields') cm.match(r'readonly_fields')
class Test5004UniqueChoiceField(TestCase):
def test_unique_choice_field(self): def test_unique_choice_field(self):
class TestUniqueChoiceSerializer(serializers.ModelSerializer): class TestUniqueChoiceSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -1194,7 +1046,6 @@ class Test5004UniqueChoiceField(TestCase):
assertserializer.errors=={'name':['unique choice model with this name already exists.']} assertserializer.errors=={'name':['unique choice model with this name already exists.']}
class TestFieldSource(TestCase):
def test_traverse_nullable_fk(self): def test_traverse_nullable_fk(self):
""" """
A dotted source with nullable elements uses default when any item in the chain is None. #5849. A dotted source with nullable elements uses default when any item in the chain is None. #5849.
@ -1203,10 +1054,7 @@ class TestFieldSource(TestCase):
but using RelatedField, rather than CharField. but using RelatedField, rather than CharField.
""" """
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):
target = serializers.PrimaryKeyRelatedField( target = serializers.PrimaryKeyRelatedField( source='target.target', read_only=True, allow_null=True, default=None )
source='target.target', read_only=True, allow_null=True, default=None
)
class Meta: class Meta:
model = NestedForeignKeySource model = NestedForeignKeySource
fields = ('target', ) fields = ('target', )
@ -1220,18 +1068,14 @@ class TestFieldSource(TestCase):
class Meta: class Meta:
model = RegularFieldsModel model = RegularFieldsModel
fields = ('number_field',) fields = ('number_field',)
extra_kwargs = { extra_kwargs = { 'number_field': { 'source': 'integer_field' } }
'number_field': {
'source': 'integer_field'
}
}
expected = dedent(""" expected = dedent("""
TestSerializer(): TestSerializer():
number_field = IntegerField(source='integer_field') number_field = IntegerField(source='integer_field')
""") """)
self.maxDiff=None self.maxDiff=None
self.assertEqual(repr(TestSerializer()), expected) assertrepr(TestSerializer())==expected
class Issue6110TestModel(models.Model): class Issue6110TestModel(models.Model):
@ -1247,11 +1091,10 @@ class Issue6110ModelSerializer(serializers.ModelSerializer):
fields = ('name',) fields = ('name',)
class Issue6110Test(TestCase):
def test_model_serializer_custom_manager(self): def test_model_serializer_custom_manager(self):
instance = Issue6110ModelSerializer().create({'name': 'test_name'}) instance = Issue6110ModelSerializer().create({'name': 'test_name'})
self.assertEqual(instance.name, 'test_name') assert instance.name == 'test_name'
def test_model_serializer_custom_manager_error_message(self): def test_model_serializer_custom_manager_error_message(self):
msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.') msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.')

View File

@ -33,7 +33,6 @@ class AssociatedModelSerializer(serializers.ModelSerializer):
# Tests # Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
@ -59,9 +58,6 @@ class InheritedModelSerializationTests(TestCase):
Assert that the pointer to the parent table is not a required field Assert that the pointer to the parent table is not a required field
for input data for input data
""" """
data = { data = { 'name1': 'parent name', 'name2': 'child name', }
'name1': 'parent name',
'name2': 'child name',
}
serializer = DerivedModelSerializer(data=data) serializer = DerivedModelSerializer(data=data)
assert serializer.is_valid() is True assert serializer.is_valid() is True

View File

@ -30,7 +30,6 @@ class NoCharsetSpecifiedRenderer(BaseRenderer):
media_type = 'my/media' media_type = 'my/media'
class TestAcceptedMediaType(TestCase):
def setUp(self): def setUp(self):
self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()] self.renderers = [MockJSONRenderer(), MockHTMLRenderer(), MockOpenAPIRenderer()]
self.negotiator = DefaultContentNegotiation() self.negotiator = DefaultContentNegotiation()
@ -85,7 +84,6 @@ class TestAcceptedMediaType(TestCase):
self.negotiator.filter_renderers(renderers, format='json') self.negotiator.filter_renderers(renderers, format='json')
class BaseContentNegotiationTests(TestCase):
def setUp(self): def setUp(self):
self.negotiator = BaseContentNegotiation() self.negotiator = BaseContentNegotiation()

View File

@ -28,7 +28,6 @@ class ChildAssociatedModelSerializer(serializers.ModelSerializer):
# Tests # Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
@ -37,5 +36,4 @@ class InheritedModelSerializationTests(TestCase):
""" """
child = ChildModel(name1='parent name', name2='child name') child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child) serializer = DerivedModelSerializer(child)
self.assertEqual(set(serializer.data), assert set(serializer.data) == {'name1', 'name2', 'id', 'childassociatedmodel'}
{'name1', 'name2', 'id', 'childassociatedmodel'})

View File

@ -22,31 +22,24 @@ class Form(forms.Form):
field2 = forms.CharField() field2 = forms.CharField()
class TestFormParser(TestCase):
def setUp(self): def setUp(self):
self.string = "field1=abc&field2=defghijk" self.string = "field1=abc&field2=defghijk"
def test_parse(self): def test_parse(self):
""" Make sure the `QueryDict` works OK """ """ Make sure the `QueryDict` works OK """
parser = FormParser() parser = FormParser()
stream = io.StringIO(self.string) stream = io.StringIO(self.string)
data = parser.parse(stream) data = parser.parse(stream)
assert Form(data).is_valid() is True assert Form(data).is_valid() is True
class TestFileUploadParser(TestCase):
def setUp(self): def setUp(self):
class MockRequest: class MockRequest:
pass pass
self.stream = io.BytesIO(b"Test text file") self.stream = io.BytesIO(b"Test text file")
request=MockRequest() request=MockRequest()
request.upload_handlers=(MemoryFileUploadHandler(),) request.upload_handlers=(MemoryFileUploadHandler(),)
request.META = { request.META={'HTTP_CONTENT_DISPOSITION':'Content-Disposition: inline; filename=file.txt','HTTP_CONTENT_LENGTH':14,}
'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt',
'HTTP_CONTENT_LENGTH': 14,
}
self.parser_context={'request':request,'kwargs':{}} self.parser_context={'request':request,'kwargs':{}}
def test_parse(self): def test_parse(self):
@ -77,10 +70,7 @@ class TestFileUploadParser(TestCase):
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( self.parser_context['request'].upload_handlers = ( MemoryFileUploadHandler(), MemoryFileUploadHandler() )
MemoryFileUploadHandler(),
MemoryFileUploadHandler()
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo: with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context) parser.parse(self.stream, None, self.parser_context)
@ -92,9 +82,7 @@ class TestFileUploadParser(TestCase):
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( self.parser_context['request'].upload_handlers = ( TemporaryFileUploadHandler(), )
TemporaryFileUploadHandler(),
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo: with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context) parser.parse(self.stream, None, self.parser_context)
@ -107,15 +95,12 @@ class TestFileUploadParser(TestCase):
def test_get_encoded_filename(self): def test_get_encoded_filename(self):
parser = FileUploadParser() parser = FileUploadParser()
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt') self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt') self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == 'ÀĥƦ.txt'
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt') self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == 'ÀĥƦ.txt'
@ -124,14 +109,11 @@ class TestFileUploadParser(TestCase):
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
class TestJSONParser(TestCase):
def bytes(self, value): def bytes(self, value):
return io.BytesIO(value.encode()) return io.BytesIO(value.encode())
def test_float_strictness(self): def test_float_strictness(self):
parser = JSONParser() parser = JSONParser()
# Default to strict
for value in ['Infinity', '-Infinity', 'NaN']: for value in ['Infinity', '-Infinity', 'NaN']:
with pytest.raises(ParseError): with pytest.raises(ParseError):
parser.parse(self.bytes(value)) parser.parse(self.bytes(value))
@ -142,7 +124,6 @@ class TestJSONParser(TestCase):
assertmath.isnan(parser.parse(self.bytes('NaN'))) assertmath.isnan(parser.parse(self.bytes('NaN')))
class TestPOSTAccessed(TestCase):
def setUp(self): def setUp(self):
self.factory = APIRequestFactory() self.factory = APIRequestFactory()

View File

@ -72,32 +72,21 @@ def basic_auth_header(username, password):
return 'Basic %s' % base64_credentials return 'Basic %s' % base64_credentials
class ModelPermissionsIntegrationTests(TestCase):
def setUp(self): def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password') User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
user = User.objects.create_user('permitted', 'permitted@example.com', 'password') user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
user.user_permissions.set([ user.user_permissions.set([ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') ])
Permission.objects.get(codename='add_basicmodel'),
Permission.objects.get(codename='change_basicmodel'),
Permission.objects.get(codename='delete_basicmodel')
])
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
user.user_permissions.set([ user.user_permissions.set([ Permission.objects.get(codename='change_basicmodel'), ])
Permission.objects.get(codename='change_basicmodel'),
])
self.permitted_credentials = basic_auth_header('permitted', 'password') self.permitted_credentials = basic_auth_header('permitted', 'password')
self.disallowed_credentials = basic_auth_header('disallowed', 'password') self.disallowed_credentials = basic_auth_header('disallowed', 'password')
self.updateonly_credentials = basic_auth_header('updateonly', 'password') self.updateonly_credentials = basic_auth_header('updateonly', 'password')
BasicModel(text='foo').save() BasicModel(text='foo').save()
def test_has_create_permissions(self): def test_has_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request, pk=1) response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) assert response.status_code == status.HTTP_201_CREATED
def test_api_root_view_discard_default_django_model_permission(self): def test_api_root_view_discard_default_django_model_permission(self):
""" """
@ -105,130 +94,101 @@ class ModelPermissionsIntegrationTests(TestCase):
apply to APIRoot view. More specifically we check expected behavior of apply to APIRoot view. More specifically we check expected behavior of
``_ignore_model_permissions`` attribute support. ``_ignore_model_permissions`` attribute support.
""" """
request = factory.get('/', format='json', request = factory.get('/', format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials)
request.resolver_match = ResolverMatch('get', (), {}) request.resolver_match = ResolverMatch('get', (), {})
response = api_root_view(request) response = api_root_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_get_queryset_has_create_permissions(self): def test_get_queryset_has_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials)
response = get_queryset_list_view(request, pk=1) response = get_queryset_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) assert response.status_code == status.HTTP_201_CREATED
def test_has_put_permissions(self): def test_has_put_permissions(self):
request = factory.put('/1', {'text': 'foobar'}, format='json', request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials)
HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_has_delete_permissions(self): def test_has_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) assert response.status_code == status.HTTP_204_NO_CONTENT
def test_does_not_have_create_permissions(self): def test_does_not_have_create_permissions(self):
request = factory.post('/', {'text': 'foobar'}, format='json', request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = root_view(request, pk=1) response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
def test_does_not_have_put_permissions(self): def test_does_not_have_put_permissions(self):
request = factory.put('/1', {'text': 'foobar'}, format='json', request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials)
HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
def test_does_not_have_delete_permissions(self): def test_does_not_have_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
def test_options_permitted(self): def test_options_permitted(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.permitted_credentials )
'/',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertIn('actions', response.data) assert 'actions' in response.data
self.assertEqual(list(response.data['actions']), ['POST']) assert list(response.data['actions']) == ['POST']
request = factory.options( '/1', HTTP_AUTHORIZATION=self.permitted_credentials )
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertIn('actions', response.data) assert 'actions' in response.data
self.assertEqual(list(response.data['actions']), ['PUT']) assert list(response.data['actions']) == ['PUT']
def test_options_disallowed(self): def test_options_disallowed(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.disallowed_credentials )
'/',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertNotIn('actions', response.data) assert 'actions' not in response.data
request = factory.options( '/1', HTTP_AUTHORIZATION=self.disallowed_credentials )
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertNotIn('actions', response.data) assert 'actions' not in response.data
def test_options_updateonly(self): def test_options_updateonly(self):
request = factory.options( request = factory.options( '/', HTTP_AUTHORIZATION=self.updateonly_credentials )
'/',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertNotIn('actions', response.data) assert 'actions' not in response.data
request = factory.options( '/1', HTTP_AUTHORIZATION=self.updateonly_credentials )
request = factory.options(
'/1',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertIn('actions', response.data) assert 'actions' in response.data
self.assertEqual(list(response.data['actions']), ['PUT']) assert list(response.data['actions']) == ['PUT']
def test_empty_view_does_not_assert(self): def test_empty_view_does_not_assert(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = empty_list_view(request, pk=1) response = empty_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_calling_method_not_allowed(self): def test_calling_method_not_allowed(self):
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials)
response = root_view(request) response = root_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_check_auth_before_queryset_call(self): def test_check_auth_before_queryset_call(self):
class View(RootView): class View(RootView):
def get_queryset(_): def get_queryset(_):
self.fail('should not reach due to auth check') self.fail('should not reach due to auth check')
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION='') request=factory.get('/',HTTP_AUTHORIZATION='')
response=view(request) response=view(request)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) assertresponse.status_code==status.HTTP_401_UNAUTHORIZED
def test_queryset_assertions(self): def test_queryset_assertions(self):
class View(views.APIView): class View(views.APIView):
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions] permission_classes = [permissions.DjangoModelPermissions]
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials) request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
msg='Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.' msg='Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.'
withself.assertRaisesMessage(AssertionError,msg): withself.assertRaisesMessage(AssertionError,msg):
@ -239,7 +199,6 @@ class ModelPermissionsIntegrationTests(TestCase):
def get_queryset(self): def get_queryset(self):
return None return None
view = View.as_view() view = View.as_view()
request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials) request=factory.get('/',HTTP_AUTHORIZATION=self.permitted_credentials)
withself.assertRaisesMessage(AssertionError,'View.get_queryset() returned None'): withself.assertRaisesMessage(AssertionError,'View.get_queryset() returned None'):
view(request) view(request)
@ -310,52 +269,32 @@ get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.a
@unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed') @unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed')
class ObjectPermissionsIntegrationTests(TestCase):
""" """
Integration tests for the object level permissions API. Integration tests for the object level permissions API.
""" """
defsetUp(self): defsetUp(self):
from guardian.shortcuts import assign_perm from guardian.shortcuts import assign_perm
# create users
create = User.objects.create_user create = User.objects.create_user
users = { users = { 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), 'readonly': create('readonly', 'readonly@example.com', 'password'), 'writeonly': create('writeonly', 'writeonly@example.com', 'password'), 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), }
'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
'readonly': create('readonly', 'readonly@example.com', 'password'),
'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
}
# give everyone model level permissions, as we are not testing those
everyone = Group.objects.create(name='everyone') everyone = Group.objects.create(name='everyone')
model_name = BasicPermModel._meta.model_name model_name = BasicPermModel._meta.model_name
app_label = BasicPermModel._meta.app_label app_label = BasicPermModel._meta.app_label
f = '{}_{}'.format f = '{}_{}'.format
perms = { perms = { 'view': f('view', model_name), 'change': f('change', model_name), 'delete': f('delete', model_name) }
'view': f('view', model_name),
'change': f('change', model_name),
'delete': f('delete', model_name)
}
for perm in perms.values(): for perm in perms.values():
perm = '{}.{}'.format(app_label, perm) perm = '{}.{}'.format(app_label, perm)
assign_perm(perm, everyone) assign_perm(perm, everyone)
everyone.user_set.add(*users.values()) everyone.user_set.add(*users.values())
# appropriate object level permissions
readers=Group.objects.create(name='readers') readers=Group.objects.create(name='readers')
writers=Group.objects.create(name='writers') writers=Group.objects.create(name='writers')
deleters=Group.objects.create(name='deleters') deleters=Group.objects.create(name='deleters')
model=BasicPermModel.objects.create(text='foo') model=BasicPermModel.objects.create(text='foo')
assign_perm(perms['view'],readers,model) assign_perm(perms['view'],readers,model)
assign_perm(perms['change'],writers,model) assign_perm(perms['change'],writers,model)
assign_perm(perms['delete'],deleters,model) assign_perm(perms['delete'],deleters,model)
readers.user_set.add(users['fullaccess'],users['readonly']) readers.user_set.add(users['fullaccess'],users['readonly'])
writers.user_set.add(users['fullaccess'],users['writeonly']) writers.user_set.add(users['fullaccess'],users['writeonly'])
deleters.user_set.add(users['fullaccess'],users['deleteonly']) deleters.user_set.add(users['fullaccess'],users['deleteonly'])
self.credentials={} self.credentials={}
foruserinusers.values(): foruserinusers.values():
self.credentials[user.username] = basic_auth_header(user.username, 'password') self.credentials[user.username] = basic_auth_header(user.username, 'password')
@ -364,49 +303,40 @@ class ObjectPermissionsIntegrationTests(TestCase):
def test_can_delete_permissions(self): def test_can_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) assert response.status_code == status.HTTP_204_NO_CONTENT
def test_cannot_delete_permissions(self): def test_cannot_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
# Update # Update
def test_can_update_permissions(self): def test_can_update_permissions(self):
request = factory.patch( request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['writeonly'] )
'/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['writeonly']
)
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
self.assertEqual(response.data.get('text'), 'foobar') assert response.data.get('text') == 'foobar'
def test_cannot_update_permissions(self): def test_cannot_update_permissions(self):
request = factory.patch( request = factory.patch( '/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
'/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
def test_cannot_update_permissions_non_existing(self): def test_cannot_update_permissions_non_existing(self):
request = factory.patch( request = factory.patch( '/999', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.credentials['deleteonly'] )
'/999', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='999') response = object_permissions_view(request, pk='999')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
# Read # Read
def test_can_read_permissions(self): def test_can_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
def test_cannot_read_permissions(self): def test_cannot_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) assert response.status_code == status.HTTP_404_NOT_FOUND
def test_can_read_get_queryset_permissions(self): def test_can_read_get_queryset_permissions(self):
""" """
@ -415,7 +345,7 @@ class ObjectPermissionsIntegrationTests(TestCase):
""" """
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = get_queryset_object_permissions_view(request, pk='1') response = get_queryset_object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code == status.HTTP_200_OK
# Read list # Read list
def test_django_object_permissions_filter_deprecated(self): def test_django_object_permissions_filter_deprecated(self):
@ -423,36 +353,33 @@ class ObjectPermissionsIntegrationTests(TestCase):
warnings.simplefilter("always") warnings.simplefilter("always")
DjangoObjectPermissionsFilter() DjangoObjectPermissionsFilter()
message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved " message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved ""to the 3rd-party django-rest-framework-guardian package.")
"to the 3rd-party django-rest-framework-guardian package.") assertlen(w)==1
self.assertEqual(len(w), 1) assertw[-1].categoryisRemovedInDRF310Warning
self.assertIs(w[-1].category, RemovedInDRF310Warning) assertstr(w[-1].message)==message
self.assertEqual(str(w[-1].message), message)
def test_can_read_list_permissions(self): def test_can_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
# TODO: remove in version 3.10
with warnings.catch_warnings(record=True): with warnings.catch_warnings(record=True):
warnings.simplefilter("always") warnings.simplefilter("always")
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code== status.HTTP_200_OK
self.assertEqual(response.data[0].get('id'), 1) assertresponse.data[0].get('id')==1
def test_cannot_read_list_permissions(self): def test_cannot_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
# TODO: remove in version 3.10
with warnings.catch_warnings(record=True): with warnings.catch_warnings(record=True):
warnings.simplefilter("always") warnings.simplefilter("always")
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) assert response.status_code== status.HTTP_200_OK
self.assertListEqual(response.data, []) assertresponse.data==[]
def test_cannot_method_not_allowed(self): def test_cannot_method_not_allowed(self):
request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
class BasicPerm(permissions.BasePermission): class BasicPerm(permissions.BasePermission):
@ -507,9 +434,6 @@ denied_view_with_detail = DeniedViewWithDetail.as_view()
denied_object_view = DeniedObjectView.as_view() denied_object_view = DeniedObjectView.as_view()
denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view() denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()
class CustomPermissionsTests(TestCase):
def setUp(self): def setUp(self):
BasicModel(text='foo').save() BasicModel(text='foo').save()
User.objects.create_user('username', 'username@example.com', 'password') User.objects.create_user('username', 'username@example.com', 'password')
@ -520,39 +444,34 @@ class CustomPermissionsTests(TestCase):
def test_permission_denied(self): def test_permission_denied(self):
response = denied_view(self.request, pk=1) response = denied_view(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertNotEqual(detail, self.custom_message) assert detail != self.custom_message
def test_permission_denied_with_custom_detail(self): def test_permission_denied_with_custom_detail(self):
response = denied_view_with_detail(self.request, pk=1) response = denied_view_with_detail(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(detail, self.custom_message) assert detail == self.custom_message
def test_permission_denied_for_object(self): def test_permission_denied_for_object(self):
response = denied_object_view(self.request, pk=1) response = denied_object_view(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertNotEqual(detail, self.custom_message) assert detail != self.custom_message
def test_permission_denied_for_object_with_custom_detail(self): def test_permission_denied_for_object_with_custom_detail(self):
response = denied_object_view_with_detail(self.request, pk=1) response = denied_object_view_with_detail(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) assert response.status_code == status.HTTP_403_FORBIDDEN
self.assertEqual(detail, self.custom_message) assert detail == self.custom_message
class PermissionsCompositionTests(TestCase):
def setUp(self): def setUp(self):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user( self.user = User.objects.create_user( self.username, self.email, self.password )
self.username,
self.email,
self.password
)
self.client.login(username=self.username, password=self.password) self.client.login(username=self.username, password=self.password)
def test_and_false(self): def test_and_false(self):
@ -594,46 +513,30 @@ class PermissionsCompositionTests(TestCase):
def test_several_levels_without_negation(self): def test_several_levels_without_negation(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated & permissions.IsAuthenticated )
permissions.IsAuthenticated &
permissions.IsAuthenticated &
permissions.IsAuthenticated &
permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence_with_negation(self): def test_several_levels_and_precedence_with_negation(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & ~ permissions.IsAdminUser & permissions.IsAuthenticated & ~(permissions.IsAdminUser & permissions.IsAdminUser) )
permissions.IsAuthenticated &
~ permissions.IsAdminUser &
permissions.IsAuthenticated &
~(permissions.IsAdminUser & permissions.IsAdminUser)
)
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence(self): def test_several_levels_and_precedence(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = self.user request.user = self.user
composed_perm = ( composed_perm = ( permissions.IsAuthenticated & permissions.IsAuthenticated | permissions.IsAuthenticated & permissions.IsAuthenticated )
permissions.IsAuthenticated &
permissions.IsAuthenticated |
permissions.IsAuthenticated &
permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available") @pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
deftest_or_lazyness(self): deftest_or_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_allow.assert_called_once() mock_allow.assert_called_once()
mock_deny.assert_not_called() mock_deny.assert_not_called()
@ -641,7 +544,7 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_deny.assert_called_once() mock_deny.assert_called_once()
mock_allow.assert_called_once() mock_allow.assert_called_once()
@ -649,12 +552,11 @@ class PermissionsCompositionTests(TestCase):
deftest_object_or_lazyness(self): deftest_object_or_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_allow.assert_called_once() mock_allow.assert_called_once()
mock_deny.assert_not_called() mock_deny.assert_not_called()
@ -662,7 +564,7 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_deny.assert_called_once() mock_deny.assert_called_once()
mock_allow.assert_called_once() mock_allow.assert_called_once()
@ -670,12 +572,11 @@ class PermissionsCompositionTests(TestCase):
deftest_and_lazyness(self): deftest_and_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_called_once() mock_allow.assert_called_once()
mock_deny.assert_called_once() mock_deny.assert_called_once()
@ -683,7 +584,7 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_not_called() mock_allow.assert_not_called()
mock_deny.assert_called_once() mock_deny.assert_called_once()
@ -691,12 +592,11 @@ class PermissionsCompositionTests(TestCase):
deftest_object_and_lazyness(self): deftest_object_and_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_called_once() mock_allow.assert_called_once()
mock_deny.assert_called_once() mock_deny.assert_called_once()
@ -704,6 +604,6 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_not_called() mock_allow.assert_not_called()
mock_deny.assert_called_once() mock_deny.assert_called_once()

View File

@ -18,7 +18,6 @@ class UserUpdate(generics.UpdateAPIView):
serializer_class = UserSerializer serializer_class = UserSerializer
class TestPrefetchRelatedUpdates(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com') self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
@ -31,12 +30,7 @@ class TestPrefetchRelatedUpdates(TestCase):
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk) response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1 assert User.objects.get(pk=pk).groups.count() == 1
expected = { expected = { 'id': pk, 'username': 'new', 'groups': [1], 'email': 'tom@example.com' }
'id': pk,
'username': 'new',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected assert response.data == expected
def test_prefetch_related_excluding_instance_from_original_queryset(self): def test_prefetch_related_excluding_instance_from_original_queryset(self):
@ -49,10 +43,5 @@ class TestPrefetchRelatedUpdates(TestCase):
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk) response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1 assert User.objects.get(pk=pk).groups.count() == 1
expected = { expected = { 'id': pk, 'username': 'exclude', 'groups': [1], 'email': 'tom@example.com' }
'id': pk,
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected assert response.data == expected

View File

@ -70,7 +70,6 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedManyToManyTests(TestCase):
def setUp(self): def setUp(self):
for idx in range(1, 4): for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx) target = ManyToManyTarget(name='target-%d' % idx)
@ -83,22 +82,14 @@ class HyperlinkedManyToManyTests(TestCase):
def test_relative_hyperlinks(self): def test_relative_hyperlinks(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
expected = [ expected = [ {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} ]
{'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
def test_many_to_many_retrieve(self): def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -111,11 +102,7 @@ class HyperlinkedManyToManyTests(TestCase):
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -126,15 +113,9 @@ class HyperlinkedManyToManyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ]
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
@ -144,15 +125,9 @@ class HyperlinkedManyToManyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ]
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
assert serializer.data == expected assert serializer.data == expected
def test_many_to_many_create(self): def test_many_to_many_create(self):
@ -162,16 +137,9 @@ class HyperlinkedManyToManyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ]
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_many_to_many_create(self): def test_reverse_many_to_many_create(self):
@ -181,21 +149,13 @@ class HyperlinkedManyToManyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-4' assert obj.name == 'target-4'
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ]
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
]
assert serializer.data == expected assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
@ -208,21 +168,14 @@ class HyperlinkedForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
with self.assertNumQueries(1): with self.assertNumQueries(1):
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
with self.assertNumQueries(3): with self.assertNumQueries(3):
assert serializer.data == expected assert serializer.data == expected
@ -233,15 +186,9 @@ class HyperlinkedForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ]
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
@ -256,26 +203,15 @@ class HyperlinkedForeignKeyTests(TestCase):
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
assert serializer.is_valid() assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ]
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected assert new_serializer.data == expected
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create(self): def test_foreign_key_create(self):
@ -285,16 +221,9 @@ class HyperlinkedForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, ]
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
@ -304,15 +233,9 @@ class HyperlinkedForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-3' assert obj.name == 'target-3'
# Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ]
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
@ -324,7 +247,6 @@ class HyperlinkedForeignKeyTests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedNullableForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
@ -337,11 +259,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
@ -351,16 +269,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self): def test_foreign_key_create_with_valid_emptystring(self):
@ -375,16 +286,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
@ -394,15 +298,9 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self): def test_foreign_key_update_with_valid_emptystring(self):
@ -417,20 +315,13 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ]
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
class HyperlinkedNullableOneToOneTests(TestCase):
def setUp(self): def setUp(self):
target = OneToOneTarget(name='target-1') target = OneToOneTarget(name='target-1')
target.save() target.save()
@ -442,8 +333,5 @@ class HyperlinkedNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self): def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
expected = [ expected = [ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ]
{'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
assert serializer.data == expected assert serializer.data == expected

View File

@ -77,7 +77,6 @@ class OneToOnePKSourceSerializer(serializers.ModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid # TODO: Add test that .data cannot be accessed prior to .is_valid
class PKManyToManyTests(TestCase):
def setUp(self): def setUp(self):
for idx in range(1, 4): for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx) target = ManyToManyTarget(name='target-%d' % idx)
@ -90,11 +89,7 @@ class PKManyToManyTests(TestCase):
def test_many_to_many_retrieve(self): def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -107,11 +102,7 @@ class PKManyToManyTests(TestCase):
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -122,15 +113,9 @@ class PKManyToManyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ]
{'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
@ -140,15 +125,9 @@ class PKManyToManyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ]
{'id': 1, 'name': 'target-1', 'sources': [1]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]}
]
assert serializer.data == expected assert serializer.data == expected
def test_many_to_many_create(self): def test_many_to_many_create(self):
@ -158,25 +137,15 @@ class PKManyToManyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True) serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, ]
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
{'id': 4, 'name': 'source-4', 'targets': [1, 3]},
]
assert serializer.data == expected assert serializer.data == expected
def test_many_to_many_unsaved(self): def test_many_to_many_unsaved(self):
source = ManyToManySource(name='source-unsaved') source = ManyToManySource(name='source-unsaved')
serializer = ManyToManySourceSerializer(source) serializer = ManyToManySourceSerializer(source)
expected = {'id': None, 'name': 'source-unsaved', 'targets': []} expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
# no query if source hasn't been created yet
with self.assertNumQueries(0): with self.assertNumQueries(0):
assert serializer.data == expected assert serializer.data == expected
@ -187,20 +156,12 @@ class PKManyToManyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-4' assert obj.name == 'target-4'
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True) serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]}, {'id': 4, 'name': 'target-4', 'sources': [1, 3]} ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': 'target-3', 'sources': [3]},
{'id': 4, 'name': 'target-4', 'sources': [1, 3]}
]
assert serializer.data == expected assert serializer.data == expected
class PKForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
@ -213,21 +174,14 @@ class PKForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1}
]
with self.assertNumQueries(1): with self.assertNumQueries(1):
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []},
]
with self.assertNumQueries(3): with self.assertNumQueries(3):
assert serializer.data == expected assert serializer.data == expected
@ -244,15 +198,9 @@ class PKForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 2}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ]
{'id': 1, 'name': 'source-1', 'target': 2},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
@ -267,26 +215,15 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected assert new_serializer.data == expected
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, ]
{'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': [1, 3]},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create(self): def test_foreign_key_create(self):
@ -296,16 +233,9 @@ class PKForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1}, {'id': 4, 'name': 'source-4', 'target': 2}, ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': 1},
{'id': 4, 'name': 'source-4', 'target': 2},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
@ -315,15 +245,9 @@ class PKForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-3' assert obj.name == 'target-3'
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': [2]}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, ]
{'id': 1, 'name': 'target-1', 'sources': [2]},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': [1, 3]},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
@ -336,10 +260,7 @@ class PKForeignKeyTests(TestCase):
def test_foreign_key_with_unsaved(self): def test_foreign_key_with_unsaved(self):
source = ForeignKeySource(name='source-unsaved') source = ForeignKeySource(name='source-unsaved')
expected = {'id': None, 'name': 'source-unsaved', 'target': None} expected = {'id': None, 'name': 'source-unsaved', 'target': None}
serializer = ForeignKeySourceSerializer(source) serializer = ForeignKeySourceSerializer(source)
# no query if source hasn't been created yet
with self.assertNumQueries(0): with self.assertNumQueries(0):
assert serializer.data == expected assert serializer.data == expected
@ -379,7 +300,6 @@ class PKForeignKeyTests(TestCase):
def test_queryset_size_with_Q_limited_choices(self): def test_queryset_size_with_Q_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target") limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save() limited_target.save()
class QLimitedChoicesSerializer(serializers.ModelSerializer): class QLimitedChoicesSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ForeignKeySourceWithQLimitedChoices model = ForeignKeySourceWithQLimitedChoices
@ -389,7 +309,6 @@ class PKForeignKeyTests(TestCase):
assertlen(queryset)==1 assertlen(queryset)==1
class PKNullableForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
@ -402,11 +321,7 @@ class PKNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
@ -416,16 +331,9 @@ class PKNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self): def test_foreign_key_create_with_valid_emptystring(self):
@ -440,16 +348,9 @@ class PKNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
@ -459,15 +360,9 @@ class PKNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self): def test_foreign_key_update_with_valid_emptystring(self):
@ -482,15 +377,9 @@ class PKNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_null_uuid_foreign_key_serializes_as_none(self): def test_null_uuid_foreign_key_serializes_as_none(self):
@ -505,7 +394,6 @@ class PKNullableForeignKeyTests(TestCase):
assert serializer.is_valid(), serializer.errors assert serializer.is_valid(), serializer.errors
class PKNullableOneToOneTests(TestCase):
def setUp(self): def setUp(self):
target = OneToOneTarget(name='target-1') target = OneToOneTarget(name='target-1')
target.save() target.save()
@ -517,14 +405,10 @@ class PKNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self): def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True) serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'nullable_source': None}, {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ]
{'id': 1, 'name': 'target-1', 'nullable_source': None},
{'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
assert serializer.data == expected assert serializer.data == expected
class OneToOnePrimaryKeyTests(TestCase):
def setUp(self): def setUp(self):
# Given: Some target models already exist # Given: Some target models already exist
@ -537,36 +421,29 @@ class OneToOnePrimaryKeyTests(TestCase):
# When: Creating a Source pointing at the id of the second Target # When: Creating a Source pointing at the id of the second Target
target_pk = self.alt_target.id target_pk = self.alt_target.id
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
# Then: The source is valid with the serializer
if not source.is_valid(): if not source.is_valid():
self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors)) self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
# Then: Saving the serializer creates a new object # Then: Saving the serializer creates a new object
new_source = source.save() new_source = source.save()
# Then: The new object has the same pk as the target object assertnew_source.pk==target_pk
self.assertEqual(new_source.pk, target_pk)
def test_one_to_one_when_primary_key_no_duplicates(self): def test_one_to_one_when_primary_key_no_duplicates(self):
# When: Creating a Source pointing at the id of the second Target # When: Creating a Source pointing at the id of the second Target
target_pk = self.target.id target_pk = self.target.id
data = {'name': 'source-1', 'target': target_pk} data = {'name': 'source-1', 'target': target_pk}
source = OneToOnePKSourceSerializer(data=data) source = OneToOnePKSourceSerializer(data=data)
# Then: The source is valid with the serializer assert source.is_valid()
self.assertTrue(source.is_valid())
# Then: Saving the serializer creates a new object
new_source = source.save() new_source = source.save()
# Then: The new object has the same pk as the target object assert new_source.pk == target_pk
self.assertEqual(new_source.pk, target_pk)
# When: Trying to create a second object
second_source = OneToOnePKSourceSerializer(data=data) second_source = OneToOnePKSourceSerializer(data=data)
self.assertFalse(second_source.is_valid()) assert not second_source.is_valid()
expected = {'target': ['one to one pk source with this target already exists.']} expected = {'target': ['one to one pk source with this target already exists.']}
self.assertDictEqual(second_source.errors, expected) assert second_source.errors == expected
def test_one_to_one_when_primary_key_does_not_exist(self): def test_one_to_one_when_primary_key_does_not_exist(self):
# Given: a target PK that does not exist # Given: a target PK that does not exist
target_pk = self.target.pk + self.alt_target.pk target_pk = self.target.pk + self.alt_target.pk
source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk}) source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
# Then: The source is not valid with the serializer assert not source.is_valid()
self.assertFalse(source.is_valid()) assert "Invalid pk" in source.errors['target'][0]
self.assertIn("Invalid pk", source.errors['target'][0]) assert "object does not exist" in source.errors['target'][0]
self.assertIn("object does not exist", source.errors['target'][0])

View File

@ -42,7 +42,6 @@ class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
# TODO: M2M Tests, FKTests (Non-nullable), One2One # TODO: M2M Tests, FKTests (Non-nullable), One2One
class SlugForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
@ -55,11 +54,7 @@ class SlugForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
with self.assertNumQueries(4): with self.assertNumQueries(4):
assert serializer.data == expected assert serializer.data == expected
@ -72,10 +67,7 @@ class SlugForeignKeyTests(TestCase):
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self): def test_reverse_foreign_key_retrieve_prefetch_related(self):
@ -91,15 +83,9 @@ class SlugForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-2'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ]
{'id': 1, 'name': 'source-1', 'target': 'target-2'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
@ -114,26 +100,15 @@ class SlugForeignKeyTests(TestCase):
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid() assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 2, 'name': 'target-2', 'sources': []}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
assert new_serializer.data == expected assert new_serializer.data == expected
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create(self): def test_foreign_key_create(self):
@ -144,16 +119,9 @@ class SlugForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True) serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'}, {'id': 4, 'name': 'source-4', 'target': 'target-2'}, ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'},
{'id': 4, 'name': 'source-4', 'target': 'target-2'},
]
assert serializer.data == expected assert serializer.data == expected
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
@ -163,15 +131,9 @@ class SlugForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'target-3' assert obj.name == 'target-3'
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, {'id': 2, 'name': 'target-2', 'sources': []}, {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, ]
{'id': 1, 'name': 'target-1', 'sources': ['source-2']},
{'id': 2, 'name': 'target-2', 'sources': []},
{'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
@ -182,7 +144,6 @@ class SlugForeignKeyTests(TestCase):
assert serializer.errors == {'target': ['This field may not be null.']} assert serializer.errors == {'target': ['This field may not be null.']}
class SlugNullableForeignKeyTests(TestCase):
def setUp(self): def setUp(self):
target = ForeignKeyTarget(name='target-1') target = ForeignKeyTarget(name='target-1')
target.save() target.save()
@ -195,11 +156,7 @@ class SlugNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
@ -209,16 +166,9 @@ class SlugNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == data assert serializer.data == data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_create_with_valid_emptystring(self): def test_foreign_key_create_with_valid_emptystring(self):
@ -233,16 +183,9 @@ class SlugNullableForeignKeyTests(TestCase):
obj = serializer.save() obj = serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
assert obj.name == 'source-4' assert obj.name == 'source-4'
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': 'target-1'}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None}, {'id': 4, 'name': 'source-4', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': 'target-1'},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None},
{'id': 4, 'name': 'source-4', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
@ -252,15 +195,9 @@ class SlugNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == data assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected
def test_foreign_key_update_with_valid_emptystring(self): def test_foreign_key_update_with_valid_emptystring(self):
@ -275,13 +212,7 @@ class SlugNullableForeignKeyTests(TestCase):
assert serializer.is_valid() assert serializer.is_valid()
serializer.save() serializer.save()
assert serializer.data == expected_data assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True) serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [ expected = [ {'id': 1, 'name': 'source-1', 'target': None}, {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': None} ]
{'id': 1, 'name': 'source-1', 'target': None},
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': None}
]
assert serializer.data == expected assert serializer.data == expected

View File

@ -49,11 +49,10 @@ class DummyTestModel(models.Model):
name = models.CharField(max_length=42, default='') name = models.CharField(max_length=42, default='')
class BasicRendererTests(TestCase):
def test_expected_results(self): def test_expected_results(self):
for value, renderer_cls, expected in expected_results: for value, renderer_cls, expected in expected_results:
output = renderer_cls().render(value) output = renderer_cls().render(value)
self.assertEqual(output, expected) assert output == expected
class RendererA(BaseRenderer): class RendererA(BaseRenderer):
@ -144,7 +143,6 @@ class POSTDeniedView(APIView):
return Response() return Response()
class DocumentingRendererTests(TestCase):
def test_only_permitted_forms_are_displayed(self): def test_only_permitted_forms_are_displayed(self):
view = POSTDeniedView.as_view() view = POSTDeniedView.as_view()
request = APIRequestFactory().get('/') request = APIRequestFactory().get('/')
@ -155,90 +153,82 @@ class DocumentingRendererTests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_renderers') @override_settings(ROOT_URLCONF='tests.test_renderers')
class RendererEndToEndTests(TestCase):
""" """
End-to-end testing of renderers using an RendererMixin on a generic view. End-to-end testing of renderers using an RendererMixin on a generic view.
""" """
deftest_default_renderer_serializes_content(self): deftest_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/') resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self): def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests.""" """No response must be included in HEAD requests."""
resp = self.client.head('/') resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, b'') assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*') resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self): def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)""" (In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self): def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)""" (In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_unsatisfiable_accept_header_on_request_returns_406_status(self): def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar') resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) assert resp.status_code == status.HTTP_406_NOT_ACCEPTABLE
def test_specified_renderer_serializes_content_on_format_query(self): def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
param = '?%s=%s' % ( param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
api_settings.URL_FORMAT_OVERRIDE,
RendererB.format
)
resp = self.client.get('/' + param) resp = self.client.get('/' + param)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self): def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching """If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/something.formatb') resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the response."""
param = '?%s=%s' % ( param = '?%s=%s' % ( api_settings.URL_FORMAT_OVERRIDE, RendererB.format )
api_settings.URL_FORMAT_OVERRIDE, resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type)
RendererB.format assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
resp = self.client.get('/' + param, assert resp.status_code == DUMMYSTATUS
HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_parse_error_renderers_browsable_api(self): def test_parse_error_renderers_browsable_api(self):
"""Invalid data should still render the browsable API correctly.""" """Invalid data should still render the browsable API correctly."""
resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) assert resp.status_code == status.HTTP_400_BAD_REQUEST
def test_204_no_content_responses_have_no_content_type_set(self): def test_204_no_content_responses_have_no_content_type_set(self):
""" """
@ -247,8 +237,8 @@ class RendererEndToEndTests(TestCase):
https://github.com/encode/django-rest-framework/issues/1196 https://github.com/encode/django-rest-framework/issues/1196
""" """
resp = self.client.get('/empty') resp = self.client.get('/empty')
self.assertEqual(resp.get('Content-Type', None), None) assert resp.get('Content-Type', None) == None
self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) assert resp.status_code == status.HTTP_204_NO_CONTENT
def test_contains_headers_of_api_response(self): def test_contains_headers_of_api_response(self):
""" """
@ -275,7 +265,6 @@ def strip_trailing_whitespace(content):
return re.sub(' +\n', '\n', content) return re.sub(' +\n', '\n', content)
class BaseRendererTests(TestCase):
""" """
Tests BaseRenderer Tests BaseRenderer
""" """
@ -287,31 +276,29 @@ class BaseRendererTests(TestCase):
BaseRenderer().render('test') BaseRenderer().render('test')
class JSONRendererTests(TestCase):
""" """
Tests specific to the JSON Renderer Tests specific to the JSON Renderer
""" """
deftest_render_lazy_strings(self): deftest_render_lazy_strings(self):
""" """
JSONRenderer should deal with lazy translated strings. JSONRenderer should deal with lazy translated strings.
""" """
ret = JSONRenderer().render(_('test')) ret = JSONRenderer().render(_('test'))
self.assertEqual(ret, b'"test"') assert ret == b'"test"'
def test_render_queryset_values(self): def test_render_queryset_values(self):
o = DummyTestModel.objects.create(name='dummy') o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values('id', 'name') qs = DummyTestModel.objects.values('id', 'name')
ret = JSONRenderer().render(qs) ret = JSONRenderer().render(qs)
data = json.loads(ret.decode()) data = json.loads(ret.decode())
self.assertEqual(data, [{'id': o.id, 'name': o.name}]) assert data == [{'id': o.id, 'name': o.name}]
def test_render_queryset_values_list(self): def test_render_queryset_values_list(self):
o = DummyTestModel.objects.create(name='dummy') o = DummyTestModel.objects.create(name='dummy')
qs = DummyTestModel.objects.values_list('id', 'name') qs = DummyTestModel.objects.values_list('id', 'name')
ret = JSONRenderer().render(qs) ret = JSONRenderer().render(qs)
data = json.loads(ret.decode()) data = json.loads(ret.decode())
self.assertEqual(data, [[o.id, o.name]]) assert data == [[o.id, o.name]]
def test_render_dict_abc_obj(self): def test_render_dict_abc_obj(self):
class Dict(MutableMapping): class Dict(MutableMapping):
@ -341,7 +328,7 @@ class JSONRendererTests(TestCase):
x[2]=3 x[2]=3
ret=JSONRenderer().render(x) ret=JSONRenderer().render(x)
data=json.loads(ret.decode()) data=json.loads(ret.decode())
self.assertEqual(data, {'key': 'string value', '2': 3}) assertdata=={'key':'string value','2':3}
def test_render_obj_with_getitem(self): def test_render_obj_with_getitem(self):
class DictLike: class DictLike:
@ -356,13 +343,11 @@ class JSONRendererTests(TestCase):
x = DictLike() x = DictLike()
x.set({'a':1,'b':'string'}) x.set({'a':1,'b':'string'})
with self.assertRaises(TypeError): withpytest.raises(TypeError):
JSONRenderer().render(x) JSONRenderer().render(x)
def test_float_strictness(self): def test_float_strictness(self):
renderer = JSONRenderer() renderer = JSONRenderer()
# Default to strict
for value in [float('inf'), float('-inf'), float('nan')]: for value in [float('inf'), float('-inf'), float('nan')]:
with pytest.raises(ValueError): with pytest.raises(ValueError):
renderer.render(value) renderer.render(value)
@ -379,8 +364,7 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']} obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library. assert content.decode() == _flat_repr
self.assertEqual(content.decode(), _flat_repr)
def test_with_content_type_args(self): def test_with_content_type_args(self):
""" """
@ -389,10 +373,9 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']} obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2') content = renderer.render(obj, 'application/json; indent=2')
self.assertEqual(strip_trailing_whitespace(content.decode()), _indented_repr) assert strip_trailing_whitespace(content.decode()) == _indented_repr
class UnicodeJSONRendererTests(TestCase):
""" """
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
@ -400,7 +383,7 @@ class UnicodeJSONRendererTests(TestCase):
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode()) assert content == '{"countries":["United Kingdom","France","España"]}'.encode()
def test_u2028_u2029(self): def test_u2028_u2029(self):
# The \u2028 and \u2029 characters should be escaped, # The \u2028 and \u2029 characters should be escaped,
@ -409,10 +392,9 @@ class UnicodeJSONRendererTests(TestCase):
obj = {'should_escape': '\u2028\u2029'} obj = {'should_escape': '\u2028\u2029'}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode()) assert content == '{"should_escape":"\\u2028\\u2029"}'.encode()
class AsciiJSONRendererTests(TestCase):
""" """
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
@ -422,12 +404,11 @@ class AsciiJSONRendererTests(TestCase):
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer=AsciiJSONRenderer() renderer=AsciiJSONRenderer()
content=renderer.render(obj,'application/json') content=renderer.render(obj,'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()) assertcontent=='{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode()
# Tests for caching issue, #346 # Tests for caching issue, #346
@override_settings(ROOT_URLCONF='tests.test_renderers') @override_settings(ROOT_URLCONF='tests.test_renderers')
class CacheRenderTest(TestCase):
""" """
Tests specific to caching responses Tests specific to caching responses
""" """
@ -476,7 +457,6 @@ class TestJSONIndentationStyles:
assert renderer.render(data) == b'{"a": 1, "b": 2}' assert renderer.render(data) == b'{"a": 1, "b": 2}'
class TestHiddenFieldHTMLFormRenderer(TestCase):
def test_hidden_field_rendering(self): def test_hidden_field_rendering(self):
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
published = serializers.HiddenField(default=True) published = serializers.HiddenField(default=True)
@ -489,7 +469,6 @@ class TestHiddenFieldHTMLFormRenderer(TestCase):
assertrendered=='' assertrendered==''
class TestHTMLFormRenderer(TestCase):
def setUp(self): def setUp(self):
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
test_field = serializers.CharField() test_field = serializers.CharField()
@ -500,31 +479,23 @@ class TestHTMLFormRenderer(TestCase):
def test_render_with_default_args(self): def test_render_with_default_args(self):
self.serializer.is_valid() self.serializer.is_valid()
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data) result = renderer.render(self.serializer.data)
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText)
def test_render_with_provided_args(self): def test_render_with_provided_args(self):
self.serializer.is_valid() self.serializer.is_valid()
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
result = renderer.render(self.serializer.data, None, {}) result = renderer.render(self.serializer.data, None, {})
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText)
class TestChoiceFieldHTMLFormRenderer(TestCase):
""" """
Test rendering ChoiceField with HTMLFormRenderer. Test rendering ChoiceField with HTMLFormRenderer.
""" """
defsetUp(self): defsetUp(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
test_field = serializers.ChoiceField(choices=choices, test_field = serializers.ChoiceField(choices=choices, initial=2)
initial=2)
self.TestSerializer = TestSerializer self.TestSerializer = TestSerializer
self.renderer=HTMLFormRenderer() self.renderer=HTMLFormRenderer()
@ -532,76 +503,55 @@ class TestChoiceFieldHTMLFormRenderer(TestCase):
def test_render_initial_option(self): def test_render_initial_option(self):
serializer = self.TestSerializer() serializer = self.TestSerializer()
result = self.renderer.render(serializer.data) result = self.renderer.render(serializer.data)
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="2" selected>Option2</option>', result)
self.assertInHTML('<option value="2" selected>Option2</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result) self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="12">Option12</option>', result) self.assertInHTML('<option value="12">Option12</option>', result)
def test_render_selected_option(self): def test_render_selected_option(self):
serializer = self.TestSerializer(data={'test_field': '12'}) serializer = self.TestSerializer(data={'test_field': '12'})
serializer.is_valid() serializer.is_valid()
result = self.renderer.render(serializer.data) result = self.renderer.render(serializer.data)
assert isinstance(result, SafeText)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="12" selected>Option12</option>', result)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result) self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result) self.assertInHTML('<option value="2">Option2</option>', result)
class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
""" """
Test rendering MultipleChoiceField with HTMLFormRenderer. Test rendering MultipleChoiceField with HTMLFormRenderer.
""" """
defsetUp(self): defsetUp(self):
self.renderer = HTMLFormRenderer() self.renderer = HTMLFormRenderer()
def test_render_selected_option_with_string_option_ids(self): def test_render_selected_option_with_string_option_ids(self):
choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), ('}', 'OptionBrace'))
('}', 'OptionBrace'))
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices) test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']}) serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid() serializer.is_valid()
result=self.renderer.render(serializer.data) result=self.renderer.render(serializer.data)
assertisinstance(result, SafeText)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="12" selected>Option12</option>',result)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>',result) self.assertInHTML('<option value="1">Option1</option>',result)
self.assertInHTML('<option value="2">Option2</option>',result) self.assertInHTML('<option value="2">Option2</option>',result)
self.assertInHTML('<option value="}">OptionBrace</option>',result) self.assertInHTML('<option value="}">OptionBrace</option>',result)
def test_render_selected_option_with_integer_option_ids(self): def test_render_selected_option_with_integer_option_ids(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer): class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices) test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']}) serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid() serializer.is_valid()
result=self.renderer.render(serializer.data) result=self.renderer.render(serializer.data)
assertisinstance(result, SafeText)
self.assertIsInstance(result, SafeText) self.assertInHTML('<option value="12" selected>Option12</option>',result)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>',result) self.assertInHTML('<option value="1">Option1</option>',result)
self.assertInHTML('<option value="2">Option2</option>',result) self.assertInHTML('<option value="2">Option2</option>',result)
class StaticHTMLRendererTests(TestCase):
""" """
Tests specific for Static HTML Renderer Tests specific for Static HTML Renderer
""" """
@ -614,10 +564,7 @@ class StaticHTMLRendererTests(TestCase):
assert result == data assert result == data
def test_static_renderer_with_exception(self): def test_static_renderer_with_exception(self):
context = { context = { 'response': Response(status=500, exception=True), 'request': Request(HttpRequest()) }
'response': Response(status=500, exception=True),
'request': Request(HttpRequest())
}
result = self.renderer.render({}, renderer_context=context) result = self.renderer.render({}, renderer_context=context)
assert result == '500 Internal Server Error' assert result == '500 Internal Server Error'
@ -658,7 +605,6 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
assert '>Extra list action<' in resp.content.decode() assert '>Extra list action<' in resp.content.decode()
class AdminRendererTests(TestCase):
def setUp(self): def setUp(self):
self.renderer = AdminRenderer() self.renderer = AdminRenderer()
@ -669,24 +615,16 @@ class AdminRendererTests(TestCase):
request = Request(HttpRequest()) request = Request(HttpRequest())
request.build_absolute_uri=lambda:'http://example.com' request.build_absolute_uri=lambda:'http://example.com'
response=Response(status=201,headers={'Location':'/test'}) response=Response(status=201,headers={'Location':'/test'})
context = { context={'view':DummyView(),'request':request,'response':response}
'view': DummyView(), result=self.renderer.render(data={'test':'test'},renderer_context=context)
'request': request,
'response': response
}
result = self.renderer.render(data={'test': 'test'},
renderer_context=context)
assertresult=='' assertresult==''
assertresponse.status_code==status.HTTP_303_SEE_OTHER assertresponse.status_code==status.HTTP_303_SEE_OTHER
assertresponse['Location']=='http://example.com' assertresponse['Location']=='http://example.com'
def test_render_dict(self): def test_render_dict(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
renderer_classes = (AdminRenderer, ) renderer_classes = (AdminRenderer, )
def get(self, request): def get(self, request):
return Response({'foo': 'a string'}) return Response({'foo': 'a string'})
view = DummyView.as_view() view = DummyView.as_view()
@ -697,10 +635,8 @@ class AdminRendererTests(TestCase):
def test_render_dict_with_items_key(self): def test_render_dict_with_items_key(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
renderer_classes = (AdminRenderer, ) renderer_classes = (AdminRenderer, )
def get(self, request): def get(self, request):
return Response({'items': 'a string'}) return Response({'items': 'a string'})
@ -712,10 +648,8 @@ class AdminRendererTests(TestCase):
def test_render_dict_with_iteritems_key(self): def test_render_dict_with_iteritems_key(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
renderer_classes = (AdminRenderer, ) renderer_classes = (AdminRenderer, )
def get(self, request): def get(self, request):
return Response({'iteritems': 'a string'}) return Response({'iteritems': 'a string'})
@ -727,12 +661,10 @@ class AdminRendererTests(TestCase):
def test_get_result_url(self): def test_get_result_url(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyGenericViewsetLike(APIView): class DummyGenericViewsetLike(APIView):
lookup_field = 'test' lookup_field = 'test'
def reverse_action(view, *args, **kwargs): def reverse_action(view, *args, **kwargs):
self.assertEqual(kwargs['kwargs']['test'], 1) assert kwargs['kwargs']['test'] == 1
return '/example/' return '/example/'
# get the view instance instead of the view function # get the view instance instead of the view function
@ -740,13 +672,11 @@ class AdminRendererTests(TestCase):
request=factory.get('/') request=factory.get('/')
response=view(request) response=view(request)
view=response.renderer_context['view'] view=response.renderer_context['view']
assertself.renderer.get_result_url({'test':1},view)=='/example/'
self.assertEqual(self.renderer.get_result_url({'test': 1}, view), '/example/') assertself.renderer.get_result_url({},view)is None
self.assertIsNone(self.renderer.get_result_url({}, view))
def test_get_result_url_no_result(self): def test_get_result_url_no_result(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
lookup_field = 'test' lookup_field = 'test'
@ -755,16 +685,13 @@ class AdminRendererTests(TestCase):
request=factory.get('/') request=factory.get('/')
response=view(request) response=view(request)
view=response.renderer_context['view'] view=response.renderer_context['view']
assertself.renderer.get_result_url({'test':1},view)is None
self.assertIsNone(self.renderer.get_result_url({'test': 1}, view)) assertself.renderer.get_result_url({},view)is None
self.assertIsNone(self.renderer.get_result_url({}, view))
def test_get_context_result_urls(self): def test_get_context_result_urls(self):
factory = APIRequestFactory() factory = APIRequestFactory()
class DummyView(APIView): class DummyView(APIView):
lookup_field = 'test' lookup_field = 'test'
def reverse_action(view, url_name, args=None, kwargs=None): def reverse_action(view, url_name, args=None, kwargs=None):
return '/%s/%d' % (url_name, kwargs['test']) return '/%s/%d' % (url_name, kwargs['test'])
@ -772,71 +699,39 @@ class AdminRendererTests(TestCase):
view = DummyView.as_view() view = DummyView.as_view()
request=factory.get('/') request=factory.get('/')
response=view(request) response=view(request)
data=[{'test':1},{'url':'/example','test':2},{'url':None,'test':3},{},]
data = [ context={'view':DummyView(),'request':Request(request),'response':response}
{'test': 1},
{'url': '/example', 'test': 2},
{'url': None, 'test': 3},
{},
]
context = {
'view': DummyView(),
'request': Request(request),
'response': response
}
context=self.renderer.get_context(data,None,context) context=self.renderer.get_context(data,None,context)
results=context['results'] results=context['results']
assertlen(results)==4
self.assertEqual(len(results), 4) assertresults[0]['url']=='/detail/1'
self.assertEqual(results[0]['url'], '/detail/1') assertresults[1]['url']=='/example'
self.assertEqual(results[1]['url'], '/example') assertresults[2]['url']==None
self.assertEqual(results[2]['url'], None) assert'url'not inresults[3]
self.assertNotIn('url', results[3])
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestDocumentationRenderer(TestCase):
def test_document_with_link_named_data(self): def test_document_with_link_named_data(self):
""" """
Ref #5395: Doc's `document.data` would fail with a Link named "data". Ref #5395: Doc's `document.data` would fail with a Link named "data".
As per #4972, use templatetag instead. As per #4972, use templatetag instead.
""" """
document = coreapi.Document( document = coreapi.Document( title='Data Endpoint API', url='https://api.example.org/', content={ 'data': coreapi.Link( url='/data/', action='get', fields=[], description='Return data.' ) } )
title='Data Endpoint API',
url='https://api.example.org/',
content={
'data': coreapi.Link(
url='/data/',
action='get',
fields=[],
description='Return data.'
)
}
)
factory = APIRequestFactory() factory = APIRequestFactory()
request = factory.get('/') request = factory.get('/')
renderer = DocumentationRenderer() renderer = DocumentationRenderer()
html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request}) html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
assert '<h1>Data Endpoint API</h1>' in html assert '<h1>Data Endpoint API</h1>' in html
def test_shell_code_example_rendering(self): def test_shell_code_example_rendering(self):
template = loader.get_template('rest_framework/docs/langs/shell.html') template = loader.get_template('rest_framework/docs/langs/shell.html')
context = { context = { 'document': coreapi.Document(url='https://api.example.org/'), 'link_key': 'testcases > list', 'link': coreapi.Link(url='/data/', action='get', fields=[]), }
'document': coreapi.Document(url='https://api.example.org/'),
'link_key': 'testcases > list',
'link': coreapi.Link(url='/data/', action='get', fields=[]),
}
html = template.render(context) html = template.render(context)
assert 'testcases list' in html assert 'testcases list' in html
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
class TestSchemaJSRenderer(TestCase):
def test_schemajs_output(self): def test_schemajs_output(self):
""" """
@ -845,9 +740,7 @@ class TestSchemaJSRenderer(TestCase):
""" """
factory = APIRequestFactory() factory = APIRequestFactory()
request = factory.get('/') request = factory.get('/')
renderer = SchemaJSRenderer() renderer = SchemaJSRenderer()
output = renderer.render('data', renderer_context={"request": request}) output = renderer.render('data', renderer_context={"request": request})
assert "'ImRhdGEi'" in output assert "'ImRhdGEi'" in output
assert "'b'ImRhdGEi''" not in output assert "'b'ImRhdGEi''" not in output

View File

@ -23,16 +23,9 @@ from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
factory = APIRequestFactory() factory = APIRequestFactory()
class TestInitializer(TestCase):
def test_request_type(self): def test_request_type(self):
request = Request(factory.get('/')) request = Request(factory.get('/'))
message = ( 'The `request` argument must be an instance of ' '`django.http.HttpRequest`, not `rest_framework.request.Request`.' )
message = (
'The `request` argument must be an instance of '
'`django.http.HttpRequest`, not `rest_framework.request.Request`.'
)
with self.assertRaisesMessage(AssertionError, message): with self.assertRaisesMessage(AssertionError, message):
Request(request) Request(request)
@ -50,7 +43,6 @@ class PlainTextParser(BaseParser):
return stream.read() return stream.read()
class TestContentParsing(TestCase):
def test_standard_behaviour_determines_no_content_GET(self): def test_standard_behaviour_determines_no_content_GET(self):
""" """
Ensure request.data returns empty QueryDict for GET request. Ensure request.data returns empty QueryDict for GET request.
@ -160,7 +152,6 @@ urlpatterns = [
@override_settings( @override_settings(
ROOT_URLCONF='tests.test_request', ROOT_URLCONF='tests.test_request',
FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler']) FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler'])
class FileUploadTests(TestCase):
def test_fileuploads_closed_at_request_end(self): def test_fileuploads_closed_at_request_end(self):
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
@ -168,13 +159,11 @@ class FileUploadTests(TestCase):
# sanity check that file was processed # sanity check that file was processed
assert len(response.data) == 1 assert len(response.data) == 1
forfileinresponse.data: forfileinresponse.data:
assert not os.path.exists(file) assert not os.path.exists(file)
@override_settings(ROOT_URLCONF='tests.test_request') @override_settings(ROOT_URLCONF='tests.test_request')
class TestContentParsingWithAuthentication(TestCase):
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
@ -188,15 +177,12 @@ class TestContentParsingWithAuthentication(TestCase):
doesn't log in. doesn't log in.
""" """
content = {'example': 'example'} content = {'example': 'example'}
response = self.client.post('/', content) response = self.client.post('/', content)
assert status.HTTP_200_OK == response.status_code assert status.HTTP_200_OK == response.status_code
response = self.csrf_client.post('/', content) response = self.csrf_client.post('/', content)
assert status.HTTP_200_OK == response.status_code assert status.HTTP_200_OK == response.status_code
class TestUserSetter(TestCase):
def setUp(self): def setUp(self):
# Pass request object through session middleware so session is # Pass request object through session middleware so session is
@ -205,7 +191,6 @@ class TestUserSetter(TestCase):
self.request = Request(self.wrapped_request) self.request = Request(self.wrapped_request)
SessionMiddleware().process_request(self.wrapped_request) SessionMiddleware().process_request(self.wrapped_request)
AuthenticationMiddleware().process_request(self.wrapped_request) AuthenticationMiddleware().process_request(self.wrapped_request)
User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
self.user = authenticate(username='ringo', password='yellow') self.user = authenticate(username='ringo', password='yellow')
@ -237,11 +222,7 @@ class TestUserSetter(TestCase):
self.MISSPELLED_NAME_THAT_DOESNT_EXIST self.MISSPELLED_NAME_THAT_DOESNT_EXIST
request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),)) request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),))
# The middleware processes the underlying Django request, sets anonymous user
assertself.wrapped_request.user.is_anonymous assertself.wrapped_request.user.is_anonymous
# The DRF request object does not have a user and should run authenticators
expected=r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'" expected=r"no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'"
withpytest.raises(WrappedAttributeError,match=expected): withpytest.raises(WrappedAttributeError,match=expected):
request.user request.user
@ -253,14 +234,12 @@ class TestUserSetter(TestCase):
login(request, self.user) login(request, self.user)
class TestAuthSetter(TestCase):
def test_auth_can_be_set(self): def test_auth_can_be_set(self):
request = Request(factory.get('/')) request = Request(factory.get('/'))
request.auth = 'DUMMY' request.auth = 'DUMMY'
assert request.auth == 'DUMMY' assert request.auth == 'DUMMY'
class TestSecure(TestCase):
def test_default_secure_false(self): def test_default_secure_false(self):
request = Request(factory.get('/', secure=False)) request = Request(factory.get('/', secure=False))
@ -271,15 +250,12 @@ class TestSecure(TestCase):
assert request.scheme == 'https' assert request.scheme == 'https'
class TestHttpRequest(TestCase):
def test_attribute_access_proxy(self): def test_attribute_access_proxy(self):
http_request = factory.get('/') http_request = factory.get('/')
request = Request(http_request) request = Request(http_request)
inner_sentinel = object() inner_sentinel = object()
http_request.inner_property = inner_sentinel http_request.inner_property = inner_sentinel
assert request.inner_property is inner_sentinel assert request.inner_property is inner_sentinel
outer_sentinel = object() outer_sentinel = object()
request.inner_property = outer_sentinel request.inner_property = outer_sentinel
assert request.inner_property is outer_sentinel assert request.inner_property is outer_sentinel
@ -288,7 +264,6 @@ class TestHttpRequest(TestCase):
# ensure the exception message is not for the underlying WSGIRequest # ensure the exception message is not for the underlying WSGIRequest
http_request = factory.get('/') http_request = factory.get('/')
request = Request(http_request) request = Request(http_request)
message = "'Request' object has no attribute 'inner_property'" message = "'Request' object has no attribute 'inner_property'"
with self.assertRaisesMessage(AttributeError, message): with self.assertRaisesMessage(AttributeError, message):
request.inner_property request.inner_property
@ -301,12 +276,8 @@ class TestHttpRequest(TestCase):
""" """
response = APIClient().post('/echo/', data={'a': 'b'}, format='json') response = APIClient().post('/echo/', data={'a': 'b'}, format='json')
request = response.renderer_context['request'] request = response.renderer_context['request']
# ensure that request stream was consumed by json parser
assert request.content_type.startswith('application/json') assert request.content_type.startswith('application/json')
assert response.data == {'a': 'b'} assert response.data == {'a': 'b'}
# pass same HttpRequest to view, stream already consumed
with pytest.raises(RawPostDataException): with pytest.raises(RawPostDataException):
EchoView.as_view()(request._request) EchoView.as_view()(request._request)
@ -320,15 +291,9 @@ class TestHttpRequest(TestCase):
""" """
response = APIClient().post('/echo/', data={'a': 'b'}) response = APIClient().post('/echo/', data={'a': 'b'})
request = response.renderer_context['request'] request = response.renderer_context['request']
# ensure that request stream was consumed by form parser
assert request.content_type.startswith('multipart/form-data') assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']} assert response.data == {'a': ['b']}
# pass same HttpRequest to view, form data set on underlying request
response = EchoView.as_view()(request._request) response = EchoView.as_view()(request._request)
request = response.renderer_context['request'] request = response.renderer_context['request']
# ensure that request stream was consumed by form parser
assert request.content_type.startswith('multipart/form-data') assert request.content_type.startswith('multipart/form-data')
assert response.data == {'a': ['b']} assert response.data == {'a': ['b']}

View File

@ -131,93 +131,86 @@ urlpatterns = [
# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... # TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class RendererIntegrationTests(TestCase):
""" """
End-to-end testing of renderers using an ResponseMixin on a generic view. End-to-end testing of renderers using an ResponseMixin on a generic view.
""" """
deftest_default_renderer_serializes_content(self): deftest_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/') resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_head_method_serializes_no_content(self): def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests.""" """No response must be included in HEAD requests."""
resp = self.client.head('/') resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, b'') assert resp.content == b''
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*') resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_default_case(self): def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)""" (In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererA.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_A_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_non_default_case(self): def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)""" (In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_query(self): def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format) resp = self.client.get('/?format=%s' % RendererB.format)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_serializes_content_on_format_kwargs(self): def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching """If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/something.formatb') resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.status_code, DUMMYSTATUS) assert resp.status_code == DUMMYSTATUS
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format, resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type)
HTTP_ACCEPT=RendererB.media_type) assert resp['Content-Type'] == RendererB.media_type + '; charset=utf-8'
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') assert resp.content == RENDERER_B_SERIALIZER(DUMMYCONTENT)
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) assert resp.status_code == DUMMYSTATUS
self.assertEqual(resp.status_code, DUMMYSTATUS)
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class UnsupportedMediaTypeTests(TestCase):
def test_should_allow_posting_json(self): def test_should_allow_posting_json(self):
response = self.client.post('/json', data='{"test": 123}', content_type='application/json') response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
assert response.status_code == 200
self.assertEqual(response.status_code, 200)
def test_should_not_allow_posting_xml(self): def test_should_not_allow_posting_xml(self):
response = self.client.post('/json', data='<test>123</test>', content_type='application/xml') response = self.client.post('/json', data='<test>123</test>', content_type='application/xml')
assert response.status_code == 415
self.assertEqual(response.status_code, 415)
def test_should_not_allow_posting_a_form(self): def test_should_not_allow_posting_a_form(self):
response = self.client.post('/json', data={'test': 123}) response = self.client.post('/json', data={'test': 123})
assert response.status_code == 415
self.assertEqual(response.status_code, 415)
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue122Tests(TestCase):
""" """
Tests that covers #122. Tests that covers #122.
""" """
@ -235,19 +228,17 @@ class Issue122Tests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue467Tests(TestCase):
""" """
Tests for #467 Tests for #467
""" """
deftest_form_has_label_and_help_text(self): deftest_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class Issue807Tests(TestCase):
""" """
Covers #807 Covers #807
""" """
@ -258,7 +249,7 @@ class Issue807Tests(TestCase):
headers = {"HTTP_ACCEPT": RendererA.media_type} headers = {"HTTP_ACCEPT": RendererA.media_type}
resp = self.client.get('/', **headers) resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererA.media_type, 'utf-8') expected = "{}; charset={}".format(RendererA.media_type, 'utf-8')
self.assertEqual(expected, resp['Content-Type']) assert expected == resp['Content-Type']
def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
""" """
@ -268,7 +259,7 @@ class Issue807Tests(TestCase):
headers = {"HTTP_ACCEPT": RendererC.media_type} headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/', **headers) resp = self.client.get('/', **headers)
expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset) expected = "{}; charset={}".format(RendererC.media_type, RendererC.charset)
self.assertEqual(expected, resp['Content-Type']) assert expected == resp['Content-Type']
def test_content_type_set_explicitly_on_response(self): def test_content_type_set_explicitly_on_response(self):
""" """
@ -276,10 +267,10 @@ class Issue807Tests(TestCase):
""" """
headers = {"HTTP_ACCEPT": RendererC.media_type} headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/setbyview', **headers) resp = self.client.get('/setbyview', **headers)
self.assertEqual('setbyview', resp['Content-Type']) assert 'setbyview' == resp['Content-Type']
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') assert resp['Content-Type'] == 'text/html; charset=utf-8'
# self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')

View File

@ -30,7 +30,6 @@ class MockVersioningScheme:
@override_settings(ROOT_URLCONF='tests.test_reverse') @override_settings(ROOT_URLCONF='tests.test_reverse')
class ReverseTests(TestCase):
""" """
Tests for fully qualified URLs when using `reverse`. Tests for fully qualified URLs when using `reverse`.
""" """
@ -42,13 +41,11 @@ class ReverseTests(TestCase):
def test_reverse_with_versioning_scheme(self): def test_reverse_with_versioning_scheme(self):
request = factory.get('/view') request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme() request.versioning_scheme = MockVersioningScheme()
url = reverse('view', request=request) url = reverse('view', request=request)
assert url == 'http://scheme-reversed/view' assert url == 'http://scheme-reversed/view'
def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self): def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self):
request = factory.get('/view') request = factory.get('/view')
request.versioning_scheme = MockVersioningScheme(raise_error=True) request.versioning_scheme = MockVersioningScheme(raise_error=True)
url = reverse('view', request=request) url = reverse('view', request=request)
assert url == 'http://testserver/view' assert url == 'http://testserver/view'

View File

@ -214,7 +214,6 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"} assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"}
class TestLookupValueRegex(TestCase):
""" """
Ensure the router honors lookup_value_regex when applied Ensure the router honors lookup_value_regex when applied
to the viewset. to the viewset.
@ -264,7 +263,6 @@ class TestLookupUrlKwargs(URLPatternsTestCase, TestCase):
assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"} assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"}
class TestTrailingSlashIncluded(TestCase):
def setUp(self): def setUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all() queryset = RouterTestModel.objects.all()
@ -279,7 +277,6 @@ class TestTrailingSlashIncluded(TestCase):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == get_regex_pattern(self.urls[idx])
class TestTrailingSlashRemoved(TestCase):
def setUp(self): def setUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all() queryset = RouterTestModel.objects.all()
@ -294,7 +291,6 @@ class TestTrailingSlashRemoved(TestCase):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == get_regex_pattern(self.urls[idx])
class TestNameableRoot(TestCase):
def setUp(self): def setUp(self):
class NoteViewSet(viewsets.ModelViewSet): class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all() queryset = RouterTestModel.objects.all()
@ -309,21 +305,16 @@ class TestNameableRoot(TestCase):
assert expected == self.urls[-1].name assert expected == self.urls[-1].name
class TestActionKeywordArgs(TestCase):
""" """
Ensure keyword arguments passed in the `@action` decorator Ensure keyword arguments passed in the `@action` decorator
are properly handled. Refs #940. are properly handled. Refs #940.
""" """
defsetUp(self): defsetUp(self):
class TestViewSet(viewsets.ModelViewSet): class TestViewSet(viewsets.ModelViewSet):
permission_classes = [] permission_classes = []
@action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny]) @action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny])
def custom(self, request, *args, **kwargs): def custom(self, request, *args, **kwargs):
return Response({ return Response({ 'permission_classes': self.permission_classes })
'permission_classes': self.permission_classes
})
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'test',TestViewSet,basename='test') self.router.register(r'test',TestViewSet,basename='test')
@ -335,24 +326,19 @@ class TestActionKeywordArgs(TestCase):
assert response.data == {'permission_classes': [permissions.AllowAny]} assert response.data == {'permission_classes': [permissions.AllowAny]}
class TestActionAppliedToExistingRoute(TestCase):
""" """
Ensure `@action` decorator raises an except when applied Ensure `@action` decorator raises an except when applied
to an existing route to an existing route
""" """
deftest_exception_raised_when_action_applied_to_existing_route(self): deftest_exception_raised_when_action_applied_to_existing_route(self):
class TestViewSet(viewsets.ModelViewSet): class TestViewSet(viewsets.ModelViewSet):
@action(methods=['post'], detail=True) @action(methods=['post'], detail=True)
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
return Response({ return Response({ 'hello': 'world' })
'hello': 'world'
})
self.router = SimpleRouter() self.router = SimpleRouter()
self.router.register(r'test',TestViewSet,basename='test') self.router.register(r'test',TestViewSet,basename='test')
withpytest.raises(ImproperlyConfigured): withpytest.raises(ImproperlyConfigured):
self.router.urls self.router.urls
@ -390,28 +376,17 @@ class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet):
pass pass
class TestDynamicListAndDetailRouter(TestCase):
def setUp(self): def setUp(self):
self.router = SimpleRouter() self.router = SimpleRouter()
def _test_list_and_detail_route_decorators(self, viewset): def _test_list_and_detail_route_decorators(self, viewset):
routes = self.router.get_routes(viewset) routes = self.router.get_routes(viewset)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path') MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
# Make sure all these endpoints exist and none have been clobbered for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), MethodNamesMap('list_route_get', 'list_route_get'), MethodNamesMap('list_route_post', 'list_route_post'), MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), MethodNamesMap('detail_route_get', 'detail_route_get'), MethodNamesMap('detail_route_post', 'detail_route_post') ]):
for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'),
MethodNamesMap('list_route_get', 'list_route_get'),
MethodNamesMap('list_route_post', 'list_route_post'),
MethodNamesMap('detail_custom_route_get', 'detail_custom-route'),
MethodNamesMap('detail_route_get', 'detail_route_get'),
MethodNamesMap('detail_route_post', 'detail_route_post')
]):
route = decorator_routes[i] route = decorator_routes[i]
# check url listing
method_name = endpoint.method_name method_name = endpoint.method_name
url_path = endpoint.url_path url_path = endpoint.url_path
if method_name.startswith('list_'): if method_name.startswith('list_'):
assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path) assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
else: else:
@ -490,14 +465,11 @@ class TestViewInitkwargs(URLPatternsTestCase, TestCase):
assert initkwargs['basename'] == 'routertestmodel' assert initkwargs['basename'] == 'routertestmodel'
class TestBaseNameRename(TestCase):
def test_base_name_and_basename_assertion(self): def test_base_name_and_basename_assertion(self):
router = SimpleRouter() router = SimpleRouter()
msg = "Do not provide both the `basename` and `base_name` arguments." msg = "Do not provide both the `basename` and `base_name` arguments."
with warnings.catch_warnings(record=True) as w, \ with warnings.catch_warnings(record=True) as w, self.assertRaisesMessage(AssertionError, msg):
self.assertRaisesMessage(AssertionError, msg):
warnings.simplefilter('always') warnings.simplefilter('always')
router.register('mock', MockViewSet, 'mock', base_name='mock') router.register('mock', MockViewSet, 'mock', base_name='mock')
@ -507,7 +479,6 @@ class TestBaseNameRename(TestCase):
def test_base_name_argument_deprecation(self): def test_base_name_argument_deprecation(self):
router = SimpleRouter() router = SimpleRouter()
with pytest.warns(RemovedInDRF311Warning) as w: with pytest.warns(RemovedInDRF311Warning) as w:
warnings.simplefilter('always') warnings.simplefilter('always')
router.register('mock', MockViewSet, base_name='mock') router.register('mock', MockViewSet, base_name='mock')
@ -515,44 +486,31 @@ class TestBaseNameRename(TestCase):
msg = "The `base_name` argument is pending deprecation in favor of `basename`." msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assertlen(w)==1 assertlen(w)==1
assertstr(w[0].message)==msg assertstr(w[0].message)==msg
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'mock'),]
('mock', MockViewSet, 'mock'),
]
def test_basename_argument_no_warnings(self): def test_basename_argument_no_warnings(self):
router = SimpleRouter() router = SimpleRouter()
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always') warnings.simplefilter('always')
router.register('mock', MockViewSet, basename='mock') router.register('mock', MockViewSet, basename='mock')
assert len(w) == 0 assert len(w) == 0
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'mock'),]
('mock', MockViewSet, 'mock'),
]
def test_get_default_base_name_deprecation(self): def test_get_default_base_name_deprecation(self):
msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`." msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`."
# Class definition should raise a warning
with pytest.warns(RemovedInDRF311Warning) as w: with pytest.warns(RemovedInDRF311Warning) as w:
warnings.simplefilter('always') warnings.simplefilter('always')
class CustomRouter(SimpleRouter): class CustomRouter(SimpleRouter):
def get_default_base_name(self, viewset): def get_default_base_name(self, viewset):
return 'foo' return 'foo'
assert len(w) == 1 assert len(w) == 1
assertstr(w[0].message)==msg assertstr(w[0].message)==msg
# Deprecated method implementation should still be called
withwarnings.catch_warnings(record=True)asw: withwarnings.catch_warnings(record=True)asw:
warnings.simplefilter('always') warnings.simplefilter('always')
router = CustomRouter() router = CustomRouter()
router.register('mock', MockViewSet) router.register('mock', MockViewSet)
assert len(w) == 0 assert len(w) == 0
assert router.registry == [ assertrouter.registry==[('mock',MockViewSet,'foo'),]
('mock', MockViewSet, 'foo'),
]

File diff suppressed because one or more lines are too long

View File

@ -4,13 +4,9 @@ Tests to cover bulk create and update using serializers.
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
class BulkCreateSerializerTests(TestCase):
""" """
Creating multiple instances using serializers. Creating multiple instances using serializers.
""" """
defsetUp(self): defsetUp(self):
class BookSerializer(serializers.Serializer): class BookSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
@ -23,23 +19,7 @@ class BulkCreateSerializerTests(TestCase):
""" """
Correct bulk update serialization should return the input data. Correct bulk update serialization should return the input data.
""" """
data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 2, 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
data = [
{
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 2,
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is True assert serializer.is_valid() is True
assert serializer.validated_data == data assert serializer.validated_data == data
@ -49,28 +29,8 @@ class BulkCreateSerializerTests(TestCase):
""" """
Incorrect bulk create serialization should return errors. Incorrect bulk create serialization should return errors.
""" """
data = [ { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }, { 'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi' }, { 'id': 'foo', 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami' } ]
data = [ expected_errors = [ {}, {}, {'id': ['A valid integer is required.']} ]
{
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}, {
'id': 1,
'title': 'If this is a man',
'author': 'Primo Levi'
}, {
'id': 'foo',
'title': 'The wind-up bird chronicle',
'author': 'Haruki Murakami'
}
]
expected_errors = [
{},
{},
{'id': ['A valid integer is required.']}
]
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
assert serializer.errors == expected_errors assert serializer.errors == expected_errors
@ -83,14 +43,8 @@ class BulkCreateSerializerTests(TestCase):
data = ['foo', 'bar', 'baz'] data = ['foo', 'bar', 'baz']
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
message = 'Invalid data. Expected a dictionary, but got str.' message = 'Invalid data. Expected a dictionary, but got str.'
expected_errors = [ expected_errors = [ {'non_field_errors': [message]}, {'non_field_errors': [message]}, {'non_field_errors': [message]} ]
{'non_field_errors': [message]},
{'non_field_errors': [message]},
{'non_field_errors': [message]}
]
assert serializer.errors == expected_errors assert serializer.errors == expected_errors
def test_invalid_single_datatype(self): def test_invalid_single_datatype(self):
@ -100,9 +54,7 @@ class BulkCreateSerializerTests(TestCase):
data = 123 data = 123
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
assert serializer.errors == expected_errors assert serializer.errors == expected_errors
def test_invalid_single_object(self): def test_invalid_single_object(self):
@ -110,14 +62,8 @@ class BulkCreateSerializerTests(TestCase):
Data containing only a single object, instead of a list of objects Data containing only a single object, instead of a list of objects
should return errors. should return errors.
""" """
data = { data = { 'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe' }
'id': 0,
'title': 'The electric kool-aid acid test',
'author': 'Tom Wolfe'
}
serializer = self.BookSerializer(data=data, many=True) serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False assert serializer.is_valid() is False
expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
assert serializer.errors == expected_errors assert serializer.errors == expected_errors