mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-07-27 08:30:03 +03:00
Handle the default django list field and test the async execution of the fields
This commit is contained in:
parent
c10753d4b1
commit
e9d5e88ea2
|
@ -53,7 +53,28 @@ class DjangoListField(Field):
|
||||||
def list_resolver(
|
def list_resolver(
|
||||||
django_object_type, resolver, default_manager, root, info, **args
|
django_object_type, resolver, default_manager, root, info, **args
|
||||||
):
|
):
|
||||||
queryset = maybe_queryset(resolver(root, info, **args))
|
iterable = resolver(root, info, **args)
|
||||||
|
|
||||||
|
if info.is_awaitable(iterable):
|
||||||
|
|
||||||
|
async def resolve_list_async(iterable):
|
||||||
|
queryset = maybe_queryset(await iterable)
|
||||||
|
if queryset is None:
|
||||||
|
queryset = maybe_queryset(default_manager)
|
||||||
|
|
||||||
|
if isinstance(queryset, QuerySet):
|
||||||
|
# Pass queryset to the DjangoObjectType get_queryset method
|
||||||
|
queryset = maybe_queryset(
|
||||||
|
await sync_to_async(django_object_type.get_queryset)(
|
||||||
|
queryset, info
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return await sync_to_async(list)(queryset)
|
||||||
|
|
||||||
|
return resolve_list_async(iterable)
|
||||||
|
|
||||||
|
queryset = maybe_queryset(iterable)
|
||||||
if queryset is None:
|
if queryset is None:
|
||||||
queryset = maybe_queryset(default_manager)
|
queryset = maybe_queryset(default_manager)
|
||||||
|
|
||||||
|
@ -64,9 +85,9 @@ class DjangoListField(Field):
|
||||||
try:
|
try:
|
||||||
get_running_loop()
|
get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
return queryset.aiterator()
|
return sync_to_async(list)(queryset)
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
@ -240,6 +261,7 @@ class DjangoConnectionField(ConnectionField):
|
||||||
iterable = resolver(root, info, **args)
|
iterable = resolver(root, info, **args)
|
||||||
|
|
||||||
if info.is_awaitable(iterable):
|
if info.is_awaitable(iterable):
|
||||||
|
|
||||||
async def resolve_connection_async(iterable):
|
async def resolve_connection_async(iterable):
|
||||||
iterable = await iterable
|
iterable = await iterable
|
||||||
if iterable is None:
|
if iterable is None:
|
||||||
|
@ -250,7 +272,10 @@ class DjangoConnectionField(ConnectionField):
|
||||||
if info.is_awaitable(iterable):
|
if info.is_awaitable(iterable):
|
||||||
iterable = await iterable
|
iterable = await iterable
|
||||||
|
|
||||||
return await sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)
|
return await sync_to_async(cls.resolve_connection)(
|
||||||
|
connection, args, iterable, max_limit=max_limit
|
||||||
|
)
|
||||||
|
|
||||||
return resolve_connection_async(iterable)
|
return resolve_connection_async(iterable)
|
||||||
|
|
||||||
if iterable is None:
|
if iterable is None:
|
||||||
|
@ -262,10 +287,11 @@ class DjangoConnectionField(ConnectionField):
|
||||||
try:
|
try:
|
||||||
get_running_loop()
|
get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
return sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)
|
return sync_to_async(cls.resolve_connection)(
|
||||||
|
connection, args, iterable, max_limit=max_limit
|
||||||
|
)
|
||||||
|
|
||||||
return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)
|
return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)
|
||||||
|
|
||||||
|
|
6
graphene_django/tests/async_test_helper.py
Normal file
6
graphene_django/tests/async_test_helper.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
|
||||||
|
|
||||||
|
def assert_async_result_equal(schema, query, result):
|
||||||
|
async_result = async_to_sync(schema.execute_async)(query)
|
||||||
|
assert async_result == result
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
from django.db.models import Count, Prefetch
|
from django.db.models import Count, Prefetch
|
||||||
|
from asgiref.sync import sync_to_async, async_to_sync
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -14,6 +15,7 @@ from .models import (
|
||||||
FilmDetails as FilmDetailsModel,
|
FilmDetails as FilmDetailsModel,
|
||||||
Reporter as ReporterModel,
|
Reporter as ReporterModel,
|
||||||
)
|
)
|
||||||
|
from .async_test_helper import assert_async_result_equal
|
||||||
|
|
||||||
|
|
||||||
class TestDjangoListField:
|
class TestDjangoListField:
|
||||||
|
@ -75,6 +77,7 @@ class TestDjangoListField:
|
||||||
|
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
|
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {
|
assert result.data == {
|
||||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||||
|
@ -102,6 +105,7 @@ class TestDjangoListField:
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": []}
|
assert result.data == {"reporters": []}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
@ -112,6 +116,7 @@ class TestDjangoListField:
|
||||||
assert result.data == {
|
assert result.data == {
|
||||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_override_resolver(self):
|
def test_override_resolver(self):
|
||||||
class Reporter(DjangoObjectType):
|
class Reporter(DjangoObjectType):
|
||||||
|
@ -139,6 +144,37 @@ class TestDjangoListField:
|
||||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
|
||||||
result = schema.execute(query)
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||||
|
|
||||||
|
def test_override_resolver_async_execution(self):
|
||||||
|
class Reporter(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = ReporterModel
|
||||||
|
fields = ("first_name",)
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
reporters = DjangoListField(Reporter)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@sync_to_async
|
||||||
|
def resolve_reporters(_, info):
|
||||||
|
return ReporterModel.objects.filter(first_name="Tara")
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
reporters {
|
||||||
|
firstName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||||
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
|
||||||
|
result = async_to_sync(schema.execute_async)(query)
|
||||||
|
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||||
|
@ -203,6 +239,7 @@ class TestDjangoListField:
|
||||||
{"firstName": "Debra", "articles": []},
|
{"firstName": "Debra", "articles": []},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_override_resolver_nested_list_field(self):
|
def test_override_resolver_nested_list_field(self):
|
||||||
class Article(DjangoObjectType):
|
class Article(DjangoObjectType):
|
||||||
|
@ -261,6 +298,7 @@ class TestDjangoListField:
|
||||||
{"firstName": "Debra", "articles": []},
|
{"firstName": "Debra", "articles": []},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_get_queryset_filter(self):
|
def test_get_queryset_filter(self):
|
||||||
class Reporter(DjangoObjectType):
|
class Reporter(DjangoObjectType):
|
||||||
|
@ -306,6 +344,7 @@ class TestDjangoListField:
|
||||||
|
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_resolve_list(self):
|
def test_resolve_list(self):
|
||||||
"""Resolving a plain list should work (and not call get_queryset)"""
|
"""Resolving a plain list should work (and not call get_queryset)"""
|
||||||
|
@ -354,6 +393,55 @@ class TestDjangoListField:
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||||
|
|
||||||
|
def test_resolve_list_async(self):
|
||||||
|
"""Resolving a plain list should work (and not call get_queryset) when running under async"""
|
||||||
|
|
||||||
|
class Reporter(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = ReporterModel
|
||||||
|
fields = ("first_name", "articles")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_queryset(cls, queryset, info):
|
||||||
|
# Only get reporters with at least 1 article
|
||||||
|
return queryset.annotate(article_count=Count("articles")).filter(
|
||||||
|
article_count__gt=0
|
||||||
|
)
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
reporters = DjangoListField(Reporter)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@sync_to_async
|
||||||
|
def resolve_reporters(_, info):
|
||||||
|
return [ReporterModel.objects.get(first_name="Debra")]
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
reporters {
|
||||||
|
firstName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||||
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
|
||||||
|
ArticleModel.objects.create(
|
||||||
|
headline="Amazing news",
|
||||||
|
reporter=r1,
|
||||||
|
pub_date=datetime.date.today(),
|
||||||
|
pub_date_time=datetime.datetime.now(),
|
||||||
|
editor=r1,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = async_to_sync(schema.execute_async)(query)
|
||||||
|
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||||
|
|
||||||
def test_get_queryset_foreign_key(self):
|
def test_get_queryset_foreign_key(self):
|
||||||
class Article(DjangoObjectType):
|
class Article(DjangoObjectType):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -413,6 +501,7 @@ class TestDjangoListField:
|
||||||
{"firstName": "Debra", "articles": []},
|
{"firstName": "Debra", "articles": []},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_resolve_list_external_resolver(self):
|
def test_resolve_list_external_resolver(self):
|
||||||
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||||
|
@ -461,6 +550,54 @@ class TestDjangoListField:
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||||
|
|
||||||
|
def test_resolve_list_external_resolver_async(self):
|
||||||
|
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
|
||||||
|
|
||||||
|
class Reporter(DjangoObjectType):
|
||||||
|
class Meta:
|
||||||
|
model = ReporterModel
|
||||||
|
fields = ("first_name", "articles")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_queryset(cls, queryset, info):
|
||||||
|
# Only get reporters with at least 1 article
|
||||||
|
return queryset.annotate(article_count=Count("articles")).filter(
|
||||||
|
article_count__gt=0
|
||||||
|
)
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
|
def resolve_reporters(_, info):
|
||||||
|
return [ReporterModel.objects.get(first_name="Debra")]
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
|
||||||
|
|
||||||
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
|
query = """
|
||||||
|
query {
|
||||||
|
reporters {
|
||||||
|
firstName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||||
|
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||||
|
|
||||||
|
ArticleModel.objects.create(
|
||||||
|
headline="Amazing news",
|
||||||
|
reporter=r1,
|
||||||
|
pub_date=datetime.date.today(),
|
||||||
|
pub_date_time=datetime.datetime.now(),
|
||||||
|
editor=r1,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = async_to_sync(schema.execute_async)(query)
|
||||||
|
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == {"reporters": [{"firstName": "Debra"}]}
|
||||||
|
|
||||||
def test_get_queryset_filter_external_resolver(self):
|
def test_get_queryset_filter_external_resolver(self):
|
||||||
class Reporter(DjangoObjectType):
|
class Reporter(DjangoObjectType):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -505,6 +642,7 @@ class TestDjangoListField:
|
||||||
|
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
assert result.data == {"reporters": [{"firstName": "Tara"}]}
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
||||||
def test_select_related_and_prefetch_related_are_respected(
|
def test_select_related_and_prefetch_related_are_respected(
|
||||||
self, django_assert_num_queries
|
self, django_assert_num_queries
|
||||||
|
@ -647,3 +785,4 @@ class TestDjangoListField:
|
||||||
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
|
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
|
||||||
captured.captured_queries[1]["sql"],
|
captured.captured_queries[1]["sql"],
|
||||||
)
|
)
|
||||||
|
assert_async_result_equal(schema, query, result)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user