Update to pull out mutations and ensure that DjangoFields don't get double sync'd

This commit is contained in:
Josh Warwick 2023-05-10 19:17:30 +01:00
parent b134ab0a3e
commit 8c068fbc2b
2 changed files with 48 additions and 22 deletions

View File

@ -5,6 +5,9 @@ import inspect
from .sql.tracking import unwrap_cursor, wrap_cursor from .sql.tracking import unwrap_cursor, wrap_cursor
from .exception.formating import wrap_exception from .exception.formating import wrap_exception
from .types import DjangoDebug from .types import DjangoDebug
from graphql.type.definition import GraphQLNonNull
from django.db.models import QuerySet
class DjangoDebugContext: class DjangoDebugContext:
@ -74,6 +77,12 @@ class DjangoDebugMiddleware:
class DjangoSyncRequiredMiddleware: class DjangoSyncRequiredMiddleware:
def resolve(self, next, root, info, **args): def resolve(self, next, root, info, **args):
parent_type = info.parent_type parent_type = info.parent_type
return_type = info.return_type
if isinstance(parent_type, GraphQLNonNull):
parent_type = parent_type.of_type
if isinstance(return_type, GraphQLNonNull):
return_type = return_type.of_type
## Anytime the parent is a DjangoObject type ## Anytime the parent is a DjangoObject type
# and we're resolving a sync field, we need to wrap it in a sync_to_async # and we're resolving a sync field, we need to wrap it in a sync_to_async
@ -87,23 +96,28 @@ class DjangoSyncRequiredMiddleware:
## In addition, if we're resolving to a DjangoObject type ## In addition, if we're resolving to a DjangoObject type
# we likely need to wrap it in a sync_to_async as well # we likely need to wrap it in a sync_to_async as well
if hasattr(info.return_type, "graphene_type") and hasattr( if hasattr(return_type, "graphene_type") and hasattr(
info.return_type.graphene_type._meta, "model" return_type.graphene_type._meta, "model"
): ):
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction( if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
next next
): ):
return sync_to_async(next)(root, info, **args) return sync_to_async(next)(root, info, **args)
## We also need to handle custom resolvers around Connections ## We can move this resolver logic into the field resolver itself and probably should
# but only when their parent is not already a DjangoObject type if hasattr(return_type, "graphene_type"):
# this case already gets handled above. if hasattr(return_type.graphene_type, "Edge"):
if hasattr(info.return_type, "graphene_type"): node_type = return_type.graphene_type.Edge.node.type
if hasattr(info.return_type.graphene_type, "Edge"):
node_type = info.return_type.graphene_type.Edge.node.type
if hasattr(node_type, "_meta") and hasattr(node_type._meta, "model"): if hasattr(node_type, "_meta") and hasattr(node_type._meta, "model"):
if not inspect.iscoroutinefunction( if not inspect.iscoroutinefunction(
next next
) and not inspect.isasyncgenfunction(next): ) and not inspect.isasyncgenfunction(next):
return sync_to_async(next)(root, info, **args) return sync_to_async(next)(root, info, **args)
if info.parent_type.name == "Mutation":
if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction(
next
):
return sync_to_async(next)(root, info, **args)
return next(root, info, **args) return next(root, info, **args)

View File

@ -1,3 +1,4 @@
import inspect
from functools import partial from functools import partial
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
@ -82,13 +83,6 @@ class DjangoListField(Field):
# Pass queryset to the DjangoObjectType get_queryset method # Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info)) queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
try:
get_running_loop()
except RuntimeError:
pass
else:
return sync_to_async(list)(queryset)
return queryset return queryset
def wrap_resolve(self, parent_resolver): def wrap_resolve(self, parent_resolver):
@ -97,12 +91,31 @@ class DjangoListField(Field):
if isinstance(_type, NonNull): if isinstance(_type, NonNull):
_type = _type.of_type _type = _type.of_type
django_object_type = _type.of_type.of_type django_object_type = _type.of_type.of_type
return partial(
self.list_resolver, try:
django_object_type, get_running_loop()
resolver, except RuntimeError:
self.get_manager(), return partial(
) self.list_resolver, django_object_type, resolver, self.get_manager()
)
else:
if not inspect.iscoroutinefunction(
resolver
) and not inspect.isasyncgenfunction(resolver):
async_resolver = sync_to_async(resolver)
## This is needed because our middleware can't detect the resolver as async when we returns partial[couroutine]
async def wrapped_resolver(root, info, **args):
return await self.list_resolver(
django_object_type,
async_resolver,
self.get_manager(),
root,
info,
**args
)
return wrapped_resolver
class DjangoConnectionField(ConnectionField): class DjangoConnectionField(ConnectionField):
@ -257,7 +270,6 @@ class DjangoConnectionField(ConnectionField):
# eventually leads to DjangoObjectType's get_queryset (accepts queryset) # eventually leads to DjangoObjectType's get_queryset (accepts queryset)
# or a resolve_foo (does not accept queryset) # or a resolve_foo (does not accept queryset)
iterable = resolver(root, info, **args) iterable = resolver(root, info, **args)
if info.is_awaitable(iterable): if info.is_awaitable(iterable):