Update to pull out mutations and ensure that DjangoFields don't get double sync'd

This commit is contained in:
Josh Warwick 2023-05-10 19:17:30 +01:00
parent b134ab0a3e
commit 8c068fbc2b
2 changed files with 48 additions and 22 deletions

View File

@ -5,6 +5,9 @@ import inspect
from .sql.tracking import unwrap_cursor, wrap_cursor
from .exception.formating import wrap_exception
from .types import DjangoDebug
from graphql.type.definition import GraphQLNonNull
from django.db.models import QuerySet
class DjangoDebugContext:
@ -74,6 +77,12 @@ class DjangoDebugMiddleware:
class DjangoSyncRequiredMiddleware:
def resolve(self, next, root, info, **args):
parent_type = info.parent_type
return_type = info.return_type
if isinstance(parent_type, GraphQLNonNull):
parent_type = parent_type.of_type
if isinstance(return_type, GraphQLNonNull):
return_type = return_type.of_type
## Anytime the parent is a DjangoObject type
# and we're resolving a sync field, we need to wrap it in a sync_to_async
@ -87,23 +96,28 @@ class DjangoSyncRequiredMiddleware:
## In addition, if we're resolving to a DjangoObject type
# we likely need to wrap it in a sync_to_async as well
if hasattr(info.return_type, "graphene_type") and hasattr(
info.return_type.graphene_type._meta, "model"
if hasattr(return_type, "graphene_type") and hasattr(
return_type.graphene_type._meta, "model"
):
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
next
):
return sync_to_async(next)(root, info, **args)
## We also need to handle custom resolvers around Connections
# but only when their parent is not already a DjangoObject type
# this case already gets handled above.
if hasattr(info.return_type, "graphene_type"):
if hasattr(info.return_type.graphene_type, "Edge"):
node_type = info.return_type.graphene_type.Edge.node.type
## We can move this resolver logic into the field resolver itself and probably should
if hasattr(return_type, "graphene_type"):
if hasattr(return_type.graphene_type, "Edge"):
node_type = return_type.graphene_type.Edge.node.type
if hasattr(node_type, "_meta") and hasattr(node_type._meta, "model"):
if not inspect.iscoroutinefunction(
next
) and not inspect.isasyncgenfunction(next):
return sync_to_async(next)(root, info, **args)
if info.parent_type.name == "Mutation":
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
next
):
return sync_to_async(next)(root, info, **args)
return next(root, info, **args)

View File

@ -1,3 +1,4 @@
import inspect
from functools import partial
from django.db.models.query import QuerySet
@ -82,13 +83,6 @@ class DjangoListField(Field):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
try:
get_running_loop()
except RuntimeError:
pass
else:
return sync_to_async(list)(queryset)
return queryset
def wrap_resolve(self, parent_resolver):
@ -97,12 +91,31 @@ class DjangoListField(Field):
if isinstance(_type, NonNull):
_type = _type.of_type
django_object_type = _type.of_type.of_type
try:
get_running_loop()
except RuntimeError:
return partial(
self.list_resolver,
django_object_type,
resolver,
self.get_manager(),
self.list_resolver, django_object_type, resolver, self.get_manager()
)
else:
if not inspect.iscoroutinefunction(
resolver
) and not inspect.isasyncgenfunction(resolver):
async_resolver = sync_to_async(resolver)
## This is needed because our middleware can't detect the resolver as async when we returns partial[couroutine]
async def wrapped_resolver(root, info, **args):
return await self.list_resolver(
django_object_type,
async_resolver,
self.get_manager(),
root,
info,
**args
)
return wrapped_resolver
class DjangoConnectionField(ConnectionField):
@ -257,7 +270,6 @@ class DjangoConnectionField(ConnectionField):
# eventually leads to DjangoObjectType's get_queryset (accepts queryset)
# or a resolve_foo (does not accept queryset)
iterable = resolver(root, info, **args)
if info.is_awaitable(iterable):