From 3c229b619efb546971c3df46e30a9ff18aca5721 Mon Sep 17 00:00:00 2001 From: Paul Craciunoiu Date: Thu, 25 Jun 2020 06:00:24 -0600 Subject: [PATCH] Fix hasNextPage - revert to count. Fix after (#986) Co-authored-by: Jonathan Kim --- graphene_django/debug/tests/test_query.py | 57 ++++++++--------------- graphene_django/fields.py | 32 ++++++++----- graphene_django/tests/test_query.py | 55 +++++++++++++++++++++- 3 files changed, 94 insertions(+), 50 deletions(-) diff --git a/graphene_django/debug/tests/test_query.py b/graphene_django/debug/tests/test_query.py index 4c057ed..d71c3fb 100644 --- a/graphene_django/debug/tests/test_query.py +++ b/graphene_django/debug/tests/test_query.py @@ -56,8 +56,8 @@ def test_should_query_field(): assert result.data == expected -@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)]) -def test_should_query_nested_field(graphene_settings, max_limit, does_count): +@pytest.mark.parametrize("max_limit", [None, 100]) +def test_should_query_nested_field(graphene_settings, max_limit): graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit r1 = Reporter(last_name="ABA") @@ -117,18 +117,11 @@ def test_should_query_nested_field(graphene_settings, max_limit, does_count): assert not result.errors query = str(Reporter.objects.order_by("pk")[:1].query) assert result.data["__debug"]["sql"][0]["rawSql"] == query - if does_count: - assert "COUNT" in result.data["__debug"]["sql"][1]["rawSql"] - assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"] - assert "COUNT" in result.data["__debug"]["sql"][3]["rawSql"] - assert "tests_reporter_pets" in result.data["__debug"]["sql"][4]["rawSql"] - assert len(result.data["__debug"]["sql"]) == 5 - else: - assert len(result.data["__debug"]["sql"]) == 3 - for i in range(len(result.data["__debug"]["sql"])): - assert "COUNT" not in result.data["__debug"]["sql"][i]["rawSql"] - assert "tests_reporter_pets" in result.data["__debug"]["sql"][1]["rawSql"] - assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"] + assert "COUNT" in result.data["__debug"]["sql"][1]["rawSql"] + assert "tests_reporter_pets" in result.data["__debug"]["sql"][2]["rawSql"] + assert "COUNT" in result.data["__debug"]["sql"][3]["rawSql"] + assert "tests_reporter_pets" in result.data["__debug"]["sql"][4]["rawSql"] + assert len(result.data["__debug"]["sql"]) == 5 assert result.data["reporter"] == expected["reporter"] @@ -175,8 +168,8 @@ def test_should_query_list(): assert result.data == expected -@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)]) -def test_should_query_connection(graphene_settings, max_limit, does_count): +@pytest.mark.parametrize("max_limit", [None, 100]) +def test_should_query_connection(graphene_settings, max_limit): graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit r1 = Reporter(last_name="ABA") @@ -219,20 +212,14 @@ def test_should_query_connection(graphene_settings, max_limit, does_count): ) assert not result.errors assert result.data["allReporters"] == expected["allReporters"] - if does_count: - assert len(result.data["__debug"]["sql"]) == 2 - assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] - query = str(Reporter.objects.all()[:1].query) - assert result.data["__debug"]["sql"][1]["rawSql"] == query - else: - assert len(result.data["__debug"]["sql"]) == 1 - assert "COUNT" not in result.data["__debug"]["sql"][0]["rawSql"] - query = str(Reporter.objects.all()[:1].query) - assert result.data["__debug"]["sql"][0]["rawSql"] == query + assert len(result.data["__debug"]["sql"]) == 2 + assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] + query = str(Reporter.objects.all()[:1].query) + assert result.data["__debug"]["sql"][1]["rawSql"] == query -@pytest.mark.parametrize("max_limit,does_count", [(None, True), (100, False)]) -def test_should_query_connectionfilter(graphene_settings, max_limit, does_count): +@pytest.mark.parametrize("max_limit", [None, 100]) +def test_should_query_connectionfilter(graphene_settings, max_limit): graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit from ...filter import DjangoFilterConnectionField @@ -278,13 +265,7 @@ def test_should_query_connectionfilter(graphene_settings, max_limit, does_count) ) assert not result.errors assert result.data["allReporters"] == expected["allReporters"] - if does_count: - assert len(result.data["__debug"]["sql"]) == 2 - assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] - query = str(Reporter.objects.all()[:1].query) - assert result.data["__debug"]["sql"][1]["rawSql"] == query - else: - assert len(result.data["__debug"]["sql"]) == 1 - assert "COUNT" not in result.data["__debug"]["sql"][0]["rawSql"] - query = str(Reporter.objects.all()[:1].query) - assert result.data["__debug"]["sql"][0]["rawSql"] == query + assert len(result.data["__debug"]["sql"]) == 2 + assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] + query = str(Reporter.objects.all()[:1].query) + assert result.data["__debug"]["sql"][1]["rawSql"] == query diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 9b102bd..ac7ce45 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -2,7 +2,10 @@ from functools import partial import six from django.db.models.query import QuerySet -from graphql_relay.connection.arrayconnection import connection_from_list_slice +from graphql_relay.connection.arrayconnection import ( + connection_from_list_slice, + get_offset_with_default, +) from promise import Promise from graphene import NonNull @@ -129,25 +132,32 @@ class DjangoConnectionField(ConnectionField): @classmethod def resolve_connection(cls, connection, args, iterable, max_limit=None): iterable = maybe_queryset(iterable) - # When slicing from the end, need to retrieve the iterable length. - if args.get("last"): - max_limit = None + if isinstance(iterable, QuerySet): - _len = max_limit or iterable.count() + list_length = iterable.count() + list_slice_length = ( + min(max_limit, list_length) if max_limit is not None else list_length + ) else: - _len = max_limit or len(iterable) + list_length = len(iterable) + list_slice_length = ( + min(max_limit, list_length) if max_limit is not None else list_length + ) + + after = get_offset_with_default(args.get("after"), -1) + 1 + connection = connection_from_list_slice( - iterable, + iterable[after:], args, - slice_start=0, - list_length=_len, - list_slice_length=_len, + slice_start=after, + list_length=list_length, + list_slice_length=list_slice_length, connection_type=connection, edge_type=connection.Edge, pageinfo_type=PageInfo, ) connection.iterable = iterable - connection.length = _len + connection.length = list_length return connection @classmethod diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index e6ed49e..64f54bb 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -1126,6 +1126,59 @@ def test_should_return_max_limit(graphene_settings): assert len(result.data["allReporters"]["edges"]) == 4 +def test_should_have_next_page(graphene_settings): + graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 6 + reporters = [Reporter(**kwargs) for kwargs in REPORTERS] + Reporter.objects.bulk_create(reporters) + db_reporters = Reporter.objects.all() + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + schema = graphene.Schema(query=Query) + # Need first: 4 here to trigger the `has_next_page` logic in graphql-relay + # See `arrayconnection.py::connection_from_list_slice`: + # has_next_page=isinstance(first, int) and end_offset < upper_bound + query = """ + query AllReporters($first: Int, $after: String) { + allReporters(first: $first, after: $after) { + pageInfo { + hasNextPage + endCursor + } + edges { + node { + id + } + } + } + } + """ + + result = schema.execute(query, variable_values=dict(first=4)) + assert not result.errors + assert len(result.data["allReporters"]["edges"]) == 4 + assert result.data["allReporters"]["pageInfo"]["hasNextPage"] + + last_result = result.data["allReporters"]["pageInfo"]["endCursor"] + result2 = schema.execute(query, variable_values=dict(first=4, after=last_result)) + assert not result2.errors + assert len(result2.data["allReporters"]["edges"]) == 2 + assert not result2.data["allReporters"]["pageInfo"]["hasNextPage"] + gql_reporters = ( + result.data["allReporters"]["edges"] + result2.data["allReporters"]["edges"] + ) + + assert {to_global_id("ReporterType", reporter.id) for reporter in db_reporters} == { + gql_reporter["node"]["id"] for gql_reporter in gql_reporters + } + + def test_should_preserve_prefetch_related(django_assert_num_queries): class ReporterType(DjangoObjectType): class Meta: @@ -1172,7 +1225,7 @@ def test_should_preserve_prefetch_related(django_assert_num_queries): } """ schema = graphene.Schema(query=Query) - with django_assert_num_queries(2) as captured: + with django_assert_num_queries(3) as captured: result = schema.execute(query) assert not result.errors