mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-04-13 05:34:20 +03:00
Refactor out async helper functions
This commit is contained in:
parent
d3f8fcf906
commit
930248f78d
|
@ -9,6 +9,8 @@ from graphql.type.definition import GraphQLNonNull
|
|||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from ..utils import is_sync_function
|
||||
|
||||
|
||||
class DjangoDebugContext:
|
||||
def __init__(self):
|
||||
|
@ -89,9 +91,7 @@ class DjangoSyncRequiredMiddleware:
|
|||
if hasattr(parent_type, "graphene_type") and hasattr(
|
||||
parent_type.graphene_type._meta, "model"
|
||||
):
|
||||
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
|
||||
next
|
||||
):
|
||||
if is_sync_function(next):
|
||||
return sync_to_async(next)(root, info, **args)
|
||||
|
||||
## In addition, if we're resolving to a DjangoObject type
|
||||
|
@ -99,15 +99,11 @@ class DjangoSyncRequiredMiddleware:
|
|||
if hasattr(return_type, "graphene_type") and hasattr(
|
||||
return_type.graphene_type._meta, "model"
|
||||
):
|
||||
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
|
||||
next
|
||||
):
|
||||
if is_sync_function(next):
|
||||
return sync_to_async(next)(root, info, **args)
|
||||
|
||||
if info.parent_type.name == "Mutation":
|
||||
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
|
||||
next
|
||||
):
|
||||
if is_sync_function(next):
|
||||
return sync_to_async(next)(root, info, **args)
|
||||
|
||||
return next(root, info, **args)
|
||||
|
|
|
@ -11,7 +11,6 @@ from graphql_relay import (
|
|||
)
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from asyncio import get_running_loop
|
||||
|
||||
from graphene import Int, NonNull
|
||||
from graphene.relay import ConnectionField
|
||||
|
@ -19,7 +18,7 @@ from graphene.relay.connection import connection_adapter, page_info_adapter
|
|||
from graphene.types import Field, List
|
||||
|
||||
from .settings import graphene_settings
|
||||
from .utils import maybe_queryset
|
||||
from .utils import maybe_queryset, is_sync_function, is_running_async
|
||||
|
||||
|
||||
class DjangoListField(Field):
|
||||
|
@ -92,16 +91,12 @@ class DjangoListField(Field):
|
|||
_type = _type.of_type
|
||||
django_object_type = _type.of_type.of_type
|
||||
|
||||
try:
|
||||
get_running_loop()
|
||||
except RuntimeError:
|
||||
if not is_running_async():
|
||||
return partial(
|
||||
self.list_resolver, django_object_type, resolver, self.get_manager()
|
||||
)
|
||||
else:
|
||||
if not inspect.iscoroutinefunction(
|
||||
resolver
|
||||
) and not inspect.isasyncgenfunction(resolver):
|
||||
if is_sync_function(resolver):
|
||||
async_resolver = sync_to_async(resolver)
|
||||
|
||||
## This is needed because our middleware can't detect the resolver as async when we returns partial[couroutine]
|
||||
|
@ -271,14 +266,8 @@ class DjangoConnectionField(ConnectionField):
|
|||
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
|
||||
# or a resolve_foo (does not accept queryset)
|
||||
|
||||
try:
|
||||
get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
if not inspect.iscoroutinefunction(
|
||||
resolver
|
||||
) and not inspect.isasyncgenfunction(resolver):
|
||||
if is_running_async():
|
||||
if is_sync_function(resolver):
|
||||
resolver = sync_to_async(resolver)
|
||||
|
||||
iterable = resolver(root, info, **args)
|
||||
|
@ -305,11 +294,7 @@ class DjangoConnectionField(ConnectionField):
|
|||
# thus the iterable gets refiltered by resolve_queryset
|
||||
# but iterable might be promise
|
||||
|
||||
try:
|
||||
get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
if is_running_async():
|
||||
|
||||
async def perform_resolve(iterable):
|
||||
iterable = await sync_to_async(queryset_resolver)(
|
||||
|
|
|
@ -2,7 +2,6 @@ from collections import OrderedDict
|
|||
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework import serializers
|
||||
from asyncio import get_running_loop
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
import graphene
|
||||
|
@ -13,6 +12,7 @@ from graphene.types.objecttype import yank_fields_from_attrs
|
|||
|
||||
from ..types import ErrorType
|
||||
from .serializer_converter import convert_serializer_field
|
||||
from ..utils import is_running_async
|
||||
|
||||
|
||||
class SerializerMutationOptions(MutationOptions):
|
||||
|
@ -154,11 +154,7 @@ class SerializerMutation(ClientIDMutation):
|
|||
kwargs = cls.get_serializer_kwargs(root, info, **input)
|
||||
serializer = cls._meta.serializer_class(**kwargs)
|
||||
|
||||
try:
|
||||
get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
if is_running_async():
|
||||
|
||||
async def perform_mutate_async():
|
||||
if await sync_to_async(serializer.is_valid)():
|
||||
|
|
|
@ -16,6 +16,7 @@ from .utils import (
|
|||
camelize,
|
||||
get_model_fields,
|
||||
is_valid_django_model,
|
||||
is_running_async,
|
||||
)
|
||||
|
||||
ALL_FIELDS = "__all__"
|
||||
|
@ -288,13 +289,7 @@ class DjangoObjectType(ObjectType):
|
|||
def get_node(cls, info, id):
|
||||
queryset = cls.get_queryset(cls._meta.model.objects, info)
|
||||
try:
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
if is_running_async():
|
||||
return queryset.aget(pk=id)
|
||||
|
||||
return queryset.get(pk=id)
|
||||
|
|
|
@ -6,6 +6,8 @@ from .utils import (
|
|||
get_reverse_fields,
|
||||
is_valid_django_model,
|
||||
maybe_queryset,
|
||||
is_sync_function,
|
||||
is_running_async,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -16,4 +18,6 @@ __all__ = [
|
|||
"camelize",
|
||||
"is_valid_django_model",
|
||||
"GraphQLTestCase",
|
||||
"is_sync_function",
|
||||
"is_running_async",
|
||||
]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
from asyncio import get_running_loop
|
||||
|
||||
from django.db import connection, models, transaction
|
||||
from django.db.models.manager import Manager
|
||||
|
@ -105,3 +106,18 @@ def set_rollback():
|
|||
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
|
||||
if atomic_requests and connection.in_atomic_block:
|
||||
transaction.set_rollback(True)
|
||||
|
||||
|
||||
def is_running_async():
|
||||
try:
|
||||
get_running_loop()
|
||||
except RuntimeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_sync_function(func):
|
||||
return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
|
||||
func
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user