Handle the default django list field and test the async execution of the fields

This commit is contained in:
Josh Warwick 2023-05-05 11:18:21 +01:00
parent c10753d4b1
commit e9d5e88ea2
3 changed files with 184 additions and 13 deletions

View File

@ -53,7 +53,28 @@ class DjangoListField(Field):
def list_resolver(
django_object_type, resolver, default_manager, root, info, **args
):
queryset = maybe_queryset(resolver(root, info, **args))
iterable = resolver(root, info, **args)
if info.is_awaitable(iterable):
async def resolve_list_async(iterable):
queryset = maybe_queryset(await iterable)
if queryset is None:
queryset = maybe_queryset(default_manager)
if isinstance(queryset, QuerySet):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(
await sync_to_async(django_object_type.get_queryset)(
queryset, info
)
)
return await sync_to_async(list)(queryset)
return resolve_list_async(iterable)
queryset = maybe_queryset(iterable)
if queryset is None:
queryset = maybe_queryset(default_manager)
@ -61,12 +82,12 @@ class DjangoListField(Field):
# Pass queryset to the DjangoObjectType get_queryset method
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
try:
try:
get_running_loop()
except RuntimeError:
pass
pass
else:
return queryset.aiterator()
return sync_to_async(list)(queryset)
return queryset
@ -238,34 +259,39 @@ class DjangoConnectionField(ConnectionField):
# or a resolve_foo (does not accept queryset)
iterable = resolver(root, info, **args)
if info.is_awaitable(iterable):
async def resolve_connection_async(iterable):
iterable = await iterable
if iterable is None:
iterable = default_manager
## This could also be async
iterable = queryset_resolver(connection, iterable, info, args)
if info.is_awaitable(iterable):
iterable = await iterable
return await sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)
return await sync_to_async(cls.resolve_connection)(
connection, args, iterable, max_limit=max_limit
)
return resolve_connection_async(iterable)
if iterable is None:
iterable = default_manager
# thus the iterable gets refiltered by resolve_queryset
# but iterable might be promise
iterable = queryset_resolver(connection, iterable, info, args)
try:
try:
get_running_loop()
except RuntimeError:
pass
pass
else:
return sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)
return sync_to_async(cls.resolve_connection)(
connection, args, iterable, max_limit=max_limit
)
return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)

View File

@ -0,0 +1,6 @@
from asgiref.sync import async_to_sync
def assert_async_result_equal(schema, query, result):
async_result = async_to_sync(schema.execute_async)(query)
assert async_result == result

View File

@ -1,6 +1,7 @@
import datetime
import re
from django.db.models import Count, Prefetch
from asgiref.sync import sync_to_async, async_to_sync
import pytest
@ -14,6 +15,7 @@ from .models import (
FilmDetails as FilmDetailsModel,
Reporter as ReporterModel,
)
from .async_test_helper import assert_async_result_equal
class TestDjangoListField:
@ -75,6 +77,7 @@ class TestDjangoListField:
result = schema.execute(query)
assert_async_result_equal(schema, query, result)
assert not result.errors
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
@ -102,6 +105,7 @@ class TestDjangoListField:
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": []}
assert_async_result_equal(schema, query, result)
ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
@ -112,6 +116,7 @@ class TestDjangoListField:
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
}
assert_async_result_equal(schema, query, result)
def test_override_resolver(self):
class Reporter(DjangoObjectType):
@ -139,6 +144,37 @@ class TestDjangoListField:
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
def test_override_resolver_async_execution(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
@staticmethod
@sync_to_async
def resolve_reporters(_, info):
return ReporterModel.objects.filter(first_name="Tara")
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
result = async_to_sync(schema.execute_async)(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
@ -203,6 +239,7 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)
def test_override_resolver_nested_list_field(self):
class Article(DjangoObjectType):
@ -261,6 +298,7 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)
def test_get_queryset_filter(self):
class Reporter(DjangoObjectType):
@ -306,6 +344,7 @@ class TestDjangoListField:
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
assert_async_result_equal(schema, query, result)
def test_resolve_list(self):
"""Resolving a plain list should work (and not call get_queryset)"""
@ -354,6 +393,55 @@ class TestDjangoListField:
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_resolve_list_async(self):
"""Resolving a plain list should work (and not call get_queryset) when running under async"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
@staticmethod
@sync_to_async
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = async_to_sync(schema.execute_async)(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_get_queryset_foreign_key(self):
class Article(DjangoObjectType):
class Meta:
@ -413,6 +501,7 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []},
]
}
assert_async_result_equal(schema, query, result)
def test_resolve_list_external_resolver(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
@ -461,6 +550,54 @@ class TestDjangoListField:
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_resolve_list_external_resolver_async(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
@sync_to_async
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = async_to_sync(schema.execute_async)(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_get_queryset_filter_external_resolver(self):
class Reporter(DjangoObjectType):
class Meta:
@ -505,6 +642,7 @@ class TestDjangoListField:
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}
assert_async_result_equal(schema, query, result)
def test_select_related_and_prefetch_related_are_respected(
self, django_assert_num_queries
@ -647,3 +785,4 @@ class TestDjangoListField:
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
captured.captured_queries[1]["sql"],
)
assert_async_result_equal(schema, query, result)