mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-07-01 10:53:13 +03:00
Refactor out async helper functions
This commit is contained in:
parent
d3f8fcf906
commit
930248f78d
|
@ -9,6 +9,8 @@ from graphql.type.definition import GraphQLNonNull
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
|
from ..utils import is_sync_function
|
||||||
|
|
||||||
|
|
||||||
class DjangoDebugContext:
|
class DjangoDebugContext:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -89,9 +91,7 @@ class DjangoSyncRequiredMiddleware:
|
||||||
if hasattr(parent_type, "graphene_type") and hasattr(
|
if hasattr(parent_type, "graphene_type") and hasattr(
|
||||||
parent_type.graphene_type._meta, "model"
|
parent_type.graphene_type._meta, "model"
|
||||||
):
|
):
|
||||||
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
|
if is_sync_function(next):
|
||||||
next
|
|
||||||
):
|
|
||||||
return sync_to_async(next)(root, info, **args)
|
return sync_to_async(next)(root, info, **args)
|
||||||
|
|
||||||
## In addition, if we're resolving to a DjangoObject type
|
## In addition, if we're resolving to a DjangoObject type
|
||||||
|
@ -99,15 +99,11 @@ class DjangoSyncRequiredMiddleware:
|
||||||
if hasattr(return_type, "graphene_type") and hasattr(
|
if hasattr(return_type, "graphene_type") and hasattr(
|
||||||
return_type.graphene_type._meta, "model"
|
return_type.graphene_type._meta, "model"
|
||||||
):
|
):
|
||||||
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
|
if is_sync_function(next):
|
||||||
next
|
|
||||||
):
|
|
||||||
return sync_to_async(next)(root, info, **args)
|
return sync_to_async(next)(root, info, **args)
|
||||||
|
|
||||||
if info.parent_type.name == "Mutation":
|
if info.parent_type.name == "Mutation":
|
||||||
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
|
if is_sync_function(next):
|
||||||
next
|
|
||||||
):
|
|
||||||
return sync_to_async(next)(root, info, **args)
|
return sync_to_async(next)(root, info, **args)
|
||||||
|
|
||||||
return next(root, info, **args)
|
return next(root, info, **args)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from graphql_relay import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from asyncio import get_running_loop
|
|
||||||
|
|
||||||
from graphene import Int, NonNull
|
from graphene import Int, NonNull
|
||||||
from graphene.relay import ConnectionField
|
from graphene.relay import ConnectionField
|
||||||
|
@ -19,7 +18,7 @@ from graphene.relay.connection import connection_adapter, page_info_adapter
|
||||||
from graphene.types import Field, List
|
from graphene.types import Field, List
|
||||||
|
|
||||||
from .settings import graphene_settings
|
from .settings import graphene_settings
|
||||||
from .utils import maybe_queryset
|
from .utils import maybe_queryset, is_sync_function, is_running_async
|
||||||
|
|
||||||
|
|
||||||
class DjangoListField(Field):
|
class DjangoListField(Field):
|
||||||
|
@ -92,16 +91,12 @@ class DjangoListField(Field):
|
||||||
_type = _type.of_type
|
_type = _type.of_type
|
||||||
django_object_type = _type.of_type.of_type
|
django_object_type = _type.of_type.of_type
|
||||||
|
|
||||||
try:
|
if not is_running_async():
|
||||||
get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
return partial(
|
return partial(
|
||||||
self.list_resolver, django_object_type, resolver, self.get_manager()
|
self.list_resolver, django_object_type, resolver, self.get_manager()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not inspect.iscoroutinefunction(
|
if is_sync_function(resolver):
|
||||||
resolver
|
|
||||||
) and not inspect.isasyncgenfunction(resolver):
|
|
||||||
async_resolver = sync_to_async(resolver)
|
async_resolver = sync_to_async(resolver)
|
||||||
|
|
||||||
## This is needed because our middleware can't detect the resolver as async when we returns partial[couroutine]
|
## This is needed because our middleware can't detect the resolver as async when we returns partial[couroutine]
|
||||||
|
@ -271,14 +266,8 @@ class DjangoConnectionField(ConnectionField):
|
||||||
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
|
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
|
||||||
# or a resolve_foo (does not accept queryset)
|
# or a resolve_foo (does not accept queryset)
|
||||||
|
|
||||||
try:
|
if is_running_async():
|
||||||
get_running_loop()
|
if is_sync_function(resolver):
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if not inspect.iscoroutinefunction(
|
|
||||||
resolver
|
|
||||||
) and not inspect.isasyncgenfunction(resolver):
|
|
||||||
resolver = sync_to_async(resolver)
|
resolver = sync_to_async(resolver)
|
||||||
|
|
||||||
iterable = resolver(root, info, **args)
|
iterable = resolver(root, info, **args)
|
||||||
|
@ -305,11 +294,7 @@ class DjangoConnectionField(ConnectionField):
|
||||||
# thus the iterable gets refiltered by resolve_queryset
|
# thus the iterable gets refiltered by resolve_queryset
|
||||||
# but iterable might be promise
|
# but iterable might be promise
|
||||||
|
|
||||||
try:
|
if is_running_async():
|
||||||
get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
|
|
||||||
async def perform_resolve(iterable):
|
async def perform_resolve(iterable):
|
||||||
iterable = await sync_to_async(queryset_resolver)(
|
iterable = await sync_to_async(queryset_resolver)(
|
||||||
|
|
|
@ -2,7 +2,6 @@ from collections import OrderedDict
|
||||||
|
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from asyncio import get_running_loop
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
import graphene
|
import graphene
|
||||||
|
@ -13,6 +12,7 @@ from graphene.types.objecttype import yank_fields_from_attrs
|
||||||
|
|
||||||
from ..types import ErrorType
|
from ..types import ErrorType
|
||||||
from .serializer_converter import convert_serializer_field
|
from .serializer_converter import convert_serializer_field
|
||||||
|
from ..utils import is_running_async
|
||||||
|
|
||||||
|
|
||||||
class SerializerMutationOptions(MutationOptions):
|
class SerializerMutationOptions(MutationOptions):
|
||||||
|
@ -154,11 +154,7 @@ class SerializerMutation(ClientIDMutation):
|
||||||
kwargs = cls.get_serializer_kwargs(root, info, **input)
|
kwargs = cls.get_serializer_kwargs(root, info, **input)
|
||||||
serializer = cls._meta.serializer_class(**kwargs)
|
serializer = cls._meta.serializer_class(**kwargs)
|
||||||
|
|
||||||
try:
|
if is_running_async():
|
||||||
get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
|
|
||||||
async def perform_mutate_async():
|
async def perform_mutate_async():
|
||||||
if await sync_to_async(serializer.is_valid)():
|
if await sync_to_async(serializer.is_valid)():
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .utils import (
|
||||||
camelize,
|
camelize,
|
||||||
get_model_fields,
|
get_model_fields,
|
||||||
is_valid_django_model,
|
is_valid_django_model,
|
||||||
|
is_running_async,
|
||||||
)
|
)
|
||||||
|
|
||||||
ALL_FIELDS = "__all__"
|
ALL_FIELDS = "__all__"
|
||||||
|
@ -288,13 +289,7 @@ class DjangoObjectType(ObjectType):
|
||||||
def get_node(cls, info, id):
|
def get_node(cls, info, id):
|
||||||
queryset = cls.get_queryset(cls._meta.model.objects, info)
|
queryset = cls.get_queryset(cls._meta.model.objects, info)
|
||||||
try:
|
try:
|
||||||
try:
|
if is_running_async():
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
return queryset.aget(pk=id)
|
return queryset.aget(pk=id)
|
||||||
|
|
||||||
return queryset.get(pk=id)
|
return queryset.get(pk=id)
|
||||||
|
|
|
@ -6,6 +6,8 @@ from .utils import (
|
||||||
get_reverse_fields,
|
get_reverse_fields,
|
||||||
is_valid_django_model,
|
is_valid_django_model,
|
||||||
maybe_queryset,
|
maybe_queryset,
|
||||||
|
is_sync_function,
|
||||||
|
is_running_async,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -16,4 +18,6 @@ __all__ = [
|
||||||
"camelize",
|
"camelize",
|
||||||
"is_valid_django_model",
|
"is_valid_django_model",
|
||||||
"GraphQLTestCase",
|
"GraphQLTestCase",
|
||||||
|
"is_sync_function",
|
||||||
|
"is_running_async",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
from asyncio import get_running_loop
|
||||||
|
|
||||||
from django.db import connection, models, transaction
|
from django.db import connection, models, transaction
|
||||||
from django.db.models.manager import Manager
|
from django.db.models.manager import Manager
|
||||||
|
@ -105,3 +106,18 @@ def set_rollback():
|
||||||
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
|
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
|
||||||
if atomic_requests and connection.in_atomic_block:
|
if atomic_requests and connection.in_atomic_block:
|
||||||
transaction.set_rollback(True)
|
transaction.set_rollback(True)
|
||||||
|
|
||||||
|
|
||||||
|
def is_running_async():
|
||||||
|
try:
|
||||||
|
get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_sync_function(func):
|
||||||
|
return not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
|
||||||
|
func
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user