diff --git a/graphene_django/tests/test_views.py b/graphene_django/tests/test_views.py index 669f60f..b1332bd 100644 --- a/graphene_django/tests/test_views.py +++ b/graphene_django/tests/test_views.py @@ -864,7 +864,7 @@ def test_model_form_mutation_multiple_creation_invalid_non_atomic(client): graphene_settings.ATOMIC_MUTATIONS = old_graphene_atomic_mutations -@patch("rest_framework.views.transaction.set_rollback") +@patch("graphene_django.utils.utils.transaction.set_rollback") def test_query_errors_atomic_request(set_rollback_mock, client): old_atomic_mutations = connection.settings_dict.get("ATOMIC_MUTATIONS", False) old_atomic_requests = connection.settings_dict["ATOMIC_REQUESTS"] @@ -883,7 +883,7 @@ def test_query_errors_atomic_request(set_rollback_mock, client): graphene_settings.ATOMIC_MUTATIONS = old_graphene_atomic_mutations -@patch("rest_framework.views.transaction.set_rollback") +@patch("graphene_django.utils.utils.transaction.set_rollback") def test_query_errors_non_atomic(set_rollback_mock, client): old_atomic_mutations = connection.settings_dict.get("ATOMIC_MUTATIONS", False) old_atomic_requests = connection.settings_dict["ATOMIC_REQUESTS"] diff --git a/graphene_django/utils/utils.py b/graphene_django/utils/utils.py index c1d3572..b1c9a7d 100644 --- a/graphene_django/utils/utils.py +++ b/graphene_django/utils/utils.py @@ -1,7 +1,7 @@ import inspect import six -from django.db import models +from django.db import connection, models, transaction from django.db.models.manager import Manager from django.utils.encoding import force_text from django.utils.functional import Promise @@ -100,3 +100,9 @@ def import_single_dispatch(): ) return singledispatch + + +def set_rollback(): + atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False) + if atomic_requests and connection.in_atomic_block: + transaction.set_rollback(True) diff --git a/graphene_django/views.py b/graphene_django/views.py index b1e9352..e81f760 100644 --- a/graphene_django/views.py +++ b/graphene_django/views.py @@ -18,9 +18,8 @@ from graphql.execution import ExecutionResult from graphql.type.schema import GraphQLSchema from graphql.execution.middleware import MiddlewareManager -from rest_framework.views import set_rollback - from graphene_django.constants import MUTATION_ERRORS_FLAG +from graphene_django.utils.utils import set_rollback from .settings import graphene_settings