diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 4f134bce6..127957514 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -11,20 +11,10 @@ from django.utils.translation import ugettext as _ from rest_framework import views, mixins, exceptions from rest_framework.request import clone_request from rest_framework.settings import api_settings +from rest_framework.pagination import strict_positive_int import warnings -def strict_positive_int(integer_string, cutoff=None): - """ - Cast a string to a strictly positive integer. - """ - ret = int(integer_string) - if ret <= 0: - raise ValueError() - if cutoff: - ret = min(ret, cutoff) - return ret - def get_object_or_404(queryset, *filter_args, **filter_kwargs): """ Same as Django's standard shortcut, but make sure to raise 404 diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 426865ff9..a562868cd 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -10,6 +10,7 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request +from rest_framework import pagination import warnings @@ -187,3 +188,35 @@ class DestroyModelMixin(object): obj = self.get_object() obj.delete() return Response(status=status.HTTP_204_NO_CONTENT) + + +class OffsetLimitPaginationMixin(object): + offset_kwarg = 'offset' + paginate_by_param = 'limit' + pagination_serializer_class = pagination.OffsetLimitPaginationSerializer + + def paginate_queryset(self, queryset): + limit = self.get_paginate_by() + if not limit: + return # pagination not configured + offset_kwarg = self.kwargs.get(self.offset_kwarg) + offset_query_param = self.request.QUERY_PARAMS.get(self.offset_kwarg) + offset = offset_kwarg or offset_query_param or 0 + try: + offset_number = pagination.strict_positive_int(offset) + except ValueError: + offset_number = 0 + return pagination.OffsetLimitPage(queryset, offset_number, limit) + + +class LinkPaginationMixin(object): + pagination_serializer_class = pagination.LinkPaginationSerializer + + def paginate_queryset(self, queryset, page_size=None): + page = super(LinkPaginationMixin, self).paginate_queryset( + queryset, page_size) + if page is not None: + page_ser = self.get_pagination_serializer(page) + self.headers.update(page_ser.get_link_header()) + self.object_list = page.object_list + return None # Don't use pagination serializer on response diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d51ea929b..c3c06ac34 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -7,6 +7,18 @@ from rest_framework import serializers from rest_framework.templatetags.rest_framework import replace_query_param +def strict_positive_int(integer_string, cutoff=None): + """ + Cast a string to a strictly positive integer. + """ + ret = int(integer_string) + if ret <= 0: + raise ValueError() + if cutoff: + ret = min(ret, cutoff) + return ret + + class NextPageField(serializers.Field): """ Field that returns a link to the next page in paginated results. @@ -37,6 +49,35 @@ class PreviousPageField(serializers.Field): return replace_query_param(url, self.page_field, page) +class FirstPageField(serializers.Field): + """ + Field that returns a link to the first page in paginated results. + """ + page_field = 'page' + + def to_native(self, value): + if not value.has_previous(): + return None + request = self.context.get('request') + url = request and request.build_absolute_uri() or '' + return replace_query_param(url, self.page_field, 1) + + +class LastPageField(serializers.Field): + """ + Field that returns a link to the previous page in paginated results. + """ + page_field = 'page' + + def to_native(self, value): + if not value.has_next(): + return None + page = value.paginator.num_pages + request = self.context.get('request') + url = request and request.build_absolute_uri() or '' + return replace_query_param(url, self.page_field, page) + + class DefaultObjectSerializer(serializers.Field): """ If no object serializer is specified, then this serializer will be applied @@ -82,7 +123,8 @@ class BasePaginationSerializer(serializers.Serializer): else: context_kwarg = {} - self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) + self.fields[results_field] = object_serializer(source='object_list', + **context_kwarg) class PaginationSerializer(BasePaginationSerializer): @@ -92,3 +134,42 @@ class PaginationSerializer(BasePaginationSerializer): count = serializers.Field(source='paginator.count') next = NextPageField(source='*') previous = PreviousPageField(source='*') + + +class LinkPaginationSerializer(serializers.Serializer): + """ Pagination serializer in order to build Link header """ + first = FirstPageField(source='*') + next = NextPageField(source='*') + previous = PreviousPageField(source='*') + last = LastPageField(source='*') + + def get_link_header(self): + link_keader_items = [ + '<%s>; rel="%s"' % (link, rel) + for rel, link in self.data.items() + if link is not None + ] + return {'Link': ', '.join(link_keader_items)} + + +class OffsetLimitPage(object): + """ + A base class to allow offset and limit when listing a queryset. + """ + def __init__(self, queryset, offset, limit): + self.count = self._set_count(queryset) + self.object_list = queryset[offset:offset + limit] + + def _set_count(self, queryset): + try: + return queryset.count() + except (AttributeError, TypeError): + # AttributeError if object_list has no count() method. + # TypeError if object_list.count() requires arguments + # (i.e. is of type list). + return len(queryset) + + +class OffsetLimitPaginationSerializer(BasePaginationSerializer): + """ OffsetLimitPage serializer """ + count = serializers.Field() diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index d6bc7895c..29d790e8f 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -5,7 +5,8 @@ from django.db import models from django.core.paginator import Paginator from django.test import TestCase from django.utils import unittest -from rest_framework import generics, status, pagination, filters, serializers +from rest_framework import (generics, status, pagination, filters, serializers, + mixins) from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel @@ -13,6 +14,48 @@ from rest_framework.tests.models import BasicModel factory = APIRequestFactory() +def parse_header_link(value): + # Got it from python-requests code + # https://github.com/kennethreitz/requests/blob/master/requests/utils.py#L467 + def parse_header_links(value): + """Return a dict of parsed link headers proxies. + + i.e. Link: ; rel=front; type="image/jpeg",; rel=back;type="image/jpeg" + + """ + + links = [] + + replace_chars = " '\"" + + for val in value.split(","): + try: + url, params = val.split(";", 1) + except ValueError: + url, params = val, '' + + link = {} + + link["url"] = url.strip("<> '\"") + + for param in params.split(";"): + try: + key, value = param.split("=") + except ValueError: + break + + link[key.strip(replace_chars)] = value.strip(replace_chars) + + links.append(link) + + return links + links = parse_header_links(value) + link_map = {} + for link in links: + link_map[link['rel']] = link['url'] + return link_map + + class FilterableItem(models.Model): text = models.CharField(max_length=100) decimal = models.DecimalField(max_digits=4, decimal_places=2) @@ -52,6 +95,18 @@ class MaxPaginateByView(generics.ListAPIView): paginate_by_param = 'page_size' +class LinkPaginationView(mixins.LinkPaginationMixin, + generics.ListCreateAPIView): + model = BasicModel + paginate_by = 10 + + +class OffsetLimitPaginationView(mixins.OffsetLimitPaginationMixin, + generics.ListCreateAPIView): + model = BasicModel + paginate_by = 10 + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -103,6 +158,102 @@ class IntegrationTestPagination(TestCase): self.assertNotEqual(response.data['previous'], None) +class IntegrationTestLinkPagination(TestCase): + + def setUp(self): + """ + Create 26 BasicModel instances. + """ + for char in 'abcdefghijklmnopqrstuvwxyz': + BasicModel(text=char * 3).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = LinkPaginationView.as_view() + + def test_get_paginated_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + request = factory.get('/') + # Note: Database queries are a `SELECT COUNT`, and `SELECT ` + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0:10]) + link_header = parse_header_link(response['Link']) + self.assertTrue('next' in link_header) + self.assertTrue('last' in link_header) + self.assertFalse('previous' in link_header) + self.assertFalse('first' in link_header) + + request = factory.get(link_header['next']) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[10:20]) + link_header = parse_header_link(response['Link']) + self.assertTrue('next' in link_header) + self.assertTrue('last' in link_header) + self.assertTrue('previous' in link_header) + self.assertTrue('first' in link_header) + + request = factory.get(link_header['last']) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[20:]) + link_header = parse_header_link(response['Link']) + self.assertFalse('next' in link_header) + self.assertFalse('last' in link_header) + self.assertTrue('previous' in link_header) + self.assertTrue('first' in link_header) + + +class IntegrationTestOffsetLimitPagination(TestCase): + + def setUp(self): + """ + Create 26 BasicModel instances. + """ + for char in 'abcdefghijklmnopqrstuvwxyz': + BasicModel(text=char * 3).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = OffsetLimitPaginationView.as_view() + + def test_get_paginated_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + request = factory.get('/') + # Note: Database queries are a `SELECT COUNT`, and `SELECT ` + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['results'], self.data[0:10]) + self.assertEqual(response.data['count'], 26) + + request = factory.get('/', {'offset': 10}) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['results'], self.data[10:20]) + self.assertEqual(response.data['count'], 26) + + request = factory.get('/', {'offset': 20, 'limit': 3}) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['results'], self.data[20:23]) + self.assertEqual(response.data['count'], 26) + + class IntegrationTestPaginationAndFiltering(TestCase): def setUp(self): @@ -258,6 +409,63 @@ class UnitTestPagination(TestCase): self.assertEqual(serializer.context, results.context) +class UnitTestLinkPagination(TestCase): + + def setUp(self): + self.objects = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + paginator = Paginator(self.objects, 2) + self.first_page = paginator.page(1) + self.third_page = paginator.page(3) + self.last_page = paginator.page(5) + + def test_native_pagination(self): + serializer = pagination.LinkPaginationSerializer(self.first_page) + self.assertEqual(serializer.data['next'], '?page=2') + self.assertEqual(serializer.data['previous'], None) + self.assertEqual(serializer.data['first'], None) + self.assertEqual(serializer.data['last'], '?page=5') + + serializer = pagination.LinkPaginationSerializer(self.third_page) + self.assertEqual(serializer.data['next'], '?page=4') + self.assertEqual(serializer.data['previous'], '?page=2') + self.assertEqual(serializer.data['first'], '?page=1') + self.assertEqual(serializer.data['last'], '?page=5') + + serializer = pagination.LinkPaginationSerializer(self.last_page) + self.assertEqual(serializer.data['next'], None) + self.assertEqual(serializer.data['previous'], '?page=4') + self.assertEqual(serializer.data['first'], '?page=1') + self.assertEqual(serializer.data['last'], None) + + +class UnitTestOffsetLimitPagination(TestCase): + + def setUp(self): + self.objects = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + self.first_page = pagination.OffsetLimitPage( + self.objects, offset=0, limit=2) + self.third_page = pagination.OffsetLimitPage( + self.objects, offset=4, limit=2) + self.last_page = pagination.OffsetLimitPage( + self.objects, offset=8, limit=2) + + def test_native_pagination(self): + serializer = pagination.OffsetLimitPaginationSerializer( + self.first_page) + self.assertEqual(serializer.data['count'], 10) + self.assertEqual(serializer.data['results'], self.objects[:2]) + + serializer = pagination.OffsetLimitPaginationSerializer( + self.third_page) + self.assertEqual(serializer.data['count'], 10) + self.assertEqual(serializer.data['results'], self.objects[4:6]) + + serializer = pagination.OffsetLimitPaginationSerializer( + self.last_page) + self.assertEqual(serializer.data['count'], 10) + self.assertEqual(serializer.data['results'], self.objects[8:]) + + class TestUnpaginated(TestCase): """ Tests for list views without pagination.