Refactor out async helper functions

This commit is contained in:
Josh Warwick 2023-05-16 16:20:13 +01:00
parent d3f8fcf906
commit 930248f78d
6 changed files with 35 additions and 43 deletions

View File

@ -9,6 +9,8 @@ from graphql.type.definition import GraphQLNonNull
from django.db.models import QuerySet
from ..utils import is_sync_function
class DjangoDebugContext:
def __init__(self):
@ -89,9 +91,7 @@ class DjangoSyncRequiredMiddleware:
if hasattr(parent_type, "graphene_type") and hasattr(
parent_type.graphene_type._meta, "model"
):
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
next
):
if is_sync_function(next):
return sync_to_async(next)(root, info, **args)
## In addition, if we're resolving to a DjangoObject type
@ -99,15 +99,11 @@ class DjangoSyncRequiredMiddleware:
if hasattr(return_type, "graphene_type") and hasattr(
return_type.graphene_type._meta, "model"
):
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
next
):
if is_sync_function(next):
return sync_to_async(next)(root, info, **args)
if info.parent_type.name == "Mutation":
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
next
):
if is_sync_function(next):
return sync_to_async(next)(root, info, **args)
return next(root, info, **args)

View File

@ -11,7 +11,6 @@ from graphql_relay import (
)
from asgiref.sync import sync_to_async
from asyncio import get_running_loop
from graphene import Int, NonNull
from graphene.relay import ConnectionField
@ -19,7 +18,7 @@ from graphene.relay.connection import connection_adapter, page_info_adapter
from graphene.types import Field, List
from .settings import graphene_settings
from .utils import maybe_queryset
from .utils import maybe_queryset, is_sync_function, is_running_async
class DjangoListField(Field):
@ -92,16 +91,12 @@ class DjangoListField(Field):
_type = _type.of_type
django_object_type = _type.of_type.of_type
try:
get_running_loop()
except RuntimeError:
if not is_running_async():
return partial(
self.list_resolver, django_object_type, resolver, self.get_manager()
)
else:
if not inspect.iscoroutinefunction(
resolver
) and not inspect.isasyncgenfunction(resolver):
if is_sync_function(resolver):
async_resolver = sync_to_async(resolver)
## This is needed because our middleware can't detect the resolver as async when we returns partial[couroutine]
@ -271,14 +266,8 @@ class DjangoConnectionField(ConnectionField):
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
# or a resolve_foo (does not accept queryset)
try:
get_running_loop()
except RuntimeError:
pass
else:
if not inspect.iscoroutinefunction(
resolver
) and not inspect.isasyncgenfunction(resolver):
if is_running_async():
if is_sync_function(resolver):
resolver = sync_to_async(resolver)
iterable = resolver(root, info, **args)
@ -305,11 +294,7 @@ class DjangoConnectionField(ConnectionField):
# thus the iterable gets refiltered by resolve_queryset
# but iterable might be promise
try:
get_running_loop()
except RuntimeError:
pass
else:
if is_running_async():
async def perform_resolve(iterable):
iterable = await sync_to_async(queryset_resolver)(

View File

@ -2,7 +2,6 @@ from collections import OrderedDict
from django.shortcuts import get_object_or_404
from rest_framework import serializers
from asyncio import get_running_loop
from asgiref.sync import sync_to_async
import graphene
@ -13,6 +12,7 @@ from graphene.types.objecttype import yank_fields_from_attrs
from ..types import ErrorType
from .serializer_converter import convert_serializer_field
from ..utils import is_running_async
class SerializerMutationOptions(MutationOptions):
@ -154,11 +154,7 @@ class SerializerMutation(ClientIDMutation):
kwargs = cls.get_serializer_kwargs(root, info, **input)
serializer = cls._meta.serializer_class(**kwargs)
try:
get_running_loop()
except RuntimeError:
pass
else:
if is_running_async():
async def perform_mutate_async():
if await sync_to_async(serializer.is_valid)():

View File

@ -16,6 +16,7 @@ from .utils import (
camelize,
get_model_fields,
is_valid_django_model,
is_running_async,
)
ALL_FIELDS = "__all__"
@ -288,13 +289,7 @@ class DjangoObjectType(ObjectType):
def get_node(cls, info, id):
queryset = cls.get_queryset(cls._meta.model.objects, info)
try:
try:
import asyncio
asyncio.get_running_loop()
except RuntimeError:
pass
else:
if is_running_async():
return queryset.aget(pk=id)
return queryset.get(pk=id)

View File

@ -6,6 +6,8 @@ from .utils import (
get_reverse_fields,
is_valid_django_model,
maybe_queryset,
is_sync_function,
is_running_async,
)
__all__ = [
@ -16,4 +18,6 @@ __all__ = [
"camelize",
"is_valid_django_model",
"GraphQLTestCase",
"is_sync_function",
"is_running_async",
]

View File

@ -1,4 +1,5 @@
import inspect
from asyncio import get_running_loop
from django.db import connection, models, transaction
from django.db.models.manager import Manager
@ -105,3 +106,18 @@ def set_rollback():
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)
def is_running_async():
try:
get_running_loop()
except RuntimeError:
return False
else:
return True
def is_sync_function(func):
return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
func
)