mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	Merge pull request #2887 from ticosax/doom-transaction-on-error
Rollback the transaction on error if ATOMIC_REQUESTS is set.
This commit is contained in:
		
						commit
						1957679368
					
				| 
						 | 
					@ -7,6 +7,7 @@ versions of django/python, and compatibility wrappers around optional packages.
 | 
				
			||||||
from __future__ import unicode_literals
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
from django.core.exceptions import ImproperlyConfigured
 | 
					from django.core.exceptions import ImproperlyConfigured
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
 | 
					from django.db import connection, transaction
 | 
				
			||||||
from django.utils.encoding import force_text
 | 
					from django.utils.encoding import force_text
 | 
				
			||||||
from django.utils.six.moves.urllib.parse import urlparse as _urlparse
 | 
					from django.utils.six.moves.urllib.parse import urlparse as _urlparse
 | 
				
			||||||
from django.utils import six
 | 
					from django.utils import six
 | 
				
			||||||
| 
						 | 
					@ -266,3 +267,19 @@ if django.VERSION >= (1, 8):
 | 
				
			||||||
    from django.utils.duration import duration_string
 | 
					    from django.utils.duration import duration_string
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    DurationField = duration_string = parse_duration = None
 | 
					    DurationField = duration_string = parse_duration = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def set_rollback():
 | 
				
			||||||
 | 
					    if hasattr(transaction, 'set_rollback'):
 | 
				
			||||||
 | 
					        if connection.settings_dict.get('ATOMIC_REQUESTS', False):
 | 
				
			||||||
 | 
					            # If running in >=1.6 then mark a rollback as required,
 | 
				
			||||||
 | 
					            # and allow it to be handled by Django.
 | 
				
			||||||
 | 
					            transaction.set_rollback(True)
 | 
				
			||||||
 | 
					    elif transaction.is_managed():
 | 
				
			||||||
 | 
					        # Otherwise handle it explicitly if in managed mode.
 | 
				
			||||||
 | 
					        if transaction.is_dirty():
 | 
				
			||||||
 | 
					            transaction.rollback()
 | 
				
			||||||
 | 
					        transaction.leave_transaction_management()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # transaction not managed
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,7 +9,7 @@ from django.utils.encoding import smart_text
 | 
				
			||||||
from django.utils.translation import ugettext_lazy as _
 | 
					from django.utils.translation import ugettext_lazy as _
 | 
				
			||||||
from django.views.decorators.csrf import csrf_exempt
 | 
					from django.views.decorators.csrf import csrf_exempt
 | 
				
			||||||
from rest_framework import status, exceptions
 | 
					from rest_framework import status, exceptions
 | 
				
			||||||
from rest_framework.compat import HttpResponseBase, View
 | 
					from rest_framework.compat import HttpResponseBase, View, set_rollback
 | 
				
			||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
from rest_framework.settings import api_settings
 | 
					from rest_framework.settings import api_settings
 | 
				
			||||||
| 
						 | 
					@ -71,16 +71,21 @@ def exception_handler(exc, context):
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            data = {'detail': exc.detail}
 | 
					            data = {'detail': exc.detail}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        set_rollback()
 | 
				
			||||||
        return Response(data, status=exc.status_code, headers=headers)
 | 
					        return Response(data, status=exc.status_code, headers=headers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif isinstance(exc, Http404):
 | 
					    elif isinstance(exc, Http404):
 | 
				
			||||||
        msg = _('Not found.')
 | 
					        msg = _('Not found.')
 | 
				
			||||||
        data = {'detail': six.text_type(msg)}
 | 
					        data = {'detail': six.text_type(msg)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        set_rollback()
 | 
				
			||||||
        return Response(data, status=status.HTTP_404_NOT_FOUND)
 | 
					        return Response(data, status=status.HTTP_404_NOT_FOUND)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif isinstance(exc, PermissionDenied):
 | 
					    elif isinstance(exc, PermissionDenied):
 | 
				
			||||||
        msg = _('Permission denied.')
 | 
					        msg = _('Permission denied.')
 | 
				
			||||||
        data = {'detail': six.text_type(msg)}
 | 
					        data = {'detail': six.text_type(msg)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        set_rollback()
 | 
				
			||||||
        return Response(data, status=status.HTTP_403_FORBIDDEN)
 | 
					        return Response(data, status=status.HTTP_403_FORBIDDEN)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Note: Unhandled exceptions will raise a 500 error.
 | 
					    # Note: Unhandled exceptions will raise a 500 error.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										110
									
								
								tests/test_atomic_requests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								tests/test_atomic_requests.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,110 @@
 | 
				
			||||||
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db import connection, connections, transaction
 | 
				
			||||||
 | 
					from django.test import TestCase
 | 
				
			||||||
 | 
					from django.utils.unittest import skipUnless
 | 
				
			||||||
 | 
					from rest_framework import status
 | 
				
			||||||
 | 
					from rest_framework.exceptions import APIException
 | 
				
			||||||
 | 
					from rest_framework.response import Response
 | 
				
			||||||
 | 
					from rest_framework.test import APIRequestFactory
 | 
				
			||||||
 | 
					from rest_framework.views import APIView
 | 
				
			||||||
 | 
					from tests.models import BasicModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					factory = APIRequestFactory()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BasicView(APIView):
 | 
				
			||||||
 | 
					    def post(self, request, *args, **kwargs):
 | 
				
			||||||
 | 
					        BasicModel.objects.create()
 | 
				
			||||||
 | 
					        return Response({'method': 'GET'})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ErrorView(APIView):
 | 
				
			||||||
 | 
					    def post(self, request, *args, **kwargs):
 | 
				
			||||||
 | 
					        BasicModel.objects.create()
 | 
				
			||||||
 | 
					        raise Exception
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class APIExceptionView(APIView):
 | 
				
			||||||
 | 
					    def post(self, request, *args, **kwargs):
 | 
				
			||||||
 | 
					        BasicModel.objects.create()
 | 
				
			||||||
 | 
					        raise APIException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@skipUnless(connection.features.uses_savepoints,
 | 
				
			||||||
 | 
					            "'atomic' requires transactions and savepoints.")
 | 
				
			||||||
 | 
					class DBTransactionTests(TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.view = BasicView.as_view()
 | 
				
			||||||
 | 
					        connections.databases['default']['ATOMIC_REQUESTS'] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        connections.databases['default']['ATOMIC_REQUESTS'] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_no_exception_conmmit_transaction(self):
 | 
				
			||||||
 | 
					        request = factory.post('/')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertNumQueries(1):
 | 
				
			||||||
 | 
					            response = self.view(request)
 | 
				
			||||||
 | 
					        self.assertFalse(transaction.get_rollback())
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, status.HTTP_200_OK)
 | 
				
			||||||
 | 
					        assert BasicModel.objects.count() == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@skipUnless(connection.features.uses_savepoints,
 | 
				
			||||||
 | 
					            "'atomic' requires transactions and savepoints.")
 | 
				
			||||||
 | 
					class DBTransactionErrorTests(TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.view = ErrorView.as_view()
 | 
				
			||||||
 | 
					        connections.databases['default']['ATOMIC_REQUESTS'] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        connections.databases['default']['ATOMIC_REQUESTS'] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_generic_exception_delegate_transaction_management(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Transaction is eventually managed by outer-most transaction atomic
 | 
				
			||||||
 | 
					        block. DRF do not try to interfere here.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We let django deal with the transaction when it will catch the Exception.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        request = factory.post('/')
 | 
				
			||||||
 | 
					        with self.assertNumQueries(3):
 | 
				
			||||||
 | 
					            # 1 - begin savepoint
 | 
				
			||||||
 | 
					            # 2 - insert
 | 
				
			||||||
 | 
					            # 3 - release savepoint
 | 
				
			||||||
 | 
					            with transaction.atomic():
 | 
				
			||||||
 | 
					                self.assertRaises(Exception, self.view, request)
 | 
				
			||||||
 | 
					                self.assertFalse(transaction.get_rollback())
 | 
				
			||||||
 | 
					        assert BasicModel.objects.count() == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@skipUnless(connection.features.uses_savepoints,
 | 
				
			||||||
 | 
					            "'atomic' requires transactions and savepoints.")
 | 
				
			||||||
 | 
					class DBTransactionAPIExceptionTests(TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.view = APIExceptionView.as_view()
 | 
				
			||||||
 | 
					        connections.databases['default']['ATOMIC_REQUESTS'] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def tearDown(self):
 | 
				
			||||||
 | 
					        connections.databases['default']['ATOMIC_REQUESTS'] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_api_exception_rollback_transaction(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Transaction is rollbacked by our transaction atomic block.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        request = factory.post('/')
 | 
				
			||||||
 | 
					        num_queries = (4 if getattr(connection.features,
 | 
				
			||||||
 | 
					                                    'can_release_savepoints', False) else 3)
 | 
				
			||||||
 | 
					        with self.assertNumQueries(num_queries):
 | 
				
			||||||
 | 
					            # 1 - begin savepoint
 | 
				
			||||||
 | 
					            # 2 - insert
 | 
				
			||||||
 | 
					            # 3 - rollback savepoint
 | 
				
			||||||
 | 
					            # 4 - release savepoint (django>=1.8 only)
 | 
				
			||||||
 | 
					            with transaction.atomic():
 | 
				
			||||||
 | 
					                response = self.view(request)
 | 
				
			||||||
 | 
					                self.assertTrue(transaction.get_rollback())
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code,
 | 
				
			||||||
 | 
					                         status.HTTP_500_INTERNAL_SERVER_ERROR)
 | 
				
			||||||
 | 
					        assert BasicModel.objects.count() == 0
 | 
				
			||||||
							
								
								
									
										2
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								tox.ini
									
									
									
									
									
								
							| 
						 | 
					@ -6,7 +6,7 @@ envlist =
 | 
				
			||||||
       {py27,py32,py33,py34}-django{17,18,master}
 | 
					       {py27,py32,py33,py34}-django{17,18,master}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[testenv]
 | 
					[testenv]
 | 
				
			||||||
commands = ./runtests.py --fast
 | 
					commands = ./runtests.py --fast {posargs}
 | 
				
			||||||
setenv =
 | 
					setenv =
 | 
				
			||||||
       PYTHONDONTWRITEBYTECODE=1
 | 
					       PYTHONDONTWRITEBYTECODE=1
 | 
				
			||||||
deps =
 | 
					deps =
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user