diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index dc120d8e8..75b06f1e8 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -2,15 +2,20 @@ Pagination serializers determine the structure of the output that should be used for paginated responses. """ +import json +import operator + from base64 import b64decode, b64encode from collections import OrderedDict, namedtuple from urllib import parse +from functools import reduce from django.core.paginator import InvalidPage from django.core.paginator import Paginator as DjangoPaginator from django.template import loader from django.utils.encoding import force_str from django.utils.translation import gettext_lazy as _ +from django.db.models.query import Q from rest_framework.compat import coreapi, coreschema from rest_framework.exceptions import NotFound @@ -616,25 +621,41 @@ class CursorPagination(BasePagination): else: (offset, reverse, current_position) = self.cursor - # Cursor pagination always enforces an ordering. - if reverse: - queryset = queryset.order_by(*_reverse_ordering(self.ordering)) - else: - queryset = queryset.order_by(*self.ordering) - # If we have a cursor with a fixed position then filter by that. if current_position is not None: - order = self.ordering[0] - is_reversed = order.startswith('-') - order_attr = order.lstrip('-') + current_position_list = json.loads(current_position) - # Test for: (cursor reversed) XOR (queryset reversed) - if self.cursor.reverse != is_reversed: - kwargs = {order_attr + '__lt': current_position} - else: - kwargs = {order_attr + '__gt': current_position} + q_objects_equals = {} + q_objects_compare = {} - queryset = queryset.filter(**kwargs) + for order, position in zip(self.ordering, current_position_list): + is_reversed = order.startswith("-") + order_attr = order.lstrip("-") + + q_objects_equals[order] = Q(**{order_attr: position}) + + # Test for: (cursor reversed) XOR (queryset reversed) + if self.cursor.reverse != is_reversed: + q_objects_compare[order] = Q( + **{(order_attr + "__lt"): position} + ) + else: + q_objects_compare[order] = Q( + **{(order_attr + "__gt"): position} + ) + + filter_list = [] + # starting with the second field + for i in range(len(self.ordering)): + # The first operands need to be equals + # the last operands need to be gt + equals = list(self.ordering[:i+2]) + greater_than_q = q_objects_compare[equals.pop()] + sub_filters = [q_objects_equals[e] for e in equals] + sub_filters.append(greater_than_q) + filter_list.append(reduce(operator.and_, sub_filters)) + + queryset = queryset.filter(reduce(operator.or_, filter_list)) # If we have an offset cursor then offset the entire page by that amount. # We also always fetch an extra item in order to determine if there is a @@ -839,7 +860,14 @@ class CursorPagination(BasePagination): ) if isinstance(ordering, str): - return (ordering,) + ordering = (ordering,) + + pk_name = queryset.model._meta.pk.name + + # Always include a unique key to order by + if not {f"-{pk_name}", pk_name, "pk", "-pk"} & set(ordering): + ordering = ordering + (pk_name,) + return tuple(ordering) def decode_cursor(self, request): @@ -884,12 +912,18 @@ class CursorPagination(BasePagination): return replace_query_param(self.base_url, self.cursor_query_param, encoded) def _get_position_from_instance(self, instance, ordering): - field_name = ordering[0].lstrip('-') - if isinstance(instance, dict): - attr = instance[field_name] - else: - attr = getattr(instance, field_name) - return str(attr) + fields = [] + + for o in ordering: + field_name = o.lstrip("-") + if isinstance(instance, dict): + attr = instance[field_name] + else: + attr = getattr(instance, field_name) + + fields.append(str(attr)) + + return json.dumps(fields).encode() def get_paginated_response(self, data): return Response(OrderedDict([