Support atomic transaction views in multiple database connections

This commit is contained in:
Arad Bar Sadeh 2020-10-26 14:43:09 +02:00
parent 0618fa88e1
commit ac3e8495e7

View File

@ -16,30 +16,35 @@ factory = APIRequestFactory()
class BasicView(APIView): class BasicView(APIView):
database = 'default'
def get_queryset(self):
return BasicModel.objects.using(self.database).all()
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
BasicModel.objects.create() self.get_queryset().create()
return Response({'method': 'GET'}) return Response({'method': 'GET'})
class ErrorView(APIView): class ErrorView(BasicView):
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
BasicModel.objects.create() self.get_queryset().create()
raise Exception raise Exception
class APIExceptionView(APIView): class APIExceptionView(BasicView):
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
BasicModel.objects.create() self.get_queryset().create()
raise APIException raise APIException
class NonAtomicAPIExceptionView(APIView): class NonAtomicAPIExceptionView(BasicView):
@transaction.non_atomic_requests @transaction.non_atomic_requests
def dispatch(self, *args, **kwargs): def dispatch(self, *args, **kwargs):
return super().dispatch(*args, **kwargs) return super().dispatch(*args, **kwargs)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
BasicModel.objects.all() self.get_queryset()
raise Http404 raise Http404
@ -53,34 +58,52 @@ urlpatterns = (
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionTests(TestCase): class DBTransactionTests(TestCase):
databases = '__all__'
def setUp(self): def setUp(self):
self.view = BasicView.as_view() self.view = BasicView
connections.databases['default']['ATOMIC_REQUESTS'] = True for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False
def test_no_exception_commit_transaction(self): def test_no_exception_commit_transaction(self):
request = factory.post('/') request = factory.post('/')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request) response = self.view.as_view()(request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
def test_no_exception_commit_transaction_spare_connection(self):
request = factory.post('/')
with self.assertNumQueries(1, using='spare'):
view = self.view.as_view(database='spare')
response = view(request)
assert not transaction.get_rollback(using='spare')
assert response.status_code == status.HTTP_200_OK
assert BasicModel.objects.using('spare').count() == 1
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionErrorTests(TestCase): class DBTransactionErrorTests(TestCase):
databases = '__all__'
def setUp(self): def setUp(self):
self.view = ErrorView.as_view() self.view = ErrorView
connections.databases['default']['ATOMIC_REQUESTS'] = True for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False
def test_generic_exception_delegate_transaction_management(self): def test_generic_exception_delegate_transaction_management(self):
""" """
@ -95,22 +118,37 @@ class DBTransactionErrorTests(TestCase):
# 2 - insert # 2 - insert
# 3 - release savepoint # 3 - release savepoint
with transaction.atomic(): with transaction.atomic():
self.assertRaises(Exception, self.view, request) self.assertRaises(Exception, self.view.as_view(), request)
assert not transaction.get_rollback() assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
def test_generic_exception_delegate_transaction_management_spare_connections(self):
request = factory.post('/')
with self.assertNumQueries(3, using='spare'):
# 1 - begin savepoint
# 2 - insert
# 3 - release savepoint
with transaction.atomic(using='spare'):
self.assertRaises(Exception, self.view.as_view(database='spare'), request)
assert not transaction.get_rollback(using='spare')
assert BasicModel.objects.using('spare').count() == 1
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints."
) )
class DBTransactionAPIExceptionTests(TestCase): class DBTransactionAPIExceptionTests(TestCase):
databases = '__all__'
def setUp(self): def setUp(self):
self.view = APIExceptionView.as_view() self.view = APIExceptionView
connections.databases['default']['ATOMIC_REQUESTS'] = True for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False
def test_api_exception_rollback_transaction(self): def test_api_exception_rollback_transaction(self):
""" """
@ -124,11 +162,28 @@ class DBTransactionAPIExceptionTests(TestCase):
# 3 - rollback savepoint # 3 - rollback savepoint
# 4 - release savepoint # 4 - release savepoint
with transaction.atomic(): with transaction.atomic():
response = self.view(request) response = self.view.as_view()(request)
assert transaction.get_rollback() assert transaction.get_rollback()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.count() == 0 assert BasicModel.objects.count() == 0
def test_api_exception_rollback_transaction_spare_connection(self):
"""
Transaction is rollbacked by our transaction atomic block.
"""
request = factory.post('/')
num_queries = 4 if connections['spare'].features.can_release_savepoints else 3
with self.assertNumQueries(num_queries, using='spare'):
# 1 - begin savepoint
# 2 - insert
# 3 - rollback savepoint
# 4 - release savepoint
with transaction.atomic(using='spare'):
response = self.view.as_view(database='spare')(request)
assert transaction.get_rollback(using='spare')
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.using('spare').count() == 0
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
@ -171,11 +226,15 @@ class MultiDBTransactionAPIExceptionTests(TestCase):
) )
@override_settings(ROOT_URLCONF='tests.test_atomic_requests') @override_settings(ROOT_URLCONF='tests.test_atomic_requests')
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
databases = '__all__'
def setUp(self): def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False
def test_api_exception_rollback_transaction_non_atomic_view(self): def test_api_exception_rollback_transaction_non_atomic_view(self):
response = self.client.get('/') response = self.client.get('/')