This commit is contained in:
Josh Warwick 2023-05-05 16:28:40 +01:00
parent c501fdb20b
commit 58b92e6ed3
4 changed files with 10 additions and 5 deletions

View File

@ -95,6 +95,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
qs = super().resolve_queryset(connection, iterable, info, args) qs = super().resolve_queryset(connection, iterable, info, args)
if info.is_awaitable(qs): if info.is_awaitable(qs):
async def filter_async(qs): async def filter_async(qs):
filterset = filterset_class( filterset = filterset_class(
data=filter_kwargs(), queryset=await qs, request=info.context data=filter_kwargs(), queryset=await qs, request=info.context
@ -102,6 +103,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if await sync_to_async(filterset.is_valid)(): if await sync_to_async(filterset.is_valid)():
return filterset.qs return filterset.qs
raise ValidationError(filterset.form.errors.as_json()) raise ValidationError(filterset.form.errors.as_json())
return filter_async(qs) return filter_async(qs)
filterset = filterset_class( filterset = filterset_class(

View File

@ -154,17 +154,19 @@ class SerializerMutation(ClientIDMutation):
kwargs = cls.get_serializer_kwargs(root, info, **input) kwargs = cls.get_serializer_kwargs(root, info, **input)
serializer = cls._meta.serializer_class(**kwargs) serializer = cls._meta.serializer_class(**kwargs)
try: try:
get_running_loop() get_running_loop()
except RuntimeError: except RuntimeError:
pass pass
else: else:
async def perform_mutate_async(): async def perform_mutate_async():
if await sync_to_async(serializer.is_valid)(): if await sync_to_async(serializer.is_valid)():
return await sync_to_async(cls.perform_mutate)(serializer, info) return await sync_to_async(cls.perform_mutate)(serializer, info)
else: else:
errors = ErrorType.from_errors(serializer.errors) errors = ErrorType.from_errors(serializer.errors)
return cls(errors=errors) return cls(errors=errors)
return perform_mutate_async() return perform_mutate_async()
if serializer.is_valid(): if serializer.is_valid():

View File

@ -288,14 +288,15 @@ class DjangoObjectType(ObjectType):
def get_node(cls, info, id): def get_node(cls, info, id):
queryset = cls.get_queryset(cls._meta.model.objects, info) queryset = cls.get_queryset(cls._meta.model.objects, info)
try: try:
try: try:
import asyncio import asyncio
asyncio.get_running_loop() asyncio.get_running_loop()
except RuntimeError: except RuntimeError:
pass pass
else: else:
return queryset.aget(pk=id) return queryset.aget(pk=id)
return queryset.get(pk=id) return queryset.get(pk=id)
except cls._meta.model.DoesNotExist: except cls._meta.model.DoesNotExist:
return None return None

View File

@ -22,7 +22,7 @@ tests_require = [
"pytz", "pytz",
"django-filter>=22.1", "django-filter>=22.1",
"pytest-django>=4.5.2", "pytest-django>=4.5.2",
"pytest-asyncio>=0.16,<2" "pytest-asyncio>=0.16,<2",
] + rest_framework_require ] + rest_framework_require