This commit is contained in:
Kevin Turner 2017-09-27 08:52:37 +00:00 committed by GitHub
commit 878ac12d64
2 changed files with 56 additions and 2 deletions

View File

@ -7,6 +7,7 @@ from __future__ import unicode_literals
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
import decimal
from django.core.paginator import Paginator as DjangoPaginator from django.core.paginator import Paginator as DjangoPaginator
from django.core.paginator import InvalidPage from django.core.paginator import InvalidPage
@ -497,6 +498,9 @@ class CursorPagination(BasePagination):
# queries, by having a hard cap on the maximum possible size of the offset. # queries, by having a hard cap on the maximum possible size of the offset.
offset_cutoff = 1000 offset_cutoff = 1000
__rounding_down = decimal.Context(prec=14, rounding=decimal.ROUND_FLOOR)
__rounding_up = decimal.Context(prec=14, rounding=decimal.ROUND_CEILING)
def paginate_queryset(self, queryset, request, view=None): def paginate_queryset(self, queryset, request, view=None):
self.page_size = self.get_page_size(request) self.page_size = self.get_page_size(request)
if not self.page_size: if not self.page_size:
@ -775,6 +779,13 @@ class CursorPagination(BasePagination):
attr = instance[field_name] attr = instance[field_name]
else: else:
attr = getattr(instance, field_name) attr = getattr(instance, field_name)
if isinstance(attr, float):
if ordering[0][0] == '-':
attr = self.__rounding_down.create_decimal_from_float(attr)
else:
attr = self.__rounding_up.create_decimal_from_float(attr)
return six.text_type(attr) return six.text_type(attr)
def get_paginated_response(self, data): def get_paginated_response(self, data):

View File

@ -7,9 +7,9 @@ from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import ( from rest_framework import (
exceptions, filters, generics, pagination, serializers, status exceptions, filters, generics, pagination, serializers, status, viewsets
) )
from rest_framework.pagination import PAGE_BREAK, PageLink from rest_framework.pagination import CursorPagination, PAGE_BREAK, PageLink
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
@ -986,3 +986,46 @@ def test_get_displayed_page_numbers():
assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9]
assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9]
assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9]
class ThreeItemCursorPagination(CursorPagination):
page_size = 3
ordering = 'score'
class FloatyModel(models.Model):
score = models.FloatField()
class PassThroughSerializer(serializers.BaseSerializer):
def to_representation(self, item):
return item
class FloatyViewSet(viewsets.ReadOnlyModelViewSet):
queryset = FloatyModel.objects.all()
pagination_class = ThreeItemCursorPagination
serializer_class = PassThroughSerializer
page_size = 3
class TestCursorPaginationWithFloatingPointPosition(TestCase):
def setUp(self):
self.view = FloatyViewSet.as_view(actions={'get': 'list'})
def test_page_boundary_does_not_repeat_elements(self):
for i in range(12):
FloatyModel.objects.create(score=i/9.0)
request = factory.get('/')
first_response = self.view(request)
first_page_last_item = first_response.data['results'][-1]
second_request = factory.get(first_response.data['next'])
second_response = self.view(second_request)
second_page_first_item = second_response.data['results'][0]
self.assertNotEqual(first_page_last_item.pk, second_page_first_item.pk)
self.assertLess(first_page_last_item.score, second_page_first_item.score)