From 805a915e7c2c0122ea588fa0dcac9f92e3276bbc Mon Sep 17 00:00:00 2001 From: "S. Andrew Sheppard" Date: Wed, 27 May 2015 21:06:57 -0500 Subject: [PATCH 01/16] can't nest unique_together relations --- tests/test_model_serializer.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index dc34649ea..048a12fba 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -316,6 +316,13 @@ class RelationalModel(models.Model): through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through') +class UniqueTogetherModel(models.Model): + foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='unique_foreign_key') + one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='unique_one_to_one') + class Meta: + unique_together = ("foreign_key", "one_to_one") + + class TestRelationalFieldMappings(TestCase): def test_pk_relations(self): class TestSerializer(serializers.ModelSerializer): @@ -395,6 +402,25 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(unicode_repr(TestSerializer()), expected) + def test_nested_unique_together_relations(self): + class TestSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = UniqueTogetherModel + depth = 1 + expected = dedent(""" + TestSerializer(): + url = HyperlinkedIdentityField(view_name='uniquetogethermodel-detail') + foreign_key = NestedSerializer(read_only=True): + url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail') + name = CharField(max_length=100) + one_to_one = NestedSerializer(read_only=True): + url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') + name = CharField(max_length=100) + class Meta: + validators = [] + """) + self.assertEqual(unicode_repr(TestSerializer()), expected) + def test_pk_reverse_foreign_key(self): class TestSerializer(serializers.ModelSerializer): class Meta: From 8c7b5fc5c132ee4648aef7e95868dec9064d8ef6 Mon Sep 17 00:00:00 2001 From: "S. Andrew Sheppard" Date: Wed, 27 May 2015 21:14:08 -0500 Subject: [PATCH 02/16] pop required extra_kwargs if read_only is set --- rest_framework/serializers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8e1e50bc7..b1d58ee5f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1088,6 +1088,9 @@ class ModelSerializer(Serializer): if extra_kwargs.get('default') and kwargs.get('required') is False: kwargs.pop('required') + if kwargs.get('read_only', False): + extra_kwargs.pop('required', None) + kwargs.update(extra_kwargs) return kwargs From 0b8b288be597bd4d93d30c74f7c017a9d8abf497 Mon Sep 17 00:00:00 2001 From: "S. Andrew Sheppard" Date: Thu, 28 May 2015 08:20:43 -0500 Subject: [PATCH 03/16] python2 compat --- tests/test_model_serializer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 048a12fba..8ff97b766 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -419,6 +419,13 @@ class TestRelationalFieldMappings(TestCase): class Meta: validators = [] """) + if six.PY2: + # This case is also too awkward to resolve fully across both py2 + # and py3. (See above) + expected = expected.replace( + "('foreign_key', 'one_to_one')", + "(u'foreign_key', u'one_to_one')" + ) self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_foreign_key(self): From 4a3c844b7fceb68af86025b5f09ac24f5dffed6a Mon Sep 17 00:00:00 2001 From: "S. Andrew Sheppard" Date: Thu, 28 May 2015 08:29:15 -0500 Subject: [PATCH 04/16] flake8 --- tests/test_model_serializer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 8ff97b766..49ad8ac10 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -319,6 +319,7 @@ class RelationalModel(models.Model): class UniqueTogetherModel(models.Model): foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='unique_foreign_key') one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='unique_one_to_one') + class Meta: unique_together = ("foreign_key", "one_to_one") From 2f524ec1a32c9ff225e09afef70aabae82aee61b Mon Sep 17 00:00:00 2001 From: Xavier Ordoquy Date: Mon, 1 Jun 2015 15:46:27 +0100 Subject: [PATCH 05/16] Remove an extra MockHTMLDict definition. --- tests/test_fields.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/test_fields.py b/tests/test_fields.py index 568e8d5e7..a407f5058 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -221,6 +221,14 @@ class TestInvalidErrorKey: assert str(exc_info.value) == expected +class MockHTMLDict(dict): + """ + This class mocks up a dictionary like object, that behaves + as if it was returned for multipart or urlencoded data. + """ + getlist = None + + class TestBooleanHTMLInput: def setup(self): class TestSerializer(serializers.Serializer): @@ -234,21 +242,11 @@ class TestBooleanHTMLInput: """ # This class mocks up a dictionary like object, that behaves # as if it was returned for multipart or urlencoded data. - class MockHTMLDict(dict): - getlist = None serializer = self.Serializer(data=MockHTMLDict()) assert serializer.is_valid() assert serializer.validated_data == {'archived': False} -class MockHTMLDict(dict): - """ - This class mocks up a dictionary like object, that behaves - as if it was returned for multipart or urlencoded data. - """ - getlist = None - - class TestHTMLInput: def test_empty_html_charfield(self): class TestSerializer(serializers.Serializer): From 989c08109bf0913b43a16ec7f951bbb6f6f6bef6 Mon Sep 17 00:00:00 2001 From: Xavier Ordoquy Date: Mon, 1 Jun 2015 16:04:05 +0100 Subject: [PATCH 06/16] Failing test case for #2894 --- tests/test_fields.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_fields.py b/tests/test_fields.py index 568e8d5e7..c4c2eeeca 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,6 +1,7 @@ from decimal import Decimal from django.utils import timezone from rest_framework import serializers +import rest_framework import datetime import django import pytest @@ -1017,6 +1018,12 @@ class TestMultipleChoiceField(FieldValues): ] ) + def test_against_partial_updates(self): + # serializer = self.Serializer(data=MockHTMLDict()) + from django.http import QueryDict + field = serializers.MultipleChoiceField(choices=(('a', 'a'), ('b', 'b'))) + assert field.get_value(QueryDict({})) == rest_framework.fields.empty + # File serializers... From 94e2d3ca610ec1fdb97641a094801f69a1ab2613 Mon Sep 17 00:00:00 2001 From: Xavier Ordoquy Date: Mon, 1 Jun 2015 16:13:12 +0100 Subject: [PATCH 07/16] Test case upgrade to use partial data --- tests/test_fields.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_fields.py b/tests/test_fields.py index c4c2eeeca..0867852be 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1018,10 +1018,13 @@ class TestMultipleChoiceField(FieldValues): ] ) - def test_against_partial_updates(self): + def test_against_partial_and_full_updates(self): # serializer = self.Serializer(data=MockHTMLDict()) from django.http import QueryDict field = serializers.MultipleChoiceField(choices=(('a', 'a'), ('b', 'b'))) + field.partial = False + assert field.get_value(QueryDict({})) == [] + field.partial = True assert field.get_value(QueryDict({})) == rest_framework.fields.empty From 5c90bf9cc00e9870ba1d1d5bd3113ce797e73306 Mon Sep 17 00:00:00 2001 From: Xavier Ordoquy Date: Mon, 1 Jun 2015 16:13:35 +0100 Subject: [PATCH 08/16] Fix for #2894 thanks to @carljm --- rest_framework/fields.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d8bb0a017..e7a4cee5f 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1060,7 +1060,11 @@ class MultipleChoiceField(ChoiceField): # We override the default field access in order to support # lists in HTML forms. if html.is_html_input(dictionary): - return dictionary.getlist(self.field_name) + ret = dictionary.getlist(self.field_name) + if getattr(self.root, 'partial', False) and not ret: + ret = empty + return ret + return dictionary.get(self.field_name, empty) def to_internal_value(self, data): From f701ecceb74ba8bfc0e270b64460626bd4544766 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Mon, 1 Jun 2015 18:20:53 +0200 Subject: [PATCH 09/16] Add DurationField --- docs/api-guide/fields.md | 12 ++++++++++++ rest_framework/compat.py | 8 ++++++++ rest_framework/fields.py | 25 ++++++++++++++++++++++++- rest_framework/serializers.py | 8 +++++++- tests/test_fields.py | 23 +++++++++++++++++++++++ tests/test_model_serializer.py | 26 +++++++++++++++++++++++++- 6 files changed, 99 insertions(+), 3 deletions(-) diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index c87db7854..aad188511 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -302,6 +302,18 @@ Corresponds to `django.db.models.fields.TimeField` Format strings may either be [Python strftime formats][strftime] which explicitly specify the format, or the special string `'iso-8601'`, which indicates that [ISO 8601][iso8601] style times should be used. (eg `'12:34:56.000000'`) +## DurationField + +A Duration representation. +Corresponds to `django.db.models.fields.Duration` + +The `validated_data` for these fields will contain a `datetime.timedelta` instance. +The representation is a string following this format `'[DD] [HH:[MM:]]ss[.uuuuuu]'`. + +**Note:** This field is only available with Django versions >= 1.8. + +**Signature:** `DurationField()` + --- # Choice selection fields diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 1ba907314..8d6151fa2 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -258,3 +258,11 @@ else: SHORT_SEPARATORS = (b',', b':') LONG_SEPARATORS = (b', ', b': ') INDENT_SEPARATORS = (b',', b': ') + + +if django.VERSION >= (1, 8): + from django.db.models import DurationField + from django.utils.dateparse import parse_duration + from django.utils.duration import duration_string +else: + DurationField = duration_string = parse_duration = None diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d8bb0a017..85c451078 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -12,7 +12,7 @@ from rest_framework import ISO_8601 from rest_framework.compat import ( EmailValidator, MinValueValidator, MaxValueValidator, MinLengthValidator, MaxLengthValidator, URLValidator, OrderedDict, - unicode_repr, unicode_to_repr + unicode_repr, unicode_to_repr, parse_duration, duration_string, ) from rest_framework.exceptions import ValidationError from rest_framework.settings import api_settings @@ -1003,6 +1003,29 @@ class TimeField(Field): return value.strftime(self.format) +class DurationField(Field): + default_error_messages = { + 'invalid': _('Duration has wrong format. Use one of these formats instead: {format}.'), + } + + def __init__(self, *args, **kwargs): + if parse_duration is None: + raise NotImplementedError( + 'DurationField not supported for django versions prior to 1.8') + return super(DurationField, self).__init__(*args, **kwargs) + + def to_internal_value(self, value): + if isinstance(value, datetime.timedelta): + return value + parsed = parse_duration(value) + if parsed is not None: + return parsed + self.fail('invalid', format='[DD] [HH:[MM:]]ss[.uuuuuu]') + + def to_representation(self, value): + return duration_string(value) + + # Choice types... class ChoiceField(Field): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 73ac6bc2a..55f571db9 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -15,7 +15,11 @@ from django.db import models from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField from django.db.models import query from django.utils.translation import ugettext_lazy as _ -from rest_framework.compat import postgres_fields, unicode_to_repr +from rest_framework.compat import ( + postgres_fields, + unicode_to_repr, + DurationField as ModelDurationField, +) from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( get_url_kwargs, get_field_kwargs, @@ -731,6 +735,8 @@ class ModelSerializer(Serializer): models.TimeField: TimeField, models.URLField: URLField, } + if ModelDurationField is not None: + serializer_field_mapping[ModelDurationField] = DurationField serializer_related_field = PrimaryKeyRelatedField serializer_url_field = HyperlinkedIdentityField serializer_choice_field = ChoiceField diff --git a/tests/test_fields.py b/tests/test_fields.py index 568e8d5e7..ae1920d9f 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -905,6 +905,29 @@ class TestNoOutputFormatTimeField(FieldValues): field = serializers.TimeField(format=None) +@pytest.mark.skipif(django.VERSION < (1, 8), + reason='DurationField is only available for django1.8+') +class TestDurationField(FieldValues): + """ + Valid and invalid values for `DurationField`. + """ + valid_inputs = { + '13': datetime.timedelta(seconds=13), + '3 08:32:01.000123': datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123), + '08:01': datetime.timedelta(minutes=8, seconds=1), + datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123), + } + invalid_inputs = { + 'abc': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'], + '3 08:32 01.123': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'], + } + outputs = { + datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): '3 08:32:01.000123', + } + if django.VERSION >= (1, 8): + field = serializers.DurationField() + + # Choice types... class TestChoiceField(FieldValues): diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index dc34649ea..a94133823 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -6,13 +6,15 @@ These tests deal with ensuring that we correctly map the model fields onto an appropriate set of serializer fields for each case. """ from __future__ import unicode_literals +import django from django.core.exceptions import ImproperlyConfigured from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator from django.db import models from django.test import TestCase from django.utils import six +import pytest from rest_framework import serializers -from rest_framework.compat import unicode_repr +from rest_framework.compat import unicode_repr, DurationField as ModelDurationField def dedent(blocktext): @@ -284,6 +286,28 @@ class TestRegularFieldMappings(TestCase): ChildSerializer().fields +@pytest.mark.skipif(django.VERSION < (1, 8), + reason='DurationField is only available for django1.8+') +class TestDurationFieldMapping(TestCase): + def test_duration_field(self): + class DurationFieldModel(models.Model): + """ + A model that defines DurationField. + """ + duration_field = ModelDurationField() + + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = DurationFieldModel + + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + duration_field = DurationField() + """) + self.assertEqual(unicode_repr(TestSerializer()), expected) + + # Tests for relational field mappings. # ------------------------------------ From 4ad8c17371e25acbdce4e2f449efccc4df072270 Mon Sep 17 00:00:00 2001 From: Xavier Ordoquy Date: Mon, 1 Jun 2015 18:13:08 +0100 Subject: [PATCH 10/16] Add a warning about totally custom login views. --- docs/api-guide/authentication.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/api-guide/authentication.md b/docs/api-guide/authentication.md index 54f6ac030..2ccf7d721 100644 --- a/docs/api-guide/authentication.md +++ b/docs/api-guide/authentication.md @@ -247,6 +247,10 @@ Unauthenticated responses that are denied permission will result in an `HTTP 403 If you're using an AJAX style API with SessionAuthentication, you'll need to make sure you include a valid CSRF token for any "unsafe" HTTP method calls, such as `PUT`, `PATCH`, `POST` or `DELETE` requests. See the [Django CSRF documentation][csrf-ajax] for more details. +**Warning**: Always use Django's standard login view when creating login pages. This will ensure your login views are properly protected. + +CSRF validation in REST framework works slightly differently to standard Django due to the need to support both session and non-session based authentication to the same views. This means that only authenticated requests require CSRF tokens, and anonymous requests may be sent without CSRF tokens. This behaviour is not suitable for login views, which should always have CSRF validation applied. + # Custom authentication To implement a custom authentication scheme, subclass `BaseAuthentication` and override the `.authenticate(self, request)` method. The method should return a two-tuple of `(user, auth)` if authentication succeeds, or `None` otherwise. From 3f3a01b4b82dcd70ca8ca2100fe16bc8074ba873 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Tue, 2 Jun 2015 09:18:24 +0200 Subject: [PATCH 11/16] fix Typo --- docs/api-guide/fields.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index aad188511..a6743c5a5 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -305,7 +305,7 @@ Format strings may either be [Python strftime formats][strftime] which explicitl ## DurationField A Duration representation. -Corresponds to `django.db.models.fields.Duration` +Corresponds to `django.db.models.fields.DurationField` The `validated_data` for these fields will contain a `datetime.timedelta` instance. The representation is a string following this format `'[DD] [HH:[MM:]]ss[.uuuuuu]'`. From 27fd48586eec76118b60fddb9d0b28c7343446c9 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 22 Apr 2015 11:33:01 +0200 Subject: [PATCH 12/16] allow to pass arbitrary arguments to py.test --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index f91f8b3ff..e240275f0 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ envlist = {py27,py32,py33,py34}-django{17,18,master} [testenv] -commands = ./runtests.py --fast +commands = ./runtests.py --fast {posargs} setenv = PYTHONDONTWRITEBYTECODE=1 deps = From c2d24172372385047691842219447ad55d2ca0c9 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 29 Apr 2015 15:08:52 +0200 Subject: [PATCH 13/16] Tell default error handler to doom the transaction on error if `ATOMIC_REQUESTS` is enabled. --- rest_framework/compat.py | 17 ++++++ rest_framework/views.py | 7 ++- tests/test_atomic_requests.py | 105 ++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 tests/test_atomic_requests.py diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 8d6151fa2..139d085d9 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -7,6 +7,7 @@ versions of django/python, and compatibility wrappers around optional packages. from __future__ import unicode_literals from django.core.exceptions import ImproperlyConfigured from django.conf import settings +from django.db import connection, transaction from django.utils.encoding import force_text from django.utils.six.moves.urllib.parse import urlparse as _urlparse from django.utils import six @@ -266,3 +267,19 @@ if django.VERSION >= (1, 8): from django.utils.duration import duration_string else: DurationField = duration_string = parse_duration = None + + +def set_rollback(): + if hasattr(transaction, 'set_rollback'): + if connection.settings_dict.get('ATOMIC_REQUESTS', False): + # If running in >=1.6 then mark a rollback as required, + # and allow it to be handled by Django. + transaction.set_rollback(True) + elif transaction.is_managed(): + # Otherwise handle it explicitly if in managed mode. + if transaction.is_dirty(): + transaction.rollback() + transaction.leave_transaction_management() + else: + # transaction not managed + pass diff --git a/rest_framework/views.py b/rest_framework/views.py index f0aadc0e5..ce2e74b38 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -9,7 +9,7 @@ from django.utils.encoding import smart_text from django.utils.translation import ugettext_lazy as _ from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import HttpResponseBase, View +from rest_framework.compat import HttpResponseBase, View, set_rollback from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -71,16 +71,21 @@ def exception_handler(exc, context): else: data = {'detail': exc.detail} + set_rollback() return Response(data, status=exc.status_code, headers=headers) elif isinstance(exc, Http404): msg = _('Not found.') data = {'detail': six.text_type(msg)} + + set_rollback() return Response(data, status=status.HTTP_404_NOT_FOUND) elif isinstance(exc, PermissionDenied): msg = _('Permission denied.') data = {'detail': six.text_type(msg)} + + set_rollback() return Response(data, status=status.HTTP_403_FORBIDDEN) # Note: Unhandled exceptions will raise a 500 error. diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py new file mode 100644 index 000000000..4e55b650b --- /dev/null +++ b/tests/test_atomic_requests.py @@ -0,0 +1,105 @@ +from __future__ import unicode_literals + +from django.db import connection, connections, transaction +from django.test import TestCase +from django.utils.unittest import skipUnless +from rest_framework import status +from rest_framework.exceptions import APIException +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView +from tests.models import BasicModel + + +factory = APIRequestFactory() + + +class BasicView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + return Response({'method': 'GET'}) + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + raise Exception + + +class APIExceptionView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + raise APIException + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionTests(TestCase): + def setUp(self): + self.view = BasicView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_no_exception_conmmit_transaction(self): + request = factory.get('/') + + with self.assertNumQueries(1): + response = self.view(request) + self.assertFalse(transaction.get_rollback()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionErrorTests(TestCase): + def setUp(self): + self.view = ErrorView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_error_rollback_transaction(self): + """ + Transaction is eventually managed by outer-most transaction atomic + block. DRF do not try to interfere here. + """ + request = factory.get('/') + with self.assertNumQueries(3): + # 1 - begin savepoint + # 2 - insert + # 3 - release savepoint + with transaction.atomic(): + self.assertRaises(Exception, self.view, request) + self.assertFalse(transaction.get_rollback()) + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionAPIExceptionTests(TestCase): + def setUp(self): + self.view = APIExceptionView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_api_exception_rollback_transaction(self): + """ + Transaction is rollbacked by our transaction atomic block. + """ + request = factory.get('/') + num_queries = (4 if getattr(connection.features, + 'can_release_savepoints', False) else 3) + with self.assertNumQueries(num_queries): + # 1 - begin savepoint + # 2 - insert + # 3 - rollback savepoint + # 4 - release savepoint (django>=1.8 only) + with transaction.atomic(): + response = self.view(request) + self.assertTrue(transaction.get_rollback()) + self.assertEqual(response.status_code, + status.HTTP_500_INTERNAL_SERVER_ERROR) From d1371cc949afcc66c7e7f497bab62ec655cddf31 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 29 Apr 2015 16:28:22 +0200 Subject: [PATCH 14/16] Use post instead of get for sanity of use-case. --- tests/test_atomic_requests.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 4e55b650b..b3bace3bb 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -15,19 +15,19 @@ factory = APIRequestFactory() class BasicView(APIView): - def get(self, request, *args, **kwargs): + def post(self, request, *args, **kwargs): BasicModel.objects.create() return Response({'method': 'GET'}) class ErrorView(APIView): - def get(self, request, *args, **kwargs): + def post(self, request, *args, **kwargs): BasicModel.objects.create() raise Exception class APIExceptionView(APIView): - def get(self, request, *args, **kwargs): + def post(self, request, *args, **kwargs): BasicModel.objects.create() raise APIException @@ -43,7 +43,7 @@ class DBTransactionTests(TestCase): connections.databases['default']['ATOMIC_REQUESTS'] = False def test_no_exception_conmmit_transaction(self): - request = factory.get('/') + request = factory.post('/') with self.assertNumQueries(1): response = self.view(request) @@ -66,7 +66,7 @@ class DBTransactionErrorTests(TestCase): Transaction is eventually managed by outer-most transaction atomic block. DRF do not try to interfere here. """ - request = factory.get('/') + request = factory.post('/') with self.assertNumQueries(3): # 1 - begin savepoint # 2 - insert @@ -90,7 +90,7 @@ class DBTransactionAPIExceptionTests(TestCase): """ Transaction is rollbacked by our transaction atomic block. """ - request = factory.get('/') + request = factory.post('/') num_queries = (4 if getattr(connection.features, 'can_release_savepoints', False) else 3) with self.assertNumQueries(num_queries): From 8ad38208a183343bd1bd2b499966dc98edc2863b Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 29 Apr 2015 16:28:48 +0200 Subject: [PATCH 15/16] more assertions make the test more readable --- tests/test_atomic_requests.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index b3bace3bb..09f3742ad 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -49,6 +49,7 @@ class DBTransactionTests(TestCase): response = self.view(request) self.assertFalse(transaction.get_rollback()) self.assertEqual(response.status_code, status.HTTP_200_OK) + assert BasicModel.objects.count() == 1 @skipUnless(connection.features.uses_savepoints, @@ -74,6 +75,7 @@ class DBTransactionErrorTests(TestCase): with transaction.atomic(): self.assertRaises(Exception, self.view, request) self.assertFalse(transaction.get_rollback()) + assert BasicModel.objects.count() == 1 @skipUnless(connection.features.uses_savepoints, @@ -103,3 +105,4 @@ class DBTransactionAPIExceptionTests(TestCase): self.assertTrue(transaction.get_rollback()) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert BasicModel.objects.count() == 0 From 34dc98e8ad7ff82f82e58c6bf2170bacfdb449c7 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 29 Apr 2015 16:29:09 +0200 Subject: [PATCH 16/16] improve wording --- tests/test_atomic_requests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 09f3742ad..9410fea5e 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -62,10 +62,12 @@ class DBTransactionErrorTests(TestCase): def tearDown(self): connections.databases['default']['ATOMIC_REQUESTS'] = False - def test_error_rollback_transaction(self): + def test_generic_exception_delegate_transaction_management(self): """ Transaction is eventually managed by outer-most transaction atomic block. DRF do not try to interfere here. + + We let django deal with the transaction when it will catch the Exception. """ request = factory.post('/') with self.assertNumQueries(3):