Improve set_rollback() behaviour

\## Description

Fixes #6921.

Added tests that fail before and pass afterwards.

Remove the check for `connection.in_atomic_block` to determine if the current request is under a `transaction.atomic` from `ATOMIC_REQUESTS`. Instead, duplicate the method that Django itself uses [in BaseHandler](964dd4f4f2/django/core/handlers/base.py (L64)).

This requires fetching the actual view function from `as_view()`, as seen by the URL resolver / BaseHandler. Since this requires `request`, I've also changed the accesses in `get_exception_handler_context` to be direct attribute accesses rather than `getattr()`. It seems the `getattr` defaults not accessible since `self.request`, `self.args`, and `self.kwargs` are always set in `dispatch()` before `handle_exception()` can ever be called. This is useful since `request` is always needed for the new `set_rollback` logic.
This commit is contained in:
Adam Johnson 2019-09-10 11:21:20 +01:00
parent de497a9bf1
commit 5e5f559b4d
2 changed files with 48 additions and 18 deletions

View File

@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.
""" """
from django.conf import settings from django.conf import settings
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.db import connection, models, transaction from django.db import connections, models, transaction
from django.http import Http404 from django.http import Http404
from django.http.response import HttpResponseBase from django.http.response import HttpResponseBase
from django.utils.cache import cc_delim_re, patch_vary_headers from django.utils.cache import cc_delim_re, patch_vary_headers
@ -62,10 +62,19 @@ def get_view_description(view, html=False):
return description return description
def set_rollback(): def set_rollback(request):
atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) # We need the actual view func returned by the URL resolver which gets used
if atomic_requests and connection.in_atomic_block: # by Django's BaseHandler to determine `non_atomic_requests`. Be cautious
transaction.set_rollback(True) # when fetching it though as it won't be set when views are tested with
# requessts from a RequestFactory.
try:
non_atomic_requests = request.resolver_match.func._non_atomic_requests
except AttributeError:
non_atomic_requests = set()
for db in connections.all():
if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests:
transaction.set_rollback(True, using=db.alias)
def exception_handler(exc, context): def exception_handler(exc, context):
@ -95,7 +104,7 @@ def exception_handler(exc, context):
else: else:
data = {'detail': exc.detail} data = {'detail': exc.detail}
set_rollback() set_rollback(context['request'])
return Response(data, status=exc.status_code, headers=headers) return Response(data, status=exc.status_code, headers=headers)
return None return None
@ -223,9 +232,9 @@ class APIView(View):
""" """
return { return {
'view': self, 'view': self,
'args': getattr(self, 'args', ()), 'args': self.args,
'kwargs': getattr(self, 'kwargs', {}), 'kwargs': self.kwargs,
'request': getattr(self, 'request', None) 'request': self.request,
} }
def get_view_name(self): def get_view_name(self):

View File

@ -3,7 +3,7 @@ import unittest
from django.conf.urls import url from django.conf.urls import url
from django.db import connection, connections, transaction from django.db import connection, connections, transaction
from django.http import Http404 from django.http import Http404
from django.test import TestCase, TransactionTestCase, override_settings from django.test import TestCase, override_settings
from rest_framework import status from rest_framework import status
from rest_framework.exceptions import APIException from rest_framework.exceptions import APIException
@ -39,12 +39,24 @@ class NonAtomicAPIExceptionView(APIView):
return super().dispatch(*args, **kwargs) return super().dispatch(*args, **kwargs)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
BasicModel.objects.all() list(BasicModel.objects.all())
raise Http404
class UrlDecoratedNonAtomicAPIExceptionView(APIView):
def get(self, request, *args, **kwargs):
list(BasicModel.objects.all())
raise Http404 raise Http404
urlpatterns = ( urlpatterns = (
url(r'^$', NonAtomicAPIExceptionView.as_view()), url(r'^non-atomic-exception$', NonAtomicAPIExceptionView.as_view()),
url(
r'^url-decorated-non-atomic-exception$',
transaction.non_atomic_requests(
UrlDecoratedNonAtomicAPIExceptionView.as_view()
),
),
) )
@ -94,8 +106,8 @@ class DBTransactionErrorTests(TestCase):
# 1 - begin savepoint # 1 - begin savepoint
# 2 - insert # 2 - insert
# 3 - release savepoint # 3 - release savepoint
with transaction.atomic(): with transaction.atomic(), self.assertRaises(Exception):
self.assertRaises(Exception, self.view, request) self.view(request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
@ -135,7 +147,7 @@ class DBTransactionAPIExceptionTests(TestCase):
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
@override_settings(ROOT_URLCONF='tests.test_atomic_requests') @override_settings(ROOT_URLCONF='tests.test_atomic_requests')
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): class NonAtomicDBTransactionAPIExceptionTests(TestCase):
def setUp(self): def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases['default']['ATOMIC_REQUESTS'] = True
@ -143,8 +155,17 @@ class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_api_exception_rollback_transaction_non_atomic_view(self): def test_api_exception_rollback_transaction_non_atomic_view(self):
response = self.client.get('/') response = self.client.get('/non-atomic-exception')
# without checking connection.in_atomic_block view raises 500
# due attempt to rollback without transaction
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert not transaction.get_rollback()
# Check we can still perform DB queries
list(BasicModel.objects.all())
def test_api_exception_rollback_transaction_url_decorated_non_atomic_view(self):
response = self.client.get('/url-decorated-non-atomic-exception')
assert response.status_code == status.HTTP_404_NOT_FOUND
assert not transaction.get_rollback()
# Check we can still perform DB queries
list(BasicModel.objects.all())