diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py index 50b0cfe..3e66a8f 100644 --- a/graphene_django/debug/middleware.py +++ b/graphene_django/debug/middleware.py @@ -9,6 +9,8 @@ from graphql.type.definition import GraphQLNonNull from django.db.models import QuerySet +from ..utils import is_sync_function + class DjangoDebugContext: def __init__(self): @@ -89,9 +91,7 @@ class DjangoSyncRequiredMiddleware: if hasattr(parent_type, "graphene_type") and hasattr( parent_type.graphene_type._meta, "model" ): - if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction( - next - ): + if is_sync_function(next): return sync_to_async(next)(root, info, **args) ## In addition, if we're resolving to a DjangoObject type @@ -99,15 +99,11 @@ class DjangoSyncRequiredMiddleware: if hasattr(return_type, "graphene_type") and hasattr( return_type.graphene_type._meta, "model" ): - if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction( - next - ): + if is_sync_function(next): return sync_to_async(next)(root, info, **args) if info.parent_type.name == "Mutation": - if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction( - next - ): + if is_sync_function(next): return sync_to_async(next)(root, info, **args) return next(root, info, **args) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 081d870..db846b9 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -11,7 +11,6 @@ from graphql_relay import ( ) from asgiref.sync import sync_to_async -from asyncio import get_running_loop from graphene import Int, NonNull from graphene.relay import ConnectionField @@ -19,7 +18,7 @@ from graphene.relay.connection import connection_adapter, page_info_adapter from graphene.types import Field, List from .settings import graphene_settings -from .utils import maybe_queryset +from .utils import maybe_queryset, is_sync_function, is_running_async class DjangoListField(Field): @@ -92,16 +91,12 @@ class DjangoListField(Field): _type = _type.of_type django_object_type = _type.of_type.of_type - try: - get_running_loop() - except RuntimeError: + if not is_running_async(): return partial( self.list_resolver, django_object_type, resolver, self.get_manager() ) else: - if not inspect.iscoroutinefunction( - resolver - ) and not inspect.isasyncgenfunction(resolver): + if is_sync_function(resolver): async_resolver = sync_to_async(resolver) ## This is needed because our middleware can't detect the resolver as async when we returns partial[couroutine] @@ -271,14 +266,8 @@ class DjangoConnectionField(ConnectionField): # eventually leads to DjangoObjectType's get_queryset (accepts queryset) # or a resolve_foo (does not accept queryset) - try: - get_running_loop() - except RuntimeError: - pass - else: - if not inspect.iscoroutinefunction( - resolver - ) and not inspect.isasyncgenfunction(resolver): + if is_running_async(): + if is_sync_function(resolver): resolver = sync_to_async(resolver) iterable = resolver(root, info, **args) @@ -305,11 +294,7 @@ class DjangoConnectionField(ConnectionField): # thus the iterable gets refiltered by resolve_queryset # but iterable might be promise - try: - get_running_loop() - except RuntimeError: - pass - else: + if is_running_async(): async def perform_resolve(iterable): iterable = await sync_to_async(queryset_resolver)( diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index 46d3593..ca2764d 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -2,7 +2,6 @@ from collections import OrderedDict from django.shortcuts import get_object_or_404 from rest_framework import serializers -from asyncio import get_running_loop from asgiref.sync import sync_to_async import graphene @@ -13,6 +12,7 @@ from graphene.types.objecttype import yank_fields_from_attrs from ..types import ErrorType from .serializer_converter import convert_serializer_field +from ..utils import is_running_async class SerializerMutationOptions(MutationOptions): @@ -154,11 +154,7 @@ class SerializerMutation(ClientIDMutation): kwargs = cls.get_serializer_kwargs(root, info, **input) serializer = cls._meta.serializer_class(**kwargs) - try: - get_running_loop() - except RuntimeError: - pass - else: + if is_running_async(): async def perform_mutate_async(): if await sync_to_async(serializer.is_valid)(): diff --git a/graphene_django/types.py b/graphene_django/types.py index a064b90..7216cde 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -16,6 +16,7 @@ from .utils import ( camelize, get_model_fields, is_valid_django_model, + is_running_async, ) ALL_FIELDS = "__all__" @@ -288,13 +289,7 @@ class DjangoObjectType(ObjectType): def get_node(cls, info, id): queryset = cls.get_queryset(cls._meta.model.objects, info) try: - try: - import asyncio - - asyncio.get_running_loop() - except RuntimeError: - pass - else: + if is_running_async(): return queryset.aget(pk=id) return queryset.get(pk=id) diff --git a/graphene_django/utils/__init__.py b/graphene_django/utils/__init__.py index 671b060..7344ce5 100644 --- a/graphene_django/utils/__init__.py +++ b/graphene_django/utils/__init__.py @@ -6,6 +6,8 @@ from .utils import ( get_reverse_fields, is_valid_django_model, maybe_queryset, + is_sync_function, + is_running_async, ) __all__ = [ @@ -16,4 +18,6 @@ __all__ = [ "camelize", "is_valid_django_model", "GraphQLTestCase", + "is_sync_function", + "is_running_async", ] diff --git a/graphene_django/utils/utils.py b/graphene_django/utils/utils.py index 343a3a7..beac3e6 100644 --- a/graphene_django/utils/utils.py +++ b/graphene_django/utils/utils.py @@ -1,4 +1,5 @@ import inspect +from asyncio import get_running_loop from django.db import connection, models, transaction from django.db.models.manager import Manager @@ -105,3 +106,18 @@ def set_rollback(): atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False) if atomic_requests and connection.in_atomic_block: transaction.set_rollback(True) + + +def is_running_async(): + try: + get_running_loop() + except RuntimeError: + return False + else: + return True + + +def is_sync_function(func): + return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction( + func + )