mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 01:47:59 +03:00 
			
		
		
		
	Merge pull request #2887 from ticosax/doom-transaction-on-error
Rollback the transaction on error if ATOMIC_REQUESTS is set.
This commit is contained in:
		
						commit
						1957679368
					
				| 
						 | 
				
			
			@ -7,6 +7,7 @@ versions of django/python, and compatibility wrappers around optional packages.
 | 
			
		|||
from __future__ import unicode_literals
 | 
			
		||||
from django.core.exceptions import ImproperlyConfigured
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.db import connection, transaction
 | 
			
		||||
from django.utils.encoding import force_text
 | 
			
		||||
from django.utils.six.moves.urllib.parse import urlparse as _urlparse
 | 
			
		||||
from django.utils import six
 | 
			
		||||
| 
						 | 
				
			
			@ -266,3 +267,19 @@ if django.VERSION >= (1, 8):
 | 
			
		|||
    from django.utils.duration import duration_string
 | 
			
		||||
else:
 | 
			
		||||
    DurationField = duration_string = parse_duration = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_rollback():
 | 
			
		||||
    if hasattr(transaction, 'set_rollback'):
 | 
			
		||||
        if connection.settings_dict.get('ATOMIC_REQUESTS', False):
 | 
			
		||||
            # If running in >=1.6 then mark a rollback as required,
 | 
			
		||||
            # and allow it to be handled by Django.
 | 
			
		||||
            transaction.set_rollback(True)
 | 
			
		||||
    elif transaction.is_managed():
 | 
			
		||||
        # Otherwise handle it explicitly if in managed mode.
 | 
			
		||||
        if transaction.is_dirty():
 | 
			
		||||
            transaction.rollback()
 | 
			
		||||
        transaction.leave_transaction_management()
 | 
			
		||||
    else:
 | 
			
		||||
        # transaction not managed
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,7 +9,7 @@ from django.utils.encoding import smart_text
 | 
			
		|||
from django.utils.translation import ugettext_lazy as _
 | 
			
		||||
from django.views.decorators.csrf import csrf_exempt
 | 
			
		||||
from rest_framework import status, exceptions
 | 
			
		||||
from rest_framework.compat import HttpResponseBase, View
 | 
			
		||||
from rest_framework.compat import HttpResponseBase, View, set_rollback
 | 
			
		||||
from rest_framework.request import Request
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
from rest_framework.settings import api_settings
 | 
			
		||||
| 
						 | 
				
			
			@ -71,16 +71,21 @@ def exception_handler(exc, context):
 | 
			
		|||
        else:
 | 
			
		||||
            data = {'detail': exc.detail}
 | 
			
		||||
 | 
			
		||||
        set_rollback()
 | 
			
		||||
        return Response(data, status=exc.status_code, headers=headers)
 | 
			
		||||
 | 
			
		||||
    elif isinstance(exc, Http404):
 | 
			
		||||
        msg = _('Not found.')
 | 
			
		||||
        data = {'detail': six.text_type(msg)}
 | 
			
		||||
 | 
			
		||||
        set_rollback()
 | 
			
		||||
        return Response(data, status=status.HTTP_404_NOT_FOUND)
 | 
			
		||||
 | 
			
		||||
    elif isinstance(exc, PermissionDenied):
 | 
			
		||||
        msg = _('Permission denied.')
 | 
			
		||||
        data = {'detail': six.text_type(msg)}
 | 
			
		||||
 | 
			
		||||
        set_rollback()
 | 
			
		||||
        return Response(data, status=status.HTTP_403_FORBIDDEN)
 | 
			
		||||
 | 
			
		||||
    # Note: Unhandled exceptions will raise a 500 error.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										110
									
								
								tests/test_atomic_requests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								tests/test_atomic_requests.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,110 @@
 | 
			
		|||
from __future__ import unicode_literals
 | 
			
		||||
 | 
			
		||||
from django.db import connection, connections, transaction
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.utils.unittest import skipUnless
 | 
			
		||||
from rest_framework import status
 | 
			
		||||
from rest_framework.exceptions import APIException
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
from rest_framework.test import APIRequestFactory
 | 
			
		||||
from rest_framework.views import APIView
 | 
			
		||||
from tests.models import BasicModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
factory = APIRequestFactory()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BasicView(APIView):
 | 
			
		||||
    def post(self, request, *args, **kwargs):
 | 
			
		||||
        BasicModel.objects.create()
 | 
			
		||||
        return Response({'method': 'GET'})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ErrorView(APIView):
 | 
			
		||||
    def post(self, request, *args, **kwargs):
 | 
			
		||||
        BasicModel.objects.create()
 | 
			
		||||
        raise Exception
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class APIExceptionView(APIView):
 | 
			
		||||
    def post(self, request, *args, **kwargs):
 | 
			
		||||
        BasicModel.objects.create()
 | 
			
		||||
        raise APIException
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@skipUnless(connection.features.uses_savepoints,
 | 
			
		||||
            "'atomic' requires transactions and savepoints.")
 | 
			
		||||
class DBTransactionTests(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.view = BasicView.as_view()
 | 
			
		||||
        connections.databases['default']['ATOMIC_REQUESTS'] = True
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        connections.databases['default']['ATOMIC_REQUESTS'] = False
 | 
			
		||||
 | 
			
		||||
    def test_no_exception_conmmit_transaction(self):
 | 
			
		||||
        request = factory.post('/')
 | 
			
		||||
 | 
			
		||||
        with self.assertNumQueries(1):
 | 
			
		||||
            response = self.view(request)
 | 
			
		||||
        self.assertFalse(transaction.get_rollback())
 | 
			
		||||
        self.assertEqual(response.status_code, status.HTTP_200_OK)
 | 
			
		||||
        assert BasicModel.objects.count() == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@skipUnless(connection.features.uses_savepoints,
 | 
			
		||||
            "'atomic' requires transactions and savepoints.")
 | 
			
		||||
class DBTransactionErrorTests(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.view = ErrorView.as_view()
 | 
			
		||||
        connections.databases['default']['ATOMIC_REQUESTS'] = True
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        connections.databases['default']['ATOMIC_REQUESTS'] = False
 | 
			
		||||
 | 
			
		||||
    def test_generic_exception_delegate_transaction_management(self):
 | 
			
		||||
        """
 | 
			
		||||
        Transaction is eventually managed by outer-most transaction atomic
 | 
			
		||||
        block. DRF do not try to interfere here.
 | 
			
		||||
 | 
			
		||||
        We let django deal with the transaction when it will catch the Exception.
 | 
			
		||||
        """
 | 
			
		||||
        request = factory.post('/')
 | 
			
		||||
        with self.assertNumQueries(3):
 | 
			
		||||
            # 1 - begin savepoint
 | 
			
		||||
            # 2 - insert
 | 
			
		||||
            # 3 - release savepoint
 | 
			
		||||
            with transaction.atomic():
 | 
			
		||||
                self.assertRaises(Exception, self.view, request)
 | 
			
		||||
                self.assertFalse(transaction.get_rollback())
 | 
			
		||||
        assert BasicModel.objects.count() == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@skipUnless(connection.features.uses_savepoints,
 | 
			
		||||
            "'atomic' requires transactions and savepoints.")
 | 
			
		||||
class DBTransactionAPIExceptionTests(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.view = APIExceptionView.as_view()
 | 
			
		||||
        connections.databases['default']['ATOMIC_REQUESTS'] = True
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        connections.databases['default']['ATOMIC_REQUESTS'] = False
 | 
			
		||||
 | 
			
		||||
    def test_api_exception_rollback_transaction(self):
 | 
			
		||||
        """
 | 
			
		||||
        Transaction is rollbacked by our transaction atomic block.
 | 
			
		||||
        """
 | 
			
		||||
        request = factory.post('/')
 | 
			
		||||
        num_queries = (4 if getattr(connection.features,
 | 
			
		||||
                                    'can_release_savepoints', False) else 3)
 | 
			
		||||
        with self.assertNumQueries(num_queries):
 | 
			
		||||
            # 1 - begin savepoint
 | 
			
		||||
            # 2 - insert
 | 
			
		||||
            # 3 - rollback savepoint
 | 
			
		||||
            # 4 - release savepoint (django>=1.8 only)
 | 
			
		||||
            with transaction.atomic():
 | 
			
		||||
                response = self.view(request)
 | 
			
		||||
                self.assertTrue(transaction.get_rollback())
 | 
			
		||||
        self.assertEqual(response.status_code,
 | 
			
		||||
                         status.HTTP_500_INTERNAL_SERVER_ERROR)
 | 
			
		||||
        assert BasicModel.objects.count() == 0
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user