diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 8d6151fa2..139d085d9 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -7,6 +7,7 @@ versions of django/python, and compatibility wrappers around optional packages. from __future__ import unicode_literals from django.core.exceptions import ImproperlyConfigured from django.conf import settings +from django.db import connection, transaction from django.utils.encoding import force_text from django.utils.six.moves.urllib.parse import urlparse as _urlparse from django.utils import six @@ -266,3 +267,19 @@ if django.VERSION >= (1, 8): from django.utils.duration import duration_string else: DurationField = duration_string = parse_duration = None + + +def set_rollback(): + if hasattr(transaction, 'set_rollback'): + if connection.settings_dict.get('ATOMIC_REQUESTS', False): + # If running in >=1.6 then mark a rollback as required, + # and allow it to be handled by Django. + transaction.set_rollback(True) + elif transaction.is_managed(): + # Otherwise handle it explicitly if in managed mode. + if transaction.is_dirty(): + transaction.rollback() + transaction.leave_transaction_management() + else: + # transaction not managed + pass diff --git a/rest_framework/views.py b/rest_framework/views.py index f0aadc0e5..ce2e74b38 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -9,7 +9,7 @@ from django.utils.encoding import smart_text from django.utils.translation import ugettext_lazy as _ from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import HttpResponseBase, View +from rest_framework.compat import HttpResponseBase, View, set_rollback from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -71,16 +71,21 @@ def exception_handler(exc, context): else: data = {'detail': exc.detail} + set_rollback() return Response(data, status=exc.status_code, headers=headers) elif isinstance(exc, Http404): msg = _('Not found.') data = {'detail': six.text_type(msg)} + + set_rollback() return Response(data, status=status.HTTP_404_NOT_FOUND) elif isinstance(exc, PermissionDenied): msg = _('Permission denied.') data = {'detail': six.text_type(msg)} + + set_rollback() return Response(data, status=status.HTTP_403_FORBIDDEN) # Note: Unhandled exceptions will raise a 500 error. diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py new file mode 100644 index 000000000..4e55b650b --- /dev/null +++ b/tests/test_atomic_requests.py @@ -0,0 +1,105 @@ +from __future__ import unicode_literals + +from django.db import connection, connections, transaction +from django.test import TestCase +from django.utils.unittest import skipUnless +from rest_framework import status +from rest_framework.exceptions import APIException +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView +from tests.models import BasicModel + + +factory = APIRequestFactory() + + +class BasicView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + return Response({'method': 'GET'}) + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + raise Exception + + +class APIExceptionView(APIView): + def get(self, request, *args, **kwargs): + BasicModel.objects.create() + raise APIException + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionTests(TestCase): + def setUp(self): + self.view = BasicView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_no_exception_conmmit_transaction(self): + request = factory.get('/') + + with self.assertNumQueries(1): + response = self.view(request) + self.assertFalse(transaction.get_rollback()) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionErrorTests(TestCase): + def setUp(self): + self.view = ErrorView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_error_rollback_transaction(self): + """ + Transaction is eventually managed by outer-most transaction atomic + block. DRF do not try to interfere here. + """ + request = factory.get('/') + with self.assertNumQueries(3): + # 1 - begin savepoint + # 2 - insert + # 3 - release savepoint + with transaction.atomic(): + self.assertRaises(Exception, self.view, request) + self.assertFalse(transaction.get_rollback()) + + +@skipUnless(connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints.") +class DBTransactionAPIExceptionTests(TestCase): + def setUp(self): + self.view = APIExceptionView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + + def test_api_exception_rollback_transaction(self): + """ + Transaction is rollbacked by our transaction atomic block. + """ + request = factory.get('/') + num_queries = (4 if getattr(connection.features, + 'can_release_savepoints', False) else 3) + with self.assertNumQueries(num_queries): + # 1 - begin savepoint + # 2 - insert + # 3 - rollback savepoint + # 4 - release savepoint (django>=1.8 only) + with transaction.atomic(): + response = self.view(request) + self.assertTrue(transaction.get_rollback()) + self.assertEqual(response.status_code, + status.HTTP_500_INTERNAL_SERVER_ERROR)