mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-07-24 15:09:51 +03:00
Handle the default django list field and test the async execution of the fields
This commit is contained in:
parent
c10753d4b1
commit
e9d5e88ea2
|
@ -53,7 +53,28 @@ class DjangoListField(Field):
|
|||
def list_resolver(
|
||||
django_object_type, resolver, default_manager, root, info, **args
|
||||
):
|
||||
queryset = maybe_queryset(resolver(root, info, **args))
|
||||
iterable = resolver(root, info, **args)
|
||||
|
||||
if info.is_awaitable(iterable):
|
||||
|
||||
async def resolve_list_async(iterable):
|
||||
queryset = maybe_queryset(await iterable)
|
||||
if queryset is None:
|
||||
queryset = maybe_queryset(default_manager)
|
||||
|
||||
if isinstance(queryset, QuerySet):
|
||||
# Pass queryset to the DjangoObjectType get_queryset method
|
||||
queryset = maybe_queryset(
|
||||
await sync_to_async(django_object_type.get_queryset)(
|
||||
queryset, info
|
||||
)
|
||||
)
|
||||
|
||||
return await sync_to_async(list)(queryset)
|
||||
|
||||
return resolve_list_async(iterable)
|
||||
|
||||
queryset = maybe_queryset(iterable)
|
||||
if queryset is None:
|
||||
queryset = maybe_queryset(default_manager)
|
||||
|
||||
|
@ -61,12 +82,12 @@ class DjangoListField(Field):
|
|||
# Pass queryset to the DjangoObjectType get_queryset method
|
||||
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
|
||||
|
||||
try:
|
||||
try:
|
||||
get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
pass
|
||||
else:
|
||||
return queryset.aiterator()
|
||||
return sync_to_async(list)(queryset)
|
||||
|
||||
return queryset
|
||||
|
||||
|
@ -238,34 +259,39 @@ class DjangoConnectionField(ConnectionField):
|
|||
# or a resolve_foo (does not accept queryset)
|
||||
|
||||
iterable = resolver(root, info, **args)
|
||||
|
||||
|
||||
if info.is_awaitable(iterable):
|
||||
|
||||
async def resolve_connection_async(iterable):
|
||||
iterable = await iterable
|
||||
if iterable is None:
|
||||
iterable = default_manager
|
||||
## This could also be async
|
||||
iterable = queryset_resolver(connection, iterable, info, args)
|
||||
|
||||
|
||||
if info.is_awaitable(iterable):
|
||||
iterable = await iterable
|
||||
|
||||
return await sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)
|
||||
|
||||
return await sync_to_async(cls.resolve_connection)(
|
||||
connection, args, iterable, max_limit=max_limit
|
||||
)
|
||||
|
||||
return resolve_connection_async(iterable)
|
||||
|
||||
|
||||
if iterable is None:
|
||||
iterable = default_manager
|
||||
# thus the iterable gets refiltered by resolve_queryset
|
||||
# but iterable might be promise
|
||||
iterable = queryset_resolver(connection, iterable, info, args)
|
||||
|
||||
try:
|
||||
try:
|
||||
get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
pass
|
||||
else:
|
||||
return sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)
|
||||
|
||||
return sync_to_async(cls.resolve_connection)(
|
||||
connection, args, iterable, max_limit=max_limit
|
||||
)
|
||||
|
||||
return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)
|
||||
|
||||
|
|
6
graphene_django/tests/async_test_helper.py
Normal file
6
graphene_django/tests/async_test_helper.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from asgiref.sync import async_to_sync
|
||||
|
||||
|
||||
def assert_async_result_equal(schema, query, result):
|
||||
async_result = async_to_sync(schema.execute_async)(query)
|
||||
assert async_result == result
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import re
|
||||
from django.db.models import Count, Prefetch
|
||||
from asgiref.sync import sync_to_async, async_to_sync
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -14,6 +15,7 @@ from .models import (
|
|||
FilmDetails as FilmDetailsModel,
|
||||
Reporter as ReporterModel,
|
||||
)
|
||||
from .async_test_helper import assert_async_result_equal
|
||||
|
||||
|
||||
class TestDjangoListField:
|
||||
|
@ -75,6 +77,7 @@ class TestDjangoListField:
|
|||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert_async_result_equal(schema, query, result)
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||
|
@ -102,6 +105,7 @@ class TestDjangoListField:
|
|||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": []}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
@ -112,6 +116,7 @@ class TestDjangoListField:
|
|||
assert result.data == {
|
||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_override_resolver(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
|
@ -139,6 +144,37 @@ class TestDjangoListField:
|
|||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
|
||||
def test_override_resolver_async_execution(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name",)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
@staticmethod
|
||||
@sync_to_async
|
||||
def resolve_reporters(_, info):
|
||||
return ReporterModel.objects.filter(first_name="Tara")
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
result = async_to_sync(schema.execute_async)(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
|
@ -203,6 +239,7 @@ class TestDjangoListField:
|
|||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_override_resolver_nested_list_field(self):
|
||||
class Article(DjangoObjectType):
|
||||
|
@ -261,6 +298,7 @@ class TestDjangoListField:
|
|||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_get_queryset_filter(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
|
@ -306,6 +344,7 @@ class TestDjangoListField:
|
|||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_resolve_list(self):
|
||||
"""Resolving a plain list should work (and not call get_queryset)"""
|
||||
|
@ -354,6 +393,55 @@ class TestDjangoListField:
|
|||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_resolve_list_async(self):
|
||||
"""Resolving a plain list should work (and not call get_queryset) when running under async"""
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Only get reporters with at least 1 article
|
||||
return queryset.annotate(article_count=Count("articles")).filter(
|
||||
article_count__gt=0
|
||||
)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
@staticmethod
|
||||
@sync_to_async
|
||||
def resolve_reporters(_, info):
|
||||
return [ReporterModel.objects.get(first_name="Debra")]
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = async_to_sync(schema.execute_async)(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_get_queryset_foreign_key(self):
|
||||
class Article(DjangoObjectType):
|
||||
class Meta:
|
||||
|
@ -413,6 +501,7 @@ class TestDjangoListField:
|
|||
{"firstName": "Debra", "articles": []},
|
||||
]
|
||||
}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_resolve_list_external_resolver(self):
|
||||
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||
|
@ -461,6 +550,54 @@ class TestDjangoListField:
|
|||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_resolve_list_external_resolver_async(self):
|
||||
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name", "articles")
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls, queryset, info):
|
||||
# Only get reporters with at least 1 article
|
||||
return queryset.annotate(article_count=Count("articles")).filter(
|
||||
article_count__gt=0
|
||||
)
|
||||
|
||||
@sync_to_async
|
||||
def resolve_reporters(_, info):
|
||||
return [ReporterModel.objects.get(first_name="Debra")]
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
ArticleModel.objects.create(
|
||||
headline="Amazing news",
|
||||
reporter=r1,
|
||||
pub_date=datetime.date.today(),
|
||||
pub_date_time=datetime.datetime.now(),
|
||||
editor=r1,
|
||||
)
|
||||
|
||||
result = async_to_sync(schema.execute_async)(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||
|
||||
def test_get_queryset_filter_external_resolver(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
|
@ -505,6 +642,7 @@ class TestDjangoListField:
|
|||
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
||||
def test_select_related_and_prefetch_related_are_respected(
|
||||
self, django_assert_num_queries
|
||||
|
@ -647,3 +785,4 @@ class TestDjangoListField:
|
|||
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
|
||||
captured.captured_queries[1]["sql"],
|
||||
)
|
||||
assert_async_result_equal(schema, query, result)
|
||||
|
|
Loading…
Reference in New Issue
Block a user