diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 5938063af..586e7298d 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -201,7 +201,34 @@ class BasePagination(object): raise NotImplementedError('to_html() must be implemented to display page controls.') -class PageNumberPagination(BasePagination): +class BasePageSizePagination(BasePagination): + # The default page size. + # Defaults to `None`, meaning pagination is disabled. + page_size = api_settings.PAGE_SIZE + + # Set to an integer to limit the maximum page size the client may request. + # Only relevant if 'page_size_query_param' has also been set. + max_page_size = None + + # Client can control the page size using this query parameter. + # Default is 'None'. Set to eg 'page_size' to enable usage. + page_size_query_param = None + + def get_page_size(self, request): + if self.page_size_query_param: + try: + return _positive_int( + request.query_params[self.page_size_query_param], + strict=True, + cutoff=self.max_page_size + ) + except (KeyError, ValueError): + pass + + return self.page_size + + +class PageNumberPagination(BasePageSizePagination): """ A simple page number based style that supports page numbers as query parameters. For example: @@ -209,21 +236,10 @@ class PageNumberPagination(BasePagination): http://api.example.org/accounts/?page=4 http://api.example.org/accounts/?page=4&page_size=100 """ - # The default page size. - # Defaults to `None`, meaning pagination is disabled. - page_size = api_settings.PAGE_SIZE # Client can control the page using this query parameter. page_query_param = 'page' - # Client can control the page size using this query parameter. - # Default is 'None'. Set to eg 'page_size' to enable usage. - page_size_query_param = None - - # Set to an integer to limit the maximum page size the client may request. - # Only relevant if 'page_size_query_param' has also been set. - max_page_size = None - last_page_strings = ('last',) template = 'rest_framework/pagination/numbers.html' @@ -318,19 +334,6 @@ class PageNumberPagination(BasePagination): ('results', data) ])) - def get_page_size(self, request): - if self.page_size_query_param: - try: - return _positive_int( - request.query_params[self.page_size_query_param], - strict=True, - cutoff=self.max_page_size - ) - except (KeyError, ValueError): - pass - - return self.page_size - def get_next_link(self): if not self.page.has_next(): return None @@ -484,34 +487,29 @@ class LimitOffsetPagination(BasePagination): return template.render(context) -class CursorPagination(BasePagination): +class CursorPagination(BasePageSizePagination): """ The cursor pagination implementation is neccessarily complex. For an overview of the position/offset style we use, see this post: http://cramer.io/2011/03/08/building-cursors-for-the-disqus-api/ """ cursor_query_param = 'cursor' - page_size = api_settings.PAGE_SIZE invalid_cursor_message = _('Invalid cursor') ordering = '-created' template = 'rest_framework/pagination/previous_and_next.html' def paginate_queryset(self, queryset, request, view=None): + self.page_size = self.get_page_size(request) if self.page_size is None: return None self.base_url = request.build_absolute_uri() self.ordering = self.get_ordering(request, queryset, view) - # Determine if we have a cursor, and if so then decode it. - encoded = request.query_params.get(self.cursor_query_param) - if encoded is None: - self.cursor = None + self.cursor = self.decode_cursor(request) + if self.cursor is None: (offset, reverse, current_position) = (0, False, None) else: - self.cursor = _decode_cursor(encoded) - if self.cursor is None: - raise NotFound(self.invalid_cursor_message) (offset, reverse, current_position) = self.cursor # Cursor pagination always enforces an ordering. @@ -623,8 +621,7 @@ class CursorPagination(BasePagination): position = self.previous_position cursor = Cursor(offset=offset, reverse=False, position=position) - encoded = _encode_cursor(cursor) - return replace_query_param(self.base_url, self.cursor_query_param, encoded) + return self.encode_cursor(cursor) def get_previous_link(self): if not self.has_previous: @@ -672,8 +669,7 @@ class CursorPagination(BasePagination): position = self.next_position cursor = Cursor(offset=offset, reverse=True, position=position) - encoded = _encode_cursor(cursor) - return replace_query_param(self.base_url, self.cursor_query_param, encoded) + return self.encode_cursor(cursor) def get_ordering(self, request, queryset, view): """ @@ -715,6 +711,19 @@ class CursorPagination(BasePagination): return (ordering,) return tuple(ordering) + def decode_cursor(self, request): + # Determine if we have a cursor, and if so then decode it. + encoded = request.query_params.get(self.cursor_query_param) + if encoded is not None: + cursor = _decode_cursor(encoded) + if cursor is None: + raise NotFound(self.invalid_cursor_message) + return cursor + + def encode_cursor(self, cursor): + encoded = _encode_cursor(cursor) + return replace_query_param(self.base_url, self.cursor_query_param, encoded) + def _get_position_from_instance(self, instance, ordering): attr = getattr(instance, ordering[0].lstrip('-')) return six.text_type(attr) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index c67550c85..18feb22f3 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -2,8 +2,9 @@ from __future__ import unicode_literals from rest_framework import exceptions, generics, pagination, serializers, status, filters from rest_framework.request import Request -from rest_framework.pagination import PageLink, PAGE_BREAK +from rest_framework.pagination import PageLink, PAGE_BREAK, Cursor from rest_framework.test import APIRequestFactory +from rest_framework.utils.urls import replace_query_param, remove_query_param import pytest factory = APIRequestFactory() @@ -186,6 +187,7 @@ class TestPageNumberPagination: def setup(self): class ExamplePagination(pagination.PageNumberPagination): page_size = 5 + self.pagination = ExamplePagination() self.queryset = range(1, 101) @@ -475,52 +477,77 @@ class TestCursorPagination: """ Unit tests for `pagination.CursorPagination`. """ + class MockObject(object): + def __init__(self, idx): + self.created = idx - def setup(self): - class MockObject(object): - def __init__(self, idx): - self.created = idx + class MockQuerySet(object): + def __init__(self, items): + self.items = list(items) - class MockQuerySet(object): - def __init__(self, items): - self.items = items - - def filter(self, created__gt=None, created__lt=None): - if created__gt is not None: - return MockQuerySet([ - item for item in self.items - if item.created > int(created__gt) - ]) - - assert created__lt is not None - return MockQuerySet([ + def filter(self, created__gt=None, created__lt=None): + if created__gt is not None: + return type(self)([ item for item in self.items - if item.created < int(created__lt) + if item.created > int(created__gt) ]) - def order_by(self, *ordering): - if ordering[0].startswith('-'): - return MockQuerySet(list(reversed(self.items))) - return self + assert created__lt is not None + return type(self)([ + item for item in self.items + if item.created < int(created__lt) + ]) - def __getitem__(self, sliced): - return self.items[sliced] + def order_by(self, *ordering): + if ordering[0].startswith('-'): + return type(self)(list(reversed(self.items))) + return self - class ExamplePagination(pagination.CursorPagination): - page_size = 5 - ordering = 'created' + def __getitem__(self, sliced): + return self.items[sliced] - self.pagination = ExamplePagination() - self.queryset = MockQuerySet([ - MockObject(idx) for idx in [ + class ExamplePagination(pagination.CursorPagination): + page_size = 5 + page_size_query_param = 'page_size' + max_page_size = 20 + ordering = 'created' + + class CustomCursorPagination(ExamplePagination): + cursor_query_param = 'since' + reverse_query_param = 'before' + + def decode_cursor(self, request): + rev = False + if self.reverse_query_param in request.query_params: + rev = True + query_param = self.reverse_query_param + elif self.cursor_query_param in request.query_params: + query_param = self.cursor_query_param + else: + return + return Cursor(0, rev, request.query_params[query_param]) + + def encode_cursor(self, cursor): + if cursor.reverse: + query_param = self.reverse_query_param + base_url = remove_query_param(self.base_url, self.cursor_query_param) + else: + query_param = self.cursor_query_param + base_url = remove_query_param(self.base_url, self.reverse_query_param) + return replace_query_param(base_url, query_param, cursor.position) + + def setup(self): + self.pagination = self.ExamplePagination() + self.queryset = self.MockQuerySet( + map(self.MockObject, [ 1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 9, 9, 9, 9, 9, 9 - ] - ]) + ]) + ) def get_pages(self, url): """ @@ -643,6 +670,84 @@ class TestCursorPagination: assert isinstance(self.pagination.to_html(), type('')) + def test_page_size(self): + (previous, current, next, previous_url, next_url) = \ + self.get_pages('/?page_size=10') + + assert previous is None + assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4] + assert next == [4, 4, 5, 6, 7, 7, 7, 7, 7, 7] + assert 'page_size=10' in next_url + + (previous, current, next, previous_url, next_url) = \ + self.get_pages(next_url.replace('page_size=10', 'page_size=4')) + + assert previous == [2, 3, 4, 4] + assert current == [4, 4, 5, 6] + assert next == [7, 7, 7, 7] + assert 'page_size=4' in previous_url + assert 'page_size=4' in next_url + + def test_custom_cursor_format(self): + # setup + self.pagination = self.CustomCursorPagination() + # The CustomCursorPagination expects unique keys + self.queryset = self.MockQuerySet( + map(self.MockObject, [ + 1, 2, 4, 8, 10, + 11, 12, 13, 14, 15, + 18, 33, 35, 36, 37, + 38, 39, 40, 41, + ]) + ) + + # actual test + (previous, current, next, previous_url, next_url) = self.get_pages('/') + + assert previous is None + assert current == [1, 2, 4, 8, 10] + assert next == [11, 12, 13, 14, 15] + assert previous_url is None + assert 'since=10' in next_url + + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + + assert previous == [1, 2, 4, 8, 10] + assert current == [11, 12, 13, 14, 15] + assert next == [18, 33, 35, 36, 37] + assert 'before=11' in previous_url + assert 'since' not in previous_url + assert 'since=15' in next_url + + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + + assert previous == [18, 33, 35, 36, 37] + assert current == [38, 39, 40, 41] + assert next is None + assert 'before=38' in previous_url + assert 'since' not in previous_url + assert next_url is None + + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + + assert previous == [11, 12, 13, 14, 15] + assert current == [18, 33, 35, 36, 37] + assert next == [38, 39, 40, 41] + assert 'before=18' in previous_url + assert 'since=37' in next_url + assert 'before' not in next_url + + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + + assert previous is None + assert current == [1, 2, 4, 8, 10] + assert next == [11, 12, 13, 14, 15] + assert previous_url is None + assert 'since=10' in next_url + assert 'before' not in next_url + def test_get_displayed_page_numbers(): """