From 7ddaf9f5e683d31a5e9907c80e57e83e4c332bfd Mon Sep 17 00:00:00 2001 From: Josh Warwick Date: Fri, 31 Mar 2023 11:28:57 -0700 Subject: [PATCH] Handle coroutine results from resolvers in connections and filter connections --- graphene_django/fields.py | 14 ++++++++++++++ graphene_django/filter/fields.py | 12 ++++++++++++ 2 files changed, 26 insertions(+) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 68db6a0..c8899de 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -239,6 +239,20 @@ class DjangoConnectionField(ConnectionField): iterable = resolver(root, info, **args) + if info.is_awaitable(iterable): + async def resolve_connection_async(): + iterable = await iterable + if iterable is None: + iterable = default_manager + ## This could also be async + iterable = queryset_resolver(connection, iterable, info, args) + + if info.is_awaitable(iterable): + iterable = await iterable + + return await sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit) + return resolve_connection_async() + if iterable is None: iterable = default_manager # thus the iterable gets refiltered by resolve_queryset diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index cdb8f85..a141460 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -7,6 +7,8 @@ from graphene.types.enum import EnumType from graphene.types.argument import to_arguments from graphene.utils.str_converters import to_snake_case +from asgiref.sync import sync_to_async + from ..fields import DjangoConnectionField from .utils import get_filtering_args_from_filterset, get_filterset_class @@ -92,6 +94,16 @@ class DjangoFilterConnectionField(DjangoConnectionField): qs = super().resolve_queryset(connection, iterable, info, args) + if info.is_awaitable(qs): + async def filter_async(): + filterset = filterset_class( + data=filter_kwargs(), queryset=await qs, request=info.context + ) + if await sync_to_async(filterset.is_valid)(): + return filterset.qs + raise ValidationError(filterset.form.errors.as_json()) + return filter_async() + filterset = filterset_class( data=filter_kwargs(), queryset=qs, request=info.context )