diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 56dedcae4..fab617810 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -2,8 +2,9 @@ from __future__ import unicode_literals from rest_framework import exceptions, generics, pagination, serializers, status, filters from rest_framework.request import Request -from rest_framework.pagination import PageLink, PAGE_BREAK +from rest_framework.pagination import PageLink, PAGE_BREAK, Cursor from rest_framework.test import APIRequestFactory +from rest_framework.utils.urls import replace_query_param, remove_query_param import pytest factory = APIRequestFactory() @@ -509,6 +510,30 @@ class TestCursorPagination: page_size = 5 ordering = 'created' + class CustomCursorPagination(ExamplePagination): + cursor_query_param = 'since' + reverse_query_param = 'before' + + def decode_cursor(self, request): + rev = False + if self.reverse_query_param in request.query_params: + rev = True + query_param = self.reverse_query_param + elif self.cursor_query_param in request.query_params: + query_param = self.cursor_query_param + else: + return + return Cursor(0, rev, request.query_params[query_param]) + + def encode_cursor(self, cursor): + if cursor.reverse: + query_param = self.reverse_query_param + base_url = remove_query_param(self.base_url, self.cursor_query_param) + else: + query_param = self.cursor_query_param + base_url = remove_query_param(self.base_url, self.reverse_query_param) + return replace_query_param(base_url, query_param, cursor.position) + def setup(self): self.pagination = self.ExamplePagination() self.queryset = self.MockQuerySet( @@ -643,6 +668,66 @@ class TestCursorPagination: assert isinstance(self.pagination.to_html(), type('')) + def test_custom_cursor_format(self): + # setup + self.pagination = self.CustomCursorPagination() + # The CustomCursorPagination expects unique keys + self.queryset = self.MockQuerySet( + map(self.MockObject, [ + 1, 2, 4, 8, 10, + 11, 12, 13, 14, 15, + 18, 33, 35, 36, 37, + 38, 39, 40, 41, + ]) + ) + + # actual test + (previous, current, next, previous_url, next_url) = self.get_pages('/') + + assert previous is None + assert current == [1, 2, 4, 8, 10] + assert next == [11, 12, 13, 14, 15] + assert previous_url is None + assert 'since=10' in next_url + + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + + assert previous == [1, 2, 4, 8, 10] + assert current == [11, 12, 13, 14, 15] + assert next == [18, 33, 35, 36, 37] + assert 'before=11' in previous_url + assert 'since' not in previous_url + assert 'since=15' in next_url + + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + + assert previous == [18, 33, 35, 36, 37] + assert current == [38, 39, 40, 41] + assert next is None + assert 'before=38' in previous_url + assert 'since' not in previous_url + assert next_url is None + + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + + assert previous == [11, 12, 13, 14, 15] + assert current == [18, 33, 35, 36, 37] + assert next == [38, 39, 40, 41] + assert 'before=18' in previous_url + assert 'since=37' in next_url + assert 'before' not in next_url + + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + + assert previous is None + assert current == [1, 2, 4, 8, 10] + assert next == [11, 12, 13, 14, 15] + assert previous_url is None + assert 'since=10' in next_url + assert 'before' not in next_url + def test_get_displayed_page_numbers(): """