diff --git a/graphene_django/tests/async_test_helper.py b/graphene_django/tests/async_test_helper.py index 0487f89..5785c6c 100644 --- a/graphene_django/tests/async_test_helper.py +++ b/graphene_django/tests/async_test_helper.py @@ -1,6 +1,6 @@ from asgiref.sync import async_to_sync -def assert_async_result_equal(schema, query, result): - async_result = async_to_sync(schema.execute_async)(query) +def assert_async_result_equal(schema, query, result, **kwargs): + async_result = async_to_sync(schema.execute_async)(query, **kwargs) assert async_result == result diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 383ff2e..19343af 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -7,6 +7,7 @@ from django.db.models import Q from django.utils.functional import SimpleLazyObject from graphql_relay import to_global_id from pytest import raises +from asgiref.sync import sync_to_async, async_to_sync import graphene from graphene.relay import Node @@ -16,6 +17,7 @@ from ..fields import DjangoConnectionField from ..types import DjangoObjectType from ..utils import DJANGO_FILTER_INSTALLED from .models import Article, CNNReporter, Film, FilmDetails, Person, Pet, Reporter +from .async_test_helper import assert_async_result_equal def test_should_query_only_fields(): @@ -34,6 +36,7 @@ def test_should_query_only_fields(): """ result = schema.execute(query) assert not result.errors + assert_async_result_equal(schema, query, result) def test_should_query_simplelazy_objects(): @@ -59,6 +62,7 @@ def test_should_query_simplelazy_objects(): result = schema.execute(query) assert not result.errors assert result.data == {"reporter": {"id": "1"}} + assert_async_result_equal(schema, query, result) def test_should_query_wrapped_simplelazy_objects(): @@ -84,6 +88,7 @@ def test_should_query_wrapped_simplelazy_objects(): result = schema.execute(query) assert not result.errors assert result.data == {"reporter": {"id": "1"}} + assert_async_result_equal(schema, query, result) def test_should_query_well(): @@ -112,6 +117,7 @@ def test_should_query_well(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) @pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist") @@ -167,6 +173,7 @@ def test_should_query_postgres_fields(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_node(): @@ -248,6 +255,7 @@ def test_should_node(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_onetoone_fields(): @@ -306,6 +314,7 @@ def test_should_query_onetoone_fields(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_connectionfields(): @@ -344,6 +353,7 @@ def test_should_query_connectionfields(): "edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}], } } + assert_async_result_equal(schema, query, result) def test_should_keep_annotations(): @@ -403,6 +413,7 @@ def test_should_keep_annotations(): """ result = schema.execute(query) assert not result.errors + assert_async_result_equal(schema, query, result) @pytest.mark.skipif( @@ -484,6 +495,7 @@ def test_should_query_node_filtering(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) @pytest.mark.skipif( @@ -529,6 +541,7 @@ def test_should_query_node_filtering_with_distinct_queryset(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) @pytest.mark.skipif( @@ -618,6 +631,7 @@ def test_should_query_node_multiple_filtering(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_enforce_first_or_last(graphene_settings): @@ -658,6 +672,7 @@ def test_should_enforce_first_or_last(graphene_settings): "paginate the `allReporters` connection.\n" ) assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_error_if_first_is_greater_than_max(graphene_settings): @@ -700,6 +715,7 @@ def test_should_error_if_first_is_greater_than_max(graphene_settings): "exceeds the `first` limit of 100 records.\n" ) assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_error_if_last_is_greater_than_max(graphene_settings): @@ -742,6 +758,7 @@ def test_should_error_if_last_is_greater_than_max(graphene_settings): "exceeds the `last` limit of 100 records.\n" ) assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_promise_connectionfields(): @@ -777,6 +794,7 @@ def test_should_query_promise_connectionfields(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_connectionfields_with_last(): @@ -814,6 +832,7 @@ def test_should_query_connectionfields_with_last(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_connectionfields_with_manager(): @@ -855,6 +874,7 @@ def test_should_query_connectionfields_with_manager(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_query_dataloader_fields(): @@ -957,6 +977,108 @@ def test_should_query_dataloader_fields(): assert result.data == expected +def test_should_query_dataloader_fields_async(): + from promise import Promise + from promise.dataloader import DataLoader + + def article_batch_load_fn(keys): + queryset = Article.objects.filter(reporter_id__in=keys) + return Promise.resolve( + [ + [article for article in queryset if article.reporter_id == id] + for id in keys + ] + ) + + article_loader = DataLoader(article_batch_load_fn) + + class ArticleType(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + fields = "__all__" + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + articles = DjangoConnectionField(ArticleType) + + @staticmethod + @sync_to_async + def resolve_articles(self, info, **args): + return article_loader.load(self.id).get() + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + r = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + Article.objects.create( + headline="Article Node 1", + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + reporter=r, + editor=r, + lang="es", + ) + Article.objects.create( + headline="Article Node 2", + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + reporter=r, + editor=r, + lang="en", + ) + + schema = graphene.Schema(query=Query) + query = """ + query ReporterPromiseConnectionQuery { + allReporters(first: 1) { + edges { + node { + id + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + } + } + """ + + expected = { + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjE=", + "articles": { + "edges": [ + {"node": {"headline": "Article Node 1"}}, + {"node": {"headline": "Article Node 2"}}, + ] + }, + } + } + ] + } + } + + result = async_to_sync(schema.execute_async)(query) + assert not result.errors + assert result.data == expected + + def test_should_handle_inherited_choices(): class BaseModel(models.Model): choice_field = models.IntegerField(choices=((0, "zero"), (1, "one"))) @@ -1063,6 +1185,7 @@ def test_proxy_model_support(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_should_resolve_get_queryset_connectionfields(): @@ -1108,6 +1231,7 @@ def test_should_resolve_get_queryset_connectionfields(): result = schema.execute(query) assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result) def test_connection_should_limit_after_to_list_length(): @@ -1145,6 +1269,7 @@ def test_connection_should_limit_after_to_list_length(): expected = {"allReporters": {"edges": []}} assert not result.errors assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values=dict(after=after)) REPORTERS = [ @@ -1188,6 +1313,7 @@ def test_should_return_max_limit(graphene_settings): result = schema.execute(query) assert not result.errors assert len(result.data["allReporters"]["edges"]) == 4 + assert_async_result_equal(schema, query, result) def test_should_have_next_page(graphene_settings): @@ -1226,6 +1352,7 @@ def test_should_have_next_page(graphene_settings): assert not result.errors assert len(result.data["allReporters"]["edges"]) == 4 assert result.data["allReporters"]["pageInfo"]["hasNextPage"] + assert_async_result_equal(schema, query, result, variable_values={}) last_result = result.data["allReporters"]["pageInfo"]["endCursor"] result2 = schema.execute(query, variable_values=dict(first=4, after=last_result)) @@ -1239,6 +1366,9 @@ def test_should_have_next_page(graphene_settings): assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == { gql_reporter["node"]["id"] for gql_reporter in gql_reporters } + assert_async_result_equal( + schema, query, result2, variable_values=dict(first=4, after=last_result) + ) @pytest.mark.parametrize("max_limit", [100, 4]) @@ -1262,7 +1392,7 @@ class TestBackwardPagination: def test_query_last(self, graphene_settings, max_limit): schema = self.setup_schema(graphene_settings, max_limit=max_limit) - query_last = """ + query = """ query { allReporters(last: 3) { edges { @@ -1274,16 +1404,17 @@ class TestBackwardPagination: } """ - result = schema.execute(query_last) + result = schema.execute(query) assert not result.errors assert len(result.data["allReporters"]["edges"]) == 3 assert [ e["node"]["firstName"] for e in result.data["allReporters"]["edges"] ] == ["First 3", "First 4", "First 5"] + assert_async_result_equal(schema, query, result) def test_query_first_and_last(self, graphene_settings, max_limit): schema = self.setup_schema(graphene_settings, max_limit=max_limit) - query_first_and_last = """ + query = """ query { allReporters(first: 4, last: 3) { edges { @@ -1295,16 +1426,17 @@ class TestBackwardPagination: } """ - result = schema.execute(query_first_and_last) + result = schema.execute(query) assert not result.errors assert len(result.data["allReporters"]["edges"]) == 3 assert [ e["node"]["firstName"] for e in result.data["allReporters"]["edges"] ] == ["First 1", "First 2", "First 3"] + assert_async_result_equal(schema, query, result) def test_query_first_last_and_after(self, graphene_settings, max_limit): schema = self.setup_schema(graphene_settings, max_limit=max_limit) - query_first_last_and_after = """ + query = """ query queryAfter($after: String) { allReporters(first: 4, last: 3, after: $after) { edges { @@ -1318,7 +1450,7 @@ class TestBackwardPagination: after = base64.b64encode(b"arrayconnection:0").decode() result = schema.execute( - query_first_last_and_after, + query, variable_values=dict(after=after), ) assert not result.errors @@ -1326,10 +1458,13 @@ class TestBackwardPagination: assert [ e["node"]["firstName"] for e in result.data["allReporters"]["edges"] ] == ["First 2", "First 3", "First 4"] + assert_async_result_equal( + schema, query, result, variable_values=dict(after=after) + ) def test_query_last_and_before(self, graphene_settings, max_limit): schema = self.setup_schema(graphene_settings, max_limit=max_limit) - query_first_last_and_after = """ + query = """ query queryAfter($before: String) { allReporters(last: 1, before: $before) { edges { @@ -1342,20 +1477,24 @@ class TestBackwardPagination: """ result = schema.execute( - query_first_last_and_after, + query, ) assert not result.errors assert len(result.data["allReporters"]["edges"]) == 1 assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5" + assert_async_result_equal(schema, query, result) before = base64.b64encode(b"arrayconnection:5").decode() result = schema.execute( - query_first_last_and_after, + query, variable_values=dict(before=before), ) assert not result.errors assert len(result.data["allReporters"]["edges"]) == 1 assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4" + assert_async_result_equal( + schema, query, result, variable_values=dict(before=before) + ) def test_should_preserve_prefetch_related(django_assert_num_queries): @@ -1410,6 +1549,7 @@ def test_should_preserve_prefetch_related(django_assert_num_queries): with django_assert_num_queries(3) as captured: result = schema.execute(query) assert not result.errors + assert_async_result_equal(schema, query, result) def test_should_preserve_annotations(): @@ -1465,6 +1605,7 @@ def test_should_preserve_annotations(): } assert result.data == expected, str(result.data) assert not result.errors + assert_async_result_equal(schema, query, result) def test_connection_should_enable_offset_filtering(): @@ -1504,6 +1645,7 @@ def test_connection_should_enable_offset_filtering(): } } assert result.data == expected + assert_async_result_equal(schema, query, result) def test_connection_should_enable_offset_filtering_higher_than_max_limit( @@ -1548,6 +1690,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit( } } assert result.data == expected + assert_async_result_equal(schema, query, result) def test_connection_should_forbid_offset_filtering_with_before(): @@ -1578,6 +1721,9 @@ def test_connection_should_forbid_offset_filtering_with_before(): expected_error = "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `allReporters` connection." assert len(result.errors) == 1 assert result.errors[0].message == expected_error + assert_async_result_equal( + schema, query, result, variable_values=dict(before=before) + ) def test_connection_should_allow_offset_filtering_with_after(): @@ -1620,6 +1766,7 @@ def test_connection_should_allow_offset_filtering_with_after(): } } assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values=dict(after=after)) def test_connection_should_succeed_if_last_higher_than_number_of_objects(): @@ -1650,6 +1797,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects(): assert not result.errors expected = {"allReporters": {"edges": []}} assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values=dict(last=2)) Reporter.objects.create(first_name="John", last_name="Doe") Reporter.objects.create(first_name="Some", last_name="Guy") @@ -1667,6 +1815,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects(): } } assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values=dict(last=2)) result = schema.execute(query, variable_values=dict(last=4)) assert not result.errors @@ -1681,6 +1830,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects(): } } assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values=dict(last=4)) result = schema.execute(query, variable_values=dict(last=20)) assert not result.errors @@ -1695,6 +1845,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects(): } } assert result.data == expected + assert_async_result_equal(schema, query, result, variable_values=dict(last=20)) def test_should_query_nullable_foreign_key():