From ac3e8495e7af7a82c2c75af3819726293a5466ac Mon Sep 17 00:00:00 2001 From: Arad Bar Sadeh Date: Mon, 26 Oct 2020 14:43:09 +0200 Subject: [PATCH] Support atomic transaction views in multiple database connections --- tests/test_atomic_requests.py | 101 +++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 21 deletions(-) diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index beda5cba1..14eaf7987 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -16,30 +16,35 @@ factory = APIRequestFactory() class BasicView(APIView): + database = 'default' + + def get_queryset(self): + return BasicModel.objects.using(self.database).all() + def post(self, request, *args, **kwargs): - BasicModel.objects.create() + self.get_queryset().create() return Response({'method': 'GET'}) -class ErrorView(APIView): +class ErrorView(BasicView): def post(self, request, *args, **kwargs): - BasicModel.objects.create() + self.get_queryset().create() raise Exception -class APIExceptionView(APIView): +class APIExceptionView(BasicView): def post(self, request, *args, **kwargs): - BasicModel.objects.create() + self.get_queryset().create() raise APIException -class NonAtomicAPIExceptionView(APIView): +class NonAtomicAPIExceptionView(BasicView): @transaction.non_atomic_requests def dispatch(self, *args, **kwargs): return super().dispatch(*args, **kwargs) def get(self, request, *args, **kwargs): - BasicModel.objects.all() + self.get_queryset() raise Http404 @@ -53,34 +58,52 @@ urlpatterns = ( "'atomic' requires transactions and savepoints." ) class DBTransactionTests(TestCase): + databases = '__all__' + def setUp(self): - self.view = BasicView.as_view() - connections.databases['default']['ATOMIC_REQUESTS'] = True + self.view = BasicView + for database in connections.databases: + connections.databases[database]['ATOMIC_REQUESTS'] = True def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + for database in connections.databases: + connections.databases[database]['ATOMIC_REQUESTS'] = False def test_no_exception_commit_transaction(self): request = factory.post('/') with self.assertNumQueries(1): - response = self.view(request) + response = self.view.as_view()(request) assert not transaction.get_rollback() assert response.status_code == status.HTTP_200_OK assert BasicModel.objects.count() == 1 + def test_no_exception_commit_transaction_spare_connection(self): + request = factory.post('/') + + with self.assertNumQueries(1, using='spare'): + view = self.view.as_view(database='spare') + response = view(request) + assert not transaction.get_rollback(using='spare') + assert response.status_code == status.HTTP_200_OK + assert BasicModel.objects.using('spare').count() == 1 + @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." ) class DBTransactionErrorTests(TestCase): + databases = '__all__' + def setUp(self): - self.view = ErrorView.as_view() - connections.databases['default']['ATOMIC_REQUESTS'] = True + self.view = ErrorView + for database in connections.databases: + connections.databases[database]['ATOMIC_REQUESTS'] = True def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + for database in connections.databases: + connections.databases[database]['ATOMIC_REQUESTS'] = False def test_generic_exception_delegate_transaction_management(self): """ @@ -95,22 +118,37 @@ class DBTransactionErrorTests(TestCase): # 2 - insert # 3 - release savepoint with transaction.atomic(): - self.assertRaises(Exception, self.view, request) + self.assertRaises(Exception, self.view.as_view(), request) assert not transaction.get_rollback() assert BasicModel.objects.count() == 1 + def test_generic_exception_delegate_transaction_management_spare_connections(self): + request = factory.post('/') + with self.assertNumQueries(3, using='spare'): + # 1 - begin savepoint + # 2 - insert + # 3 - release savepoint + with transaction.atomic(using='spare'): + self.assertRaises(Exception, self.view.as_view(database='spare'), request) + assert not transaction.get_rollback(using='spare') + assert BasicModel.objects.using('spare').count() == 1 + @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." ) class DBTransactionAPIExceptionTests(TestCase): + databases = '__all__' + def setUp(self): - self.view = APIExceptionView.as_view() - connections.databases['default']['ATOMIC_REQUESTS'] = True + self.view = APIExceptionView + for database in connections.databases: + connections.databases[database]['ATOMIC_REQUESTS'] = True def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + for database in connections.databases: + connections.databases[database]['ATOMIC_REQUESTS'] = False def test_api_exception_rollback_transaction(self): """ @@ -124,11 +162,28 @@ class DBTransactionAPIExceptionTests(TestCase): # 3 - rollback savepoint # 4 - release savepoint with transaction.atomic(): - response = self.view(request) + response = self.view.as_view()(request) assert transaction.get_rollback() assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert BasicModel.objects.count() == 0 + def test_api_exception_rollback_transaction_spare_connection(self): + """ + Transaction is rollbacked by our transaction atomic block. + """ + request = factory.post('/') + num_queries = 4 if connections['spare'].features.can_release_savepoints else 3 + with self.assertNumQueries(num_queries, using='spare'): + # 1 - begin savepoint + # 2 - insert + # 3 - rollback savepoint + # 4 - release savepoint + with transaction.atomic(using='spare'): + response = self.view.as_view(database='spare')(request) + assert transaction.get_rollback(using='spare') + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert BasicModel.objects.using('spare').count() == 0 + @unittest.skipUnless( connection.features.uses_savepoints, @@ -171,11 +226,15 @@ class MultiDBTransactionAPIExceptionTests(TestCase): ) @override_settings(ROOT_URLCONF='tests.test_atomic_requests') class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): + databases = '__all__' + def setUp(self): - connections.databases['default']['ATOMIC_REQUESTS'] = True + for database in connections.databases: + connections.databases[database]['ATOMIC_REQUESTS'] = True def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + for database in connections.databases: + connections.databases[database]['ATOMIC_REQUESTS'] = False def test_api_exception_rollback_transaction_non_atomic_view(self): response = self.client.get('/')