diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 0d709c37a..12fb64138 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -2,29 +2,13 @@ Generic views that provide commonly needed behaviour. """ from __future__ import unicode_literals - -from django.core.paginator import Paginator, InvalidPage from django.db.models.query import QuerySet from django.http import Http404 from django.shortcuts import get_object_or_404 as _get_object_or_404 -from django.utils import six -from django.utils.translation import ugettext as _ from rest_framework import views, mixins from rest_framework.settings import api_settings -def strict_positive_int(integer_string, cutoff=None): - """ - Cast a string to a strictly positive integer. - """ - ret = int(integer_string) - if ret <= 0: - raise ValueError() - if cutoff: - ret = min(ret, cutoff) - return ret - - def get_object_or_404(queryset, *filter_args, **filter_kwargs): """ Same as Django's standard shortcut, but make sure to also raise 404 @@ -40,7 +24,6 @@ class GenericAPIView(views.APIView): """ Base class for all other generic views. """ - # You'll need to either set these attributes, # or override `get_queryset()`/`get_serializer_class()`. # If you are overriding a view method, it is important that you call @@ -50,146 +33,16 @@ class GenericAPIView(views.APIView): queryset = None serializer_class = None - # If you want to use object lookups other than pk, set this attribute. + # If you want to use object lookups other than pk, set 'lookup_field'. # For more complex lookup requirements override `get_object()`. lookup_field = 'pk' lookup_url_kwarg = None - # Pagination settings - paginate_by = api_settings.PAGINATE_BY - paginate_by_param = api_settings.PAGINATE_BY_PARAM - max_paginate_by = api_settings.MAX_PAGINATE_BY - pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - page_kwarg = 'page' - # The filter backend classes to use for queryset filtering filter_backends = api_settings.DEFAULT_FILTER_BACKENDS - # The following attribute may be subject to change, - # and should be considered private API. - paginator_class = Paginator - - def get_serializer_context(self): - """ - Extra context provided to the serializer class. - """ - return { - 'request': self.request, - 'format': self.format_kwarg, - 'view': self - } - - def get_serializer(self, *args, **kwargs): - """ - Return the serializer instance that should be used for validating and - deserializing input, and for serializing output. - """ - serializer_class = self.get_serializer_class() - kwargs['context'] = self.get_serializer_context() - return serializer_class(*args, **kwargs) - - def get_pagination_serializer(self, page): - """ - Return a serializer instance to use with paginated data. - """ - class SerializerClass(self.pagination_serializer_class): - class Meta: - object_serializer_class = self.get_serializer_class() - - pagination_serializer_class = SerializerClass - context = self.get_serializer_context() - return pagination_serializer_class(instance=page, context=context) - - def paginate_queryset(self, queryset): - """ - Paginate a queryset if required, either returning a page object, - or `None` if pagination is not configured for this view. - """ - page_size = self.get_paginate_by() - if not page_size: - return None - - paginator = self.paginator_class(queryset, page_size) - page_kwarg = self.kwargs.get(self.page_kwarg) - page_query_param = self.request.query_params.get(self.page_kwarg) - page = page_kwarg or page_query_param or 1 - try: - page_number = paginator.validate_number(page) - except InvalidPage: - if page == 'last': - page_number = paginator.num_pages - else: - raise Http404(_('Choose a valid page number. Page numbers must be a whole number, or must be the string "last".')) - - try: - page = paginator.page(page_number) - except InvalidPage as exc: - error_format = _('Invalid page "{page_number}": {message}.') - raise Http404(error_format.format( - page_number=page_number, message=six.text_type(exc) - )) - - return page - - def filter_queryset(self, queryset): - """ - Given a queryset, filter it with whichever filter backend is in use. - - You are unlikely to want to override this method, although you may need - to call it either from a list view, or from a custom `get_object` - method if you want to apply the configured filtering backend to the - default queryset. - """ - for backend in self.get_filter_backends(): - queryset = backend().filter_queryset(self.request, queryset, self) - return queryset - - def get_filter_backends(self): - """ - Returns the list of filter backends that this view requires. - """ - return list(self.filter_backends) - - # The following methods provide default implementations - # that you may want to override for more complex cases. - - def get_paginate_by(self): - """ - Return the size of pages to use with pagination. - - If `PAGINATE_BY_PARAM` is set it will attempt to get the page size - from a named query parameter in the url, eg. ?page_size=100 - - Otherwise defaults to using `self.paginate_by`. - """ - if self.paginate_by_param: - try: - return strict_positive_int( - self.request.query_params[self.paginate_by_param], - cutoff=self.max_paginate_by - ) - except (KeyError, ValueError): - pass - - return self.paginate_by - - def get_serializer_class(self): - """ - Return the class to use for the serializer. - Defaults to using `self.serializer_class`. - - You may want to override this if you need to provide different - serializations depending on the incoming request. - - (Eg. admins get full serialization, others get basic serialization) - """ - assert self.serializer_class is not None, ( - "'%s' should either include a `serializer_class` attribute, " - "or override the `get_serializer_class()` method." - % self.__class__.__name__ - ) - - return self.serializer_class + # The style to use for queryset pagination. + pagination_class = api_settings.DEFAULT_PAGINATION_CLASS def get_queryset(self): """ @@ -246,6 +99,73 @@ class GenericAPIView(views.APIView): return obj + def get_serializer(self, *args, **kwargs): + """ + Return the serializer instance that should be used for validating and + deserializing input, and for serializing output. + """ + serializer_class = self.get_serializer_class() + kwargs['context'] = self.get_serializer_context() + return serializer_class(*args, **kwargs) + + def get_serializer_class(self): + """ + Return the class to use for the serializer. + Defaults to using `self.serializer_class`. + + You may want to override this if you need to provide different + serializations depending on the incoming request. + + (Eg. admins get full serialization, others get basic serialization) + """ + assert self.serializer_class is not None, ( + "'%s' should either include a `serializer_class` attribute, " + "or override the `get_serializer_class()` method." + % self.__class__.__name__ + ) + + return self.serializer_class + + def get_serializer_context(self): + """ + Extra context provided to the serializer class. + """ + return { + 'request': self.request, + 'format': self.format_kwarg, + 'view': self + } + + def filter_queryset(self, queryset): + """ + Given a queryset, filter it with whichever filter backend is in use. + + You are unlikely to want to override this method, although you may need + to call it either from a list view, or from a custom `get_object` + method if you want to apply the configured filtering backend to the + default queryset. + """ + for backend in list(self.filter_backends): + queryset = backend().filter_queryset(self.request, queryset, self) + return queryset + + @property + def pager(self): + if not hasattr(self, '_pager'): + if self.pagination_class is None: + self._pager = None + else: + self._pager = self.pagination_class() + return self._pager + + def paginate_queryset(self, queryset): + if self.pager is None: + return None + return self.pager.paginate_queryset(queryset, self.request, view=self) + + def get_paginated_response(self, objects): + return self.pager.get_paginated_response(objects) + # Concrete view classes that provide method handlers # by composing the mixin classes with the base view. diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 2074a1072..c34cfcee1 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -5,7 +5,6 @@ We don't bind behaviour to http method handlers yet, which allows mixin classes to be composed in interesting ways. """ from __future__ import unicode_literals - from rest_framework import status from rest_framework.response import Response from rest_framework.settings import api_settings @@ -37,12 +36,14 @@ class ListModelMixin(object): List a queryset. """ def list(self, request, *args, **kwargs): - instance = self.filter_queryset(self.get_queryset()) - page = self.paginate_queryset(instance) + queryset = self.filter_queryset(self.get_queryset()) + + page = self.paginate_queryset(queryset) if page is not None: - serializer = self.get_pagination_serializer(page) - else: - serializer = self.get_serializer(instance, many=True) + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index f31e5fa4c..da2d60a44 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -3,87 +3,192 @@ Pagination serializers determine the structure of the output that should be used for paginated responses. """ from __future__ import unicode_literals -from rest_framework import serializers +from django.core.paginator import InvalidPage, Paginator as DjangoPaginator +from django.utils import six +from django.utils.translation import ugettext as _ +from rest_framework.compat import OrderedDict +from rest_framework.exceptions import NotFound +from rest_framework.response import Response +from rest_framework.settings import api_settings from rest_framework.templatetags.rest_framework import replace_query_param -class NextPageField(serializers.Field): +def _strict_positive_int(integer_string, cutoff=None): """ - Field that returns a link to the next page in paginated results. + Cast a string to a strictly positive integer. """ - page_field = 'page' - - def to_representation(self, value): - if not value.has_next(): - return None - page = value.next_page_number() - request = self.context.get('request') - url = request and request.build_absolute_uri() or '' - return replace_query_param(url, self.page_field, page) + ret = int(integer_string) + if ret <= 0: + raise ValueError() + if cutoff: + ret = min(ret, cutoff) + return ret -class PreviousPageField(serializers.Field): +class BasePagination(object): + def paginate_queryset(self, queryset, request): + raise NotImplemented('paginate_queryset() must be implemented.') + + def get_paginated_response(self, data, page, request): + raise NotImplemented('get_paginated_response() must be implemented.') + + +class PageNumberPagination(BasePagination): """ - Field that returns a link to the previous page in paginated results. + A simple page number based style that supports page numbers as + query parameters. For example: + + http://api.example.org/accounts/?page=4 + http://api.example.org/accounts/?page=4&page_size=100 """ - page_field = 'page' + # The default page size. + # Defaults to `None`, meaning pagination is disabled. + paginate_by = api_settings.PAGINATE_BY - def to_representation(self, value): - if not value.has_previous(): - return None - page = value.previous_page_number() - request = self.context.get('request') - url = request and request.build_absolute_uri() or '' - return replace_query_param(url, self.page_field, page) + # Client can control the page using this query parameter. + page_query_param = 'page' + # Client can control the page size using this query parameter. + # Default is 'None'. Set to eg 'page_size' to enable usage. + paginate_by_param = api_settings.PAGINATE_BY_PARAM -class DefaultObjectSerializer(serializers.ReadOnlyField): - """ - If no object serializer is specified, then this serializer will be applied - as the default. - """ + # Set to an integer to limit the maximum page size the client may request. + # Only relevant if 'paginate_by_param' has also been set. + max_paginate_by = api_settings.MAX_PAGINATE_BY - def __init__(self, source=None, many=None, context=None): - # Note: Swallow context and many kwargs - only required for - # eg. ModelSerializer. - super(DefaultObjectSerializer, self).__init__(source=source) - - -class BasePaginationSerializer(serializers.Serializer): - """ - A base class for pagination serializers to inherit from, - to make implementing custom serializers more easy. - """ - results_field = 'results' - - def __init__(self, *args, **kwargs): + def paginate_queryset(self, queryset, request, view): """ - Override init to add in the object serializer field on-the-fly. + Paginate a queryset if required, either returning a page object, + or `None` if pagination is not configured for this view. """ - super(BasePaginationSerializer, self).__init__(*args, **kwargs) - results_field = self.results_field + for attr in ( + 'paginate_by', 'page_query_param', + 'paginate_by_param', 'max_paginate_by' + ): + if hasattr(view, attr): + setattr(self, attr, getattr(view, attr)) + + page_size = self.get_page_size(request) + if not page_size: + return None + + paginator = DjangoPaginator(queryset, page_size) + page_string = request.query_params.get(self.page_query_param, 1) + try: + page_number = paginator.validate_number(page_string) + except InvalidPage: + if page_string == 'last': + page_number = paginator.num_pages + else: + msg = _( + 'Choose a valid page number. Page numbers must be a ' + 'whole number, or must be the string "last".' + ) + raise NotFound(msg) try: - object_serializer = self.Meta.object_serializer_class - except AttributeError: - object_serializer = DefaultObjectSerializer + self.page = paginator.page(page_number) + except InvalidPage as exc: + msg = _('Invalid page "{page_number}": {message}.').format( + page_number=page_number, message=six.text_type(exc) + ) + raise NotFound(msg) + self.request = request + return self.page + + def get_paginated_response(self, objects): + return Response(OrderedDict([ + ('count', self.page.paginator.count), + ('next', self.get_next_link()), + ('previous', self.get_previous_link()), + ('results', objects) + ])) + + def get_page_size(self, request): + if self.paginate_by_param: + try: + return _strict_positive_int( + request.query_params[self.paginate_by_param], + cutoff=self.max_paginate_by + ) + except (KeyError, ValueError): + pass + + return self.paginate_by + + def get_next_link(self): + if not self.page.has_next(): + return None + url = self.request.build_absolute_uri() + page_number = self.page.next_page_number() + return replace_query_param(url, self.page_query_param, page_number) + + def get_previous_link(self): + if not self.page.has_previous(): + return None + url = self.request.build_absolute_uri() + page_number = self.page.previous_page_number() + return replace_query_param(url, self.page_query_param, page_number) + + +class LimitOffsetPagination(BasePagination): + """ + A limit/offset based style. For example: + + http://api.example.org/accounts/?limit=100 + http://api.example.org/accounts/?offset=400&limit=100 + """ + default_limit = api_settings.PAGINATE_BY + limit_query_param = 'limit' + offset_query_param = 'offset' + max_limit = None + + def paginate_queryset(self, queryset, request, view): + self.limit = self.get_limit(request) + self.offset = self.get_offset(request) + self.count = queryset.count() + self.request = request + return queryset[self.offset:self.offset + self.limit] + + def get_paginated_response(self, objects): + return Response(OrderedDict([ + ('count', self.count), + ('next', self.get_next_link()), + ('previous', self.get_previous_link()), + ('results', objects) + ])) + + def get_limit(self, request): + if self.limit_query_param: + try: + return _strict_positive_int( + request.query_params[self.limit_query_param], + cutoff=self.max_limit + ) + except (KeyError, ValueError): + pass + + return self.default_limit + + def get_offset(self, request): try: - list_serializer_class = object_serializer.Meta.list_serializer_class - except AttributeError: - list_serializer_class = serializers.ListSerializer + return _strict_positive_int( + request.query_params[self.offset_query_param], + ) + except (KeyError, ValueError): + return 0 - self.fields[results_field] = list_serializer_class( - child=object_serializer(), - source='object_list' - ) - self.fields[results_field].bind(field_name=results_field, parent=self) + def get_next_link(self, page): + if self.offset + self.limit >= self.count: + return None + url = self.request.build_absolute_uri() + offset = self.offset + self.limit + return replace_query_param(url, self.offset_query_param, offset) - -class PaginationSerializer(BasePaginationSerializer): - """ - A default implementation of a pagination serializer. - """ - count = serializers.ReadOnlyField(source='paginator.count') - next = NextPageField(source='*') - previous = PreviousPageField(source='*') + def get_previous_link(self, page): + if self.offset - self.limit < 0: + return None + url = self.request.build_absolute_uri() + offset = self.offset - self.limit + return replace_query_param(url, self.offset_query_param, offset) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 877d461be..3cce26b1c 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -49,7 +49,7 @@ DEFAULTS = { 'DEFAULT_VERSIONING_CLASS': None, # Generic view behavior - 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer', + 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', 'DEFAULT_FILTER_BACKENDS': (), # Throttling @@ -130,7 +130,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_METADATA_CLASS', 'DEFAULT_VERSIONING_CLASS', - 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'DEFAULT_PAGINATION_CLASS', 'DEFAULT_FILTER_BACKENDS', 'EXCEPTION_HANDLER', 'TEST_REQUEST_RENDERER_CLASSES', diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1fd9cf9c4..d410cd5eb 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,10 +1,9 @@ from __future__ import unicode_literals import datetime from decimal import Decimal -from django.core.paginator import Paginator from django.test import TestCase from django.utils import unittest -from rest_framework import generics, serializers, status, pagination, filters +from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from .models import BasicModel, FilterableItem @@ -238,45 +237,6 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertEqual(response.data['previous'], None) -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): - """ - Unit tests for pagination of primitive objects. - """ - - def setUp(self): - self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] - paginator = Paginator(self.objects, 10) - self.first_page = paginator.page(1) - self.last_page = paginator.page(3) - - def test_native_pagination(self): - serializer = pagination.PaginationSerializer(self.first_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], '?page=2') - self.assertEqual(serializer.data['previous'], None) - self.assertEqual(serializer.data['results'], self.objects[:10]) - - serializer = pagination.PaginationSerializer(self.last_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], None) - self.assertEqual(serializer.data['previous'], '?page=2') - self.assertEqual(serializer.data['results'], self.objects[20:]) - - def test_context_available_in_result(self): - """ - Ensure context gets passed through to the object serializer. - """ - serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) - serializer.data - results = serializer.fields[serializer.results_field] - self.assertEqual(serializer.context, results.context) - - class TestUnpaginated(TestCase): """ Tests for list views without pagination. @@ -377,177 +337,3 @@ class TestMaxPaginateByParam(TestCase): request = factory.get('/') response = self.view(request).render() self.assertEqual(response.data['results'], self.data[:3]) - - -# Tests for context in pagination serializers - -class CustomField(serializers.ReadOnlyField): - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into custom field") - return "value" - - -class BasicModelSerializer(serializers.Serializer): - text = CustomField() - - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into serializer") - return super(BasicSerializer, self).to_native(value) - - -class TestContextPassedToCustomField(TestCase): - def setUp(self): - BasicModel.objects.create(text='ala ma kota') - - def test_with_pagination(self): - class ListView(generics.ListCreateAPIView): - queryset = BasicModel.objects.all() - serializer_class = BasicModelSerializer - paginate_by = 1 - - self.view = ListView.as_view() - request = factory.get('/') - response = self.view(request).render() - - self.assertEqual(response.status_code, status.HTTP_200_OK) - - -# Tests for custom pagination serializers - -class LinksSerializer(serializers.Serializer): - next = pagination.NextPageField(source='*') - prev = pagination.PreviousPageField(source='*') - - -class CustomPaginationSerializer(pagination.BasePaginationSerializer): - links = LinksSerializer(source='*') # Takes the page object as the source - total_results = serializers.ReadOnlyField(source='paginator.count') - - results_field = 'objects' - - -class CustomFooSerializer(serializers.Serializer): - foo = serializers.CharField() - - -class CustomFooPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = CustomFooSerializer - - -class TestCustomPaginationSerializer(TestCase): - def setUp(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = Paginator(objects, 2) - self.page = paginator.page(1) - - def test_custom_pagination_serializer(self): - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=self.page, - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page=2', - 'prev': None - }, - 'total_results': 4, - 'objects': ['john', 'paul'] - } - self.assertEqual(serializer.data, expected) - - def test_custom_pagination_serializer_with_custom_object_serializer(self): - objects = [ - {'foo': 'bar'}, - {'foo': 'spam'} - ] - paginator = Paginator(objects, 1) - page = paginator.page(1) - serializer = CustomFooPaginationSerializer(page) - serializer.data - - -class NonIntegerPage(object): - - def __init__(self, paginator, object_list, prev_token, token, next_token): - self.paginator = paginator - self.object_list = object_list - self.prev_token = prev_token - self.token = token - self.next_token = next_token - - def has_next(self): - return not not self.next_token - - def next_page_number(self): - return self.next_token - - def has_previous(self): - return not not self.prev_token - - def previous_page_number(self): - return self.prev_token - - -class NonIntegerPaginator(object): - - def __init__(self, object_list, per_page): - self.object_list = object_list - self.per_page = per_page - - def count(self): - # pretend like we don't know how many pages we have - return None - - def page(self, token=None): - if token: - try: - first = self.object_list.index(token) - except ValueError: - first = 0 - else: - first = 0 - n = len(self.object_list) - last = min(first + self.per_page, n) - prev_token = self.object_list[last - (2 * self.per_page)] if first else None - next_token = self.object_list[last] if last < n else None - return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): - def test_custom_pagination_serializer(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = NonIntegerPaginator(objects, 2) - - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page(), - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page={0}'.format(objects[2]), - 'prev': None - }, - 'total_results': None, - 'objects': objects[:2] - } - self.assertEqual(serializer.data, expected) - - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page('george'), - context={'request': request} - ) - expected = { - 'links': { - 'next': None, - 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), - }, - 'total_results': None, - 'objects': objects[2:] - } - self.assertEqual(serializer.data, expected)