fix: backward pagination (#1346)

Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
Co-authored-by: Laurent  <laurent.riviere.pro@gmail.com>
This commit is contained in:
Thomas Leonard 2022-09-22 16:01:28 +01:00 committed by GitHub
parent 42a40b4df0
commit 3473fe025e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 27 deletions

View File

@ -152,22 +152,24 @@ class DjangoConnectionField(ConnectionField):
array_length = iterable.count() array_length = iterable.count()
else: else:
array_length = len(iterable) array_length = len(iterable)
array_slice_length = (
min(max_limit, array_length) if max_limit is not None else array_length
)
# If after is higher than list_length, connection_from_list_slice # If after is higher than array_length, connection_from_array_slice
# would try to do a negative slicing which makes django throw an # would try to do a negative slicing which makes django throw an
# AssertionError # AssertionError
slice_start = min( slice_start = min(
get_offset_with_default(args.get("after"), -1) + 1, array_length get_offset_with_default(args.get("after"), -1) + 1,
array_length,
) )
array_slice_length = array_length - slice_start
if max_limit is not None and args.get("first", None) is None: # Impose the maximum limit via the `first` field if neither first or last are already provided
if args.get("last", None) is not None: # (note that if any of them is provided they must be under max_limit otherwise an error is raised).
slice_start = max(array_length - args["last"], 0) if (
else: max_limit is not None
args["first"] = max_limit and args.get("first", None) is None
and args.get("last", None) is None
):
args["first"] = max_limit
connection = connection_from_array_slice( connection = connection_from_array_slice(
iterable[slice_start:], iterable[slice_start:],

View File

@ -1243,6 +1243,7 @@ def test_should_have_next_page(graphene_settings):
} }
@pytest.mark.parametrize("max_limit", [100, 4])
class TestBackwardPagination: class TestBackwardPagination:
def setup_schema(self, graphene_settings, max_limit): def setup_schema(self, graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
@ -1261,8 +1262,8 @@ class TestBackwardPagination:
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
return schema return schema
def do_queries(self, schema): def test_query_last(self, graphene_settings, max_limit):
# Simply last 3 schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_last = """ query_last = """
query { query {
allReporters(last: 3) { allReporters(last: 3) {
@ -1282,7 +1283,8 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"] e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 3", "First 4", "First 5"] ] == ["First 3", "First 4", "First 5"]
# Use a combination of first and last def test_query_first_and_last(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_and_last = """ query_first_and_last = """
query { query {
allReporters(first: 4, last: 3) { allReporters(first: 4, last: 3) {
@ -1302,7 +1304,8 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"] e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 1", "First 2", "First 3"] ] == ["First 1", "First 2", "First 3"]
# Use a combination of first and last and after def test_query_first_last_and_after(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_last_and_after = """ query_first_last_and_after = """
query queryAfter($after: String) { query queryAfter($after: String) {
allReporters(first: 4, last: 3, after: $after) { allReporters(first: 4, last: 3, after: $after) {
@ -1317,7 +1320,8 @@ class TestBackwardPagination:
after = base64.b64encode(b"arrayconnection:0").decode() after = base64.b64encode(b"arrayconnection:0").decode()
result = schema.execute( result = schema.execute(
query_first_last_and_after, variable_values=dict(after=after) query_first_last_and_after,
variable_values=dict(after=after),
) )
assert not result.errors assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3 assert len(result.data["allReporters"]["edges"]) == 3
@ -1325,20 +1329,35 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"] e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 2", "First 3", "First 4"] ] == ["First 2", "First 3", "First 4"]
def test_should_query(self, graphene_settings): def test_query_last_and_before(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_last_and_after = """
query queryAfter($before: String) {
allReporters(last: 1, before: $before) {
edges {
node {
firstName
}
}
}
}
""" """
Backward pagination should work as expected
"""
schema = self.setup_schema(graphene_settings, max_limit=100)
self.do_queries(schema)
def test_should_query_with_low_max_limit(self, graphene_settings): result = schema.execute(
""" query_first_last_and_after,
When doing backward pagination (using last) in combination with a max limit higher than the number of objects )
we should really retrieve the last ones. assert not result.errors
""" assert len(result.data["allReporters"]["edges"]) == 1
schema = self.setup_schema(graphene_settings, max_limit=4) assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5"
self.do_queries(schema)
before = base64.b64encode(b"arrayconnection:5").decode()
result = schema.execute(
query_first_last_and_after,
variable_values=dict(before=before),
)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 1
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4"
def test_should_preserve_prefetch_related(django_assert_num_queries): def test_should_preserve_prefetch_related(django_assert_num_queries):