From 27fd48586eec76118b60fddb9d0b28c7343446c9 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 22 Apr 2015 11:33:01 +0200 Subject: [PATCH 1/5] allow to pass arbitrary arguments to py.test --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index f91f8b3ff..e240275f0 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ envlist = {py27,py32,py33,py34}-django{17,18,master} [testenv] -commands = ./runtests.py --fast +commands = ./runtests.py --fast {posargs} setenv = PYTHONDONTWRITEBYTECODE=1 deps = From c2d24172372385047691842219447ad55d2ca0c9 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 29 Apr 2015 15:08:52 +0200 Subject: [PATCH 2/5] Tell default error handler to doom the transaction on error if `ATOMIC_REQUESTS` is enabled. --- rest_framework/compat.py | 17 ++++++ rest_framework/views.py | 7 ++- tests/test_atomic_requests.py | 105 ++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 tests/test_atomic_requests.py diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 8d6151fa2..139d085d9 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -7,6 +7,7 @@ versions of django/python, and compatibility wrappers around optional packages. from __future__ import unicode_literals from django.core.exceptions import ImproperlyConfigured from django.conf import settings +from django.db import connection, transaction from django.utils.encoding import force_text from django.utils.six.moves.urllib.parse import urlparse as _urlparse from django.utils import six @@ -266,3 +267,19 @@ if django.VERSION >= (1, 8): from django.utils.duration import duration_string else: DurationField = duration_string = parse_duration = None + + +def set_rollback(): + if hasattr(transaction, 'set_rollback'): + if connection.settings_dict.get('ATOMIC_REQUESTS', False): + # If running in >=1.6 then mark a rollback as required, + # and allow it to be handled by Django. + transaction.set_rollback(True) + elif transaction.is_managed(): + # Otherwise handle it explicitly if in managed mode. + if transaction.is_dirty(): + transaction.rollback() + transaction.leave_transaction_management() + else: + # transaction not managed + pass diff --git a/rest_framework/views.py b/rest_framework/views.py index f0aadc0e5..ce2e74b38 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -9,7 +9,7 @@ from django.utils.encoding import smart_text from django.utils.translation import ugettext_lazy as _ from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import HttpResponseBase, View +from rest_framework.compat import HttpResponseBase, View, set_rollback from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -71,16 +71,21 @@ def exception_handler(exc, context): else: data = {'detail': exc.detail} + set_rollback() return Response(data, status=exc.status_code, headers=headers) elif isinstance(exc, Http404): msg = _('Not found.') data = {'detail': six.text_type(msg)} + + set_rollback() return Response(data, status=status.HTTP_404_NOT_FOUND) elif isinstance(exc, PermissionDenied): msg = _('Permission denied.') data = {'detail': six.text_type(msg)} + + set_rollback() return Response(data, status=status.HTTP_403_FORBIDDEN) # Note: Unhandled exceptions will raise a 500 error. diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py new file mode 100644 index 000000000..4e55b650b --- /dev/null +++ b/tests/test_atomic_requests.py @@ -0,0 +1,105 @@ +from __future__ import unicode_literals + +from django.db import connection, connections, transaction +from django.test import TestCase +from django.utils.unittest import skipUnless +from rest_framework import status +from rest_framework.exceptions import APIException +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView +from tests.models import BasicModel + + +factory = APIRequestFactory() + + +class BasicView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + return Response({'method': 'GET'}) + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + raise Exception + + +class APIExceptionView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + raise APIException + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionTests(TestCase): + def setUp(self): + self.view = BasicView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_no_exception_conmmit_transaction(self): + request = factory.get('/') + + with self.assertNumQueries(1): + response = self.view(request) + self.assertFalse(transaction.get_rollback()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionErrorTests(TestCase): + def setUp(self): + self.view = ErrorView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_error_rollback_transaction(self): + """ + Transaction is eventually managed by outer-most transaction atomic + block. DRF do not try to interfere here. + """ + request = factory.get('/') + with self.assertNumQueries(3): + # 1 - begin savepoint + # 2 - insert + # 3 - release savepoint + with transaction.atomic(): + self.assertRaises(Exception, self.view, request) + self.assertFalse(transaction.get_rollback()) + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionAPIExceptionTests(TestCase): + def setUp(self): + self.view = APIExceptionView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_api_exception_rollback_transaction(self): + """ + Transaction is rollbacked by our transaction atomic block. + """ + request = factory.get('/') + num_queries = (4 if getattr(connection.features, + 'can_release_savepoints', False) else 3) + with self.assertNumQueries(num_queries): + # 1 - begin savepoint + # 2 - insert + # 3 - rollback savepoint + # 4 - release savepoint (django>=1.8 only) + with transaction.atomic(): + response = self.view(request) + self.assertTrue(transaction.get_rollback()) + self.assertEqual(response.status_code, + status.HTTP_500_INTERNAL_SERVER_ERROR) From d1371cc949afcc66c7e7f497bab62ec655cddf31 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 29 Apr 2015 16:28:22 +0200 Subject: [PATCH 3/5] Use post instead of get for sanity of use-case. --- tests/test_atomic_requests.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 4e55b650b..b3bace3bb 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -15,19 +15,19 @@ factory = APIRequestFactory() class BasicView(APIView): - def get(self, request, *args, **kwargs): + def post(self, request, *args, **kwargs): BasicModel.objects.create() return Response({'method': 'GET'}) class ErrorView(APIView): - def get(self, request, *args, **kwargs): + def post(self, request, *args, **kwargs): BasicModel.objects.create() raise Exception class APIExceptionView(APIView): - def get(self, request, *args, **kwargs): + def post(self, request, *args, **kwargs): BasicModel.objects.create() raise APIException @@ -43,7 +43,7 @@ class DBTransactionTests(TestCase): connections.databases['default']['ATOMIC_REQUESTS'] = False def test_no_exception_conmmit_transaction(self): - request = factory.get('/') + request = factory.post('/') with self.assertNumQueries(1): response = self.view(request) @@ -66,7 +66,7 @@ class DBTransactionErrorTests(TestCase): Transaction is eventually managed by outer-most transaction atomic block. DRF do not try to interfere here. """ - request = factory.get('/') + request = factory.post('/') with self.assertNumQueries(3): # 1 - begin savepoint # 2 - insert @@ -90,7 +90,7 @@ class DBTransactionAPIExceptionTests(TestCase): """ Transaction is rollbacked by our transaction atomic block. """ - request = factory.get('/') + request = factory.post('/') num_queries = (4 if getattr(connection.features, 'can_release_savepoints', False) else 3) with self.assertNumQueries(num_queries): From 8ad38208a183343bd1bd2b499966dc98edc2863b Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 29 Apr 2015 16:28:48 +0200 Subject: [PATCH 4/5] more assertions make the test more readable --- tests/test_atomic_requests.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index b3bace3bb..09f3742ad 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -49,6 +49,7 @@ class DBTransactionTests(TestCase): response = self.view(request) self.assertFalse(transaction.get_rollback()) self.assertEqual(response.status_code, status.HTTP_200_OK) + assert BasicModel.objects.count() == 1 @skipUnless(connection.features.uses_savepoints, @@ -74,6 +75,7 @@ class DBTransactionErrorTests(TestCase): with transaction.atomic(): self.assertRaises(Exception, self.view, request) self.assertFalse(transaction.get_rollback()) + assert BasicModel.objects.count() == 1 @skipUnless(connection.features.uses_savepoints, @@ -103,3 +105,4 @@ class DBTransactionAPIExceptionTests(TestCase): self.assertTrue(transaction.get_rollback()) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert BasicModel.objects.count() == 0 From 34dc98e8ad7ff82f82e58c6bf2170bacfdb449c7 Mon Sep 17 00:00:00 2001 From: Nicolas Delaby Date: Wed, 29 Apr 2015 16:29:09 +0200 Subject: [PATCH 5/5] improve wording --- tests/test_atomic_requests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 09f3742ad..9410fea5e 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -62,10 +62,12 @@ class DBTransactionErrorTests(TestCase): def tearDown(self): connections.databases['default']['ATOMIC_REQUESTS'] = False - def test_error_rollback_transaction(self): + def test_generic_exception_delegate_transaction_management(self): """ Transaction is eventually managed by outer-most transaction atomic block. DRF do not try to interfere here. + + We let django deal with the transaction when it will catch the Exception. """ request = factory.post('/') with self.assertNumQueries(3):