diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py index c0a316b..f8ea42e 100644 --- a/graphene_django/debug/middleware.py +++ b/graphene_django/debug/middleware.py @@ -5,6 +5,9 @@ import inspect from .sql.tracking import unwrap_cursor, wrap_cursor from .exception.formating import wrap_exception from .types import DjangoDebug +from graphql.type.definition import GraphQLNonNull + +from django.db.models import QuerySet class DjangoDebugContext: @@ -74,6 +77,12 @@ class DjangoDebugMiddleware: class DjangoSyncRequiredMiddleware: def resolve(self, next, root, info, **args): parent_type = info.parent_type + return_type = info.return_type + + if isinstance(parent_type, GraphQLNonNull): + parent_type = parent_type.of_type + if isinstance(return_type, GraphQLNonNull): + return_type = return_type.of_type ## Anytime the parent is a DjangoObject type # and we're resolving a sync field, we need to wrap it in a sync_to_async @@ -87,23 +96,28 @@ class DjangoSyncRequiredMiddleware: ## In addition, if we're resolving to a DjangoObject type # we likely need to wrap it in a sync_to_async as well - if hasattr(info.return_type, "graphene_type") and hasattr( - info.return_type.graphene_type._meta, "model" + if hasattr(return_type, "graphene_type") and hasattr( + return_type.graphene_type._meta, "model" ): if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction( next ): return sync_to_async(next)(root, info, **args) - ## We also need to handle custom resolvers around Connections - # but only when their parent is not already a DjangoObject type - # this case already gets handled above. - if hasattr(info.return_type, "graphene_type"): - if hasattr(info.return_type.graphene_type, "Edge"): - node_type = info.return_type.graphene_type.Edge.node.type + ## We can move this resolver logic into the field resolver itself and probably should + if hasattr(return_type, "graphene_type"): + if hasattr(return_type.graphene_type, "Edge"): + node_type = return_type.graphene_type.Edge.node.type if hasattr(node_type, "_meta") and hasattr(node_type._meta, "model"): if not inspect.iscoroutinefunction( next ) and not inspect.isasyncgenfunction(next): return sync_to_async(next)(root, info, **args) + + if info.parent_type.name == "Mutation": + if not inspect.iscoroutinefunction(next) and not inspect.isasyncgenfunction( + next + ): + return sync_to_async(next)(root, info, **args) + return next(root, info, **args) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 43225c2..8f4ce5f 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,3 +1,4 @@ +import inspect from functools import partial from django.db.models.query import QuerySet @@ -82,13 +83,6 @@ class DjangoListField(Field): # Pass queryset to the DjangoObjectType get_queryset method queryset = maybe_queryset(django_object_type.get_queryset(queryset, info)) - try: - get_running_loop() - except RuntimeError: - pass - else: - return sync_to_async(list)(queryset) - return queryset def wrap_resolve(self, parent_resolver): @@ -97,12 +91,31 @@ class DjangoListField(Field): if isinstance(_type, NonNull): _type = _type.of_type django_object_type = _type.of_type.of_type - return partial( - self.list_resolver, - django_object_type, - resolver, - self.get_manager(), - ) + + try: + get_running_loop() + except RuntimeError: + return partial( + self.list_resolver, django_object_type, resolver, self.get_manager() + ) + else: + if not inspect.iscoroutinefunction( + resolver + ) and not inspect.isasyncgenfunction(resolver): + async_resolver = sync_to_async(resolver) + + ## This is needed because our middleware can't detect the resolver as async when we returns partial[couroutine] + async def wrapped_resolver(root, info, **args): + return await self.list_resolver( + django_object_type, + async_resolver, + self.get_manager(), + root, + info, + **args + ) + + return wrapped_resolver class DjangoConnectionField(ConnectionField): @@ -257,7 +270,6 @@ class DjangoConnectionField(ConnectionField): # eventually leads to DjangoObjectType's get_queryset (accepts queryset) # or a resolve_foo (does not accept queryset) - iterable = resolver(root, info, **args) if info.is_awaitable(iterable):