This commit is contained in:
Adam Johnson 2025-05-04 20:07:09 +03:00 committed by GitHub
commit 1d55a54bec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 13 deletions

View File

@ -4,7 +4,7 @@ Provides an APIView class that is the base of all views in REST framework.
from django import VERSION as DJANGO_VERSION from django import VERSION as DJANGO_VERSION
from django.conf import settings from django.conf import settings
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.db import connections, models from django.db import connections, models, transaction
from django.http import Http404 from django.http import Http404
from django.http.response import HttpResponseBase from django.http.response import HttpResponseBase
from django.utils.cache import cc_delim_re, patch_vary_headers from django.utils.cache import cc_delim_re, patch_vary_headers
@ -64,9 +64,13 @@ def get_view_description(view, html=False):
def set_rollback(): def set_rollback():
# Rollback all connections that have ATOMIC_REQUESTS set, if it looks their
# @atomic for the request was started
# Note this check in_atomic_block may be a false positive due to
# transactions started another way, e.g. through testing with TestCase
for db in connections.all(): for db in connections.all():
if db.settings_dict['ATOMIC_REQUESTS'] and db.in_atomic_block: if db.settings_dict['ATOMIC_REQUESTS'] and db.in_atomic_block:
db.set_rollback(True) transaction.set_rollback(True, using=db.alias)
def exception_handler(exc, context): def exception_handler(exc, context):
@ -229,9 +233,9 @@ class APIView(View):
""" """
return { return {
'view': self, 'view': self,
'args': getattr(self, 'args', ()), 'args': self.args,
'kwargs': getattr(self, 'kwargs', {}), 'kwargs': self.kwargs,
'request': getattr(self, 'request', None) 'request': self.request,
} }
def get_view_name(self): def get_view_name(self):

View File

@ -39,11 +39,12 @@ class NonAtomicAPIExceptionView(APIView):
return super().dispatch(*args, **kwargs) return super().dispatch(*args, **kwargs)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
BasicModel.objects.all() list(BasicModel.objects.all())
raise Http404 raise Http404
urlpatterns = ( urlpatterns = (
path('non-atomic-exception', NonAtomicAPIExceptionView.as_view()),
path('', NonAtomicAPIExceptionView.as_view()), path('', NonAtomicAPIExceptionView.as_view()),
) )
@ -94,8 +95,8 @@ class DBTransactionErrorTests(TestCase):
# 1 - begin savepoint # 1 - begin savepoint
# 2 - insert # 2 - insert
# 3 - release savepoint # 3 - release savepoint
with transaction.atomic(): with transaction.atomic(), self.assertRaises(Exception):
self.assertRaises(Exception, self.view, request) self.view(request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
@ -174,12 +175,15 @@ class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
def setUp(self): def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases['default']['ATOMIC_REQUESTS'] = True
def tearDown(self): @self.addCleanup
def restore_atomic_requests():
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_api_exception_rollback_transaction_non_atomic_view(self): def test_api_exception_rollback_transaction_non_atomic_view(self):
response = self.client.get('/') response = self.client.get('/non-atomic-exception')
# without checking connection.in_atomic_block view raises 500 # without check for db.in_atomic_block, would raise 500 due to attempt
# due attempt to rollback without transaction # to rollback without transaction
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
# Check we can still perform DB queries
list(BasicModel.objects.all())