diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 708da29cd..8ccdc342c 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -711,7 +711,11 @@ class CursorPagination(BasePagination): return replace_query_param(self.base_url, self.cursor_query_param, encoded) def _get_position_from_instance(self, instance, ordering): - attr = getattr(instance, ordering[0].lstrip('-')) + field_name = ordering[0].lstrip('-') + if isinstance(instance, dict): + attr = instance[field_name] + else: + attr = getattr(instance, field_name) return six.text_type(attr) def get_paginated_response(self, data): diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 170d95899..9f2e1c57c 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -3,6 +3,8 @@ from __future__ import unicode_literals import pytest from django.core.paginator import Paginator as DjangoPaginator +from django.db import models +from django.test import TestCase from rest_framework import ( exceptions, filters, generics, pagination, serializers, status @@ -530,85 +532,7 @@ class TestLimitOffset: assert content.get('previous') == prev_url -class TestCursorPagination: - """ - Unit tests for `pagination.CursorPagination`. - """ - - def setup(self): - class MockObject(object): - def __init__(self, idx): - self.created = idx - - class MockQuerySet(object): - def __init__(self, items): - self.items = items - - def filter(self, created__gt=None, created__lt=None): - if created__gt is not None: - return MockQuerySet([ - item for item in self.items - if item.created > int(created__gt) - ]) - - assert created__lt is not None - return MockQuerySet([ - item for item in self.items - if item.created < int(created__lt) - ]) - - def order_by(self, *ordering): - if ordering[0].startswith('-'): - return MockQuerySet(list(reversed(self.items))) - return self - - def __getitem__(self, sliced): - return self.items[sliced] - - class ExamplePagination(pagination.CursorPagination): - page_size = 5 - ordering = 'created' - - self.pagination = ExamplePagination() - self.queryset = MockQuerySet([ - MockObject(idx) for idx in [ - 1, 1, 1, 1, 1, - 1, 2, 3, 4, 4, - 4, 4, 5, 6, 7, - 7, 7, 7, 7, 7, - 7, 7, 7, 8, 9, - 9, 9, 9, 9, 9 - ] - ]) - - def get_pages(self, url): - """ - Given a URL return a tuple of: - - (previous page, current page, next page, previous url, next url) - """ - request = Request(factory.get(url)) - queryset = self.pagination.paginate_queryset(self.queryset, request) - current = [item.created for item in queryset] - - next_url = self.pagination.get_next_link() - previous_url = self.pagination.get_previous_link() - - if next_url is not None: - request = Request(factory.get(next_url)) - queryset = self.pagination.paginate_queryset(self.queryset, request) - next = [item.created for item in queryset] - else: - next = None - - if previous_url is not None: - request = Request(factory.get(previous_url)) - queryset = self.pagination.paginate_queryset(self.queryset, request) - previous = [item.created for item in queryset] - else: - previous = None - - return (previous, current, next, previous_url, next_url) +class CursorPaginationTestsMixin: def test_invalid_cursor(self): request = Request(factory.get('/', {'cursor': '123'})) @@ -703,6 +627,145 @@ class TestCursorPagination: assert isinstance(self.pagination.to_html(), type('')) +class TestCursorPagination(CursorPaginationTestsMixin): + """ + Unit tests for `pagination.CursorPagination`. + """ + + def setup(self): + class MockObject(object): + def __init__(self, idx): + self.created = idx + + class MockQuerySet(object): + def __init__(self, items): + self.items = items + + def filter(self, created__gt=None, created__lt=None): + if created__gt is not None: + return MockQuerySet([ + item for item in self.items + if item.created > int(created__gt) + ]) + + assert created__lt is not None + return MockQuerySet([ + item for item in self.items + if item.created < int(created__lt) + ]) + + def order_by(self, *ordering): + if ordering[0].startswith('-'): + return MockQuerySet(list(reversed(self.items))) + return self + + def __getitem__(self, sliced): + return self.items[sliced] + + class ExamplePagination(pagination.CursorPagination): + page_size = 5 + ordering = 'created' + + self.pagination = ExamplePagination() + self.queryset = MockQuerySet([ + MockObject(idx) for idx in [ + 1, 1, 1, 1, 1, + 1, 2, 3, 4, 4, + 4, 4, 5, 6, 7, + 7, 7, 7, 7, 7, + 7, 7, 7, 8, 9, + 9, 9, 9, 9, 9 + ] + ]) + + def get_pages(self, url): + """ + Given a URL return a tuple of: + + (previous page, current page, next page, previous url, next url) + """ + request = Request(factory.get(url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + current = [item.created for item in queryset] + + next_url = self.pagination.get_next_link() + previous_url = self.pagination.get_previous_link() + + if next_url is not None: + request = Request(factory.get(next_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + next = [item.created for item in queryset] + else: + next = None + + if previous_url is not None: + request = Request(factory.get(previous_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + previous = [item.created for item in queryset] + else: + previous = None + + return (previous, current, next, previous_url, next_url) + + +class CursorPaginationModel(models.Model): + created = models.IntegerField() + + +class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase): + """ + Unit tests for `pagination.CursorPagination` for value querysets. + """ + + def setUp(self): + class ExamplePagination(pagination.CursorPagination): + page_size = 5 + ordering = 'created' + + self.pagination = ExamplePagination() + data = [ + 1, 1, 1, 1, 1, + 1, 2, 3, 4, 4, + 4, 4, 5, 6, 7, + 7, 7, 7, 7, 7, + 7, 7, 7, 8, 9, + 9, 9, 9, 9, 9 + ] + for idx in data: + CursorPaginationModel.objects.create(created=idx) + + self.queryset = CursorPaginationModel.objects.values() + + def get_pages(self, url): + """ + Given a URL return a tuple of: + + (previous page, current page, next page, previous url, next url) + """ + request = Request(factory.get(url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + current = [item['created'] for item in queryset] + + next_url = self.pagination.get_next_link() + previous_url = self.pagination.get_previous_link() + + if next_url is not None: + request = Request(factory.get(next_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + next = [item['created'] for item in queryset] + else: + next = None + + if previous_url is not None: + request = Request(factory.get(previous_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + previous = [item['created'] for item in queryset] + else: + previous = None + + return (previous, current, next, previous_url, next_url) + + def test_get_displayed_page_numbers(): """ Test our contextual page display function.