First pass at 3.1 pagination API

This commit is contained in:
Tom Christie 2015-01-09 15:30:36 +00:00
parent 11efde8905
commit 73feaf6299
5 changed files with 248 additions and 436 deletions

View File

@ -2,29 +2,13 @@
Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
from django.core.paginator import Paginator, InvalidPage
from django.db.models.query import QuerySet
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
from django.utils import six
from django.utils.translation import ugettext as _
from rest_framework import views, mixins
from rest_framework.settings import api_settings
def strict_positive_int(integer_string, cutoff=None):
"""
Cast a string to a strictly positive integer.
"""
ret = int(integer_string)
if ret <= 0:
raise ValueError()
if cutoff:
ret = min(ret, cutoff)
return ret
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to also raise 404
@ -40,7 +24,6 @@ class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
# You'll need to either set these attributes,
# or override `get_queryset()`/`get_serializer_class()`.
# If you are overriding a view method, it is important that you call
@ -50,146 +33,16 @@ class GenericAPIView(views.APIView):
queryset = None
serializer_class = None
# If you want to use object lookups other than pk, set this attribute.
# If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
lookup_url_kwarg = None
# Pagination settings
paginate_by = api_settings.PAGINATE_BY
paginate_by_param = api_settings.PAGINATE_BY_PARAM
max_paginate_by = api_settings.MAX_PAGINATE_BY
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
page_kwarg = 'page'
# The filter backend classes to use for queryset filtering
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
# The following attribute may be subject to change,
# and should be considered private API.
paginator_class = Paginator
def get_serializer_context(self):
"""
Extra context provided to the serializer class.
"""
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
kwargs['context'] = self.get_serializer_context()
return serializer_class(*args, **kwargs)
def get_pagination_serializer(self, page):
"""
Return a serializer instance to use with paginated data.
"""
class SerializerClass(self.pagination_serializer_class):
class Meta:
object_serializer_class = self.get_serializer_class()
pagination_serializer_class = SerializerClass
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
def paginate_queryset(self, queryset):
"""
Paginate a queryset if required, either returning a page object,
or `None` if pagination is not configured for this view.
"""
page_size = self.get_paginate_by()
if not page_size:
return None
paginator = self.paginator_class(queryset, page_size)
page_kwarg = self.kwargs.get(self.page_kwarg)
page_query_param = self.request.query_params.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
try:
page_number = paginator.validate_number(page)
except InvalidPage:
if page == 'last':
page_number = paginator.num_pages
else:
raise Http404(_('Choose a valid page number. Page numbers must be a whole number, or must be the string "last".'))
try:
page = paginator.page(page_number)
except InvalidPage as exc:
error_format = _('Invalid page "{page_number}": {message}.')
raise Http404(error_format.format(
page_number=page_number, message=six.text_type(exc)
))
return page
def filter_queryset(self, queryset):
"""
Given a queryset, filter it with whichever filter backend is in use.
You are unlikely to want to override this method, although you may need
to call it either from a list view, or from a custom `get_object`
method if you want to apply the configured filtering backend to the
default queryset.
"""
for backend in self.get_filter_backends():
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
def get_filter_backends(self):
"""
Returns the list of filter backends that this view requires.
"""
return list(self.filter_backends)
# The following methods provide default implementations
# that you may want to override for more complex cases.
def get_paginate_by(self):
"""
Return the size of pages to use with pagination.
If `PAGINATE_BY_PARAM` is set it will attempt to get the page size
from a named query parameter in the url, eg. ?page_size=100
Otherwise defaults to using `self.paginate_by`.
"""
if self.paginate_by_param:
try:
return strict_positive_int(
self.request.query_params[self.paginate_by_param],
cutoff=self.max_paginate_by
)
except (KeyError, ValueError):
pass
return self.paginate_by
def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, "
"or override the `get_serializer_class()` method."
% self.__class__.__name__
)
return self.serializer_class
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
def get_queryset(self):
"""
@ -246,6 +99,73 @@ class GenericAPIView(views.APIView):
return obj
def get_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
kwargs['context'] = self.get_serializer_context()
return serializer_class(*args, **kwargs)
def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, "
"or override the `get_serializer_class()` method."
% self.__class__.__name__
)
return self.serializer_class
def get_serializer_context(self):
"""
Extra context provided to the serializer class.
"""
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def filter_queryset(self, queryset):
"""
Given a queryset, filter it with whichever filter backend is in use.
You are unlikely to want to override this method, although you may need
to call it either from a list view, or from a custom `get_object`
method if you want to apply the configured filtering backend to the
default queryset.
"""
for backend in list(self.filter_backends):
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
@property
def pager(self):
if not hasattr(self, '_pager'):
if self.pagination_class is None:
self._pager = None
else:
self._pager = self.pagination_class()
return self._pager
def paginate_queryset(self, queryset):
if self.pager is None:
return None
return self.pager.paginate_queryset(queryset, self.request, view=self)
def get_paginated_response(self, objects):
return self.pager.get_paginated_response(objects)
# Concrete view classes that provide method handlers
# by composing the mixin classes with the base view.

View File

@ -5,7 +5,6 @@ We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
from __future__ import unicode_literals
from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
@ -37,12 +36,14 @@ class ListModelMixin(object):
List a queryset.
"""
def list(self, request, *args, **kwargs):
instance = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(instance)
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_pagination_serializer(page)
else:
serializer = self.get_serializer(instance, many=True)
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)

View File

@ -3,87 +3,192 @@ Pagination serializers determine the structure of the output that should
be used for paginated responses.
"""
from __future__ import unicode_literals
from rest_framework import serializers
from django.core.paginator import InvalidPage, Paginator as DjangoPaginator
from django.utils import six
from django.utils.translation import ugettext as _
from rest_framework.compat import OrderedDict
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.templatetags.rest_framework import replace_query_param
class NextPageField(serializers.Field):
def _strict_positive_int(integer_string, cutoff=None):
"""
Field that returns a link to the next page in paginated results.
Cast a string to a strictly positive integer.
"""
page_field = 'page'
def to_representation(self, value):
if not value.has_next():
return None
page = value.next_page_number()
request = self.context.get('request')
url = request and request.build_absolute_uri() or ''
return replace_query_param(url, self.page_field, page)
ret = int(integer_string)
if ret <= 0:
raise ValueError()
if cutoff:
ret = min(ret, cutoff)
return ret
class PreviousPageField(serializers.Field):
class BasePagination(object):
def paginate_queryset(self, queryset, request):
raise NotImplemented('paginate_queryset() must be implemented.')
def get_paginated_response(self, data, page, request):
raise NotImplemented('get_paginated_response() must be implemented.')
class PageNumberPagination(BasePagination):
"""
Field that returns a link to the previous page in paginated results.
A simple page number based style that supports page numbers as
query parameters. For example:
http://api.example.org/accounts/?page=4
http://api.example.org/accounts/?page=4&page_size=100
"""
page_field = 'page'
# The default page size.
# Defaults to `None`, meaning pagination is disabled.
paginate_by = api_settings.PAGINATE_BY
def to_representation(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
request = self.context.get('request')
url = request and request.build_absolute_uri() or ''
return replace_query_param(url, self.page_field, page)
# Client can control the page using this query parameter.
page_query_param = 'page'
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
paginate_by_param = api_settings.PAGINATE_BY_PARAM
class DefaultObjectSerializer(serializers.ReadOnlyField):
"""
If no object serializer is specified, then this serializer will be applied
as the default.
"""
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'paginate_by_param' has also been set.
max_paginate_by = api_settings.MAX_PAGINATE_BY
def __init__(self, source=None, many=None, context=None):
# Note: Swallow context and many kwargs - only required for
# eg. ModelSerializer.
super(DefaultObjectSerializer, self).__init__(source=source)
class BasePaginationSerializer(serializers.Serializer):
"""
A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy.
"""
results_field = 'results'
def __init__(self, *args, **kwargs):
def paginate_queryset(self, queryset, request, view):
"""
Override init to add in the object serializer field on-the-fly.
Paginate a queryset if required, either returning a page object,
or `None` if pagination is not configured for this view.
"""
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field
for attr in (
'paginate_by', 'page_query_param',
'paginate_by_param', 'max_paginate_by'
):
if hasattr(view, attr):
setattr(self, attr, getattr(view, attr))
page_size = self.get_page_size(request)
if not page_size:
return None
paginator = DjangoPaginator(queryset, page_size)
page_string = request.query_params.get(self.page_query_param, 1)
try:
page_number = paginator.validate_number(page_string)
except InvalidPage:
if page_string == 'last':
page_number = paginator.num_pages
else:
msg = _(
'Choose a valid page number. Page numbers must be a '
'whole number, or must be the string "last".'
)
raise NotFound(msg)
try:
object_serializer = self.Meta.object_serializer_class
except AttributeError:
object_serializer = DefaultObjectSerializer
self.page = paginator.page(page_number)
except InvalidPage as exc:
msg = _('Invalid page "{page_number}": {message}.').format(
page_number=page_number, message=six.text_type(exc)
)
raise NotFound(msg)
self.request = request
return self.page
def get_paginated_response(self, objects):
return Response(OrderedDict([
('count', self.page.paginator.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', objects)
]))
def get_page_size(self, request):
if self.paginate_by_param:
try:
return _strict_positive_int(
request.query_params[self.paginate_by_param],
cutoff=self.max_paginate_by
)
except (KeyError, ValueError):
pass
return self.paginate_by
def get_next_link(self):
if not self.page.has_next():
return None
url = self.request.build_absolute_uri()
page_number = self.page.next_page_number()
return replace_query_param(url, self.page_query_param, page_number)
def get_previous_link(self):
if not self.page.has_previous():
return None
url = self.request.build_absolute_uri()
page_number = self.page.previous_page_number()
return replace_query_param(url, self.page_query_param, page_number)
class LimitOffsetPagination(BasePagination):
"""
A limit/offset based style. For example:
http://api.example.org/accounts/?limit=100
http://api.example.org/accounts/?offset=400&limit=100
"""
default_limit = api_settings.PAGINATE_BY
limit_query_param = 'limit'
offset_query_param = 'offset'
max_limit = None
def paginate_queryset(self, queryset, request, view):
self.limit = self.get_limit(request)
self.offset = self.get_offset(request)
self.count = queryset.count()
self.request = request
return queryset[self.offset:self.offset + self.limit]
def get_paginated_response(self, objects):
return Response(OrderedDict([
('count', self.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', objects)
]))
def get_limit(self, request):
if self.limit_query_param:
try:
return _strict_positive_int(
request.query_params[self.limit_query_param],
cutoff=self.max_limit
)
except (KeyError, ValueError):
pass
return self.default_limit
def get_offset(self, request):
try:
list_serializer_class = object_serializer.Meta.list_serializer_class
except AttributeError:
list_serializer_class = serializers.ListSerializer
return _strict_positive_int(
request.query_params[self.offset_query_param],
)
except (KeyError, ValueError):
return 0
self.fields[results_field] = list_serializer_class(
child=object_serializer(),
source='object_list'
)
self.fields[results_field].bind(field_name=results_field, parent=self)
def get_next_link(self, page):
if self.offset + self.limit >= self.count:
return None
url = self.request.build_absolute_uri()
offset = self.offset + self.limit
return replace_query_param(url, self.offset_query_param, offset)
class PaginationSerializer(BasePaginationSerializer):
"""
A default implementation of a pagination serializer.
"""
count = serializers.ReadOnlyField(source='paginator.count')
next = NextPageField(source='*')
previous = PreviousPageField(source='*')
def get_previous_link(self, page):
if self.offset - self.limit < 0:
return None
url = self.request.build_absolute_uri()
offset = self.offset - self.limit
return replace_query_param(url, self.offset_query_param, offset)

View File

@ -49,7 +49,7 @@ DEFAULTS = {
'DEFAULT_VERSIONING_CLASS': None,
# Generic view behavior
'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'DEFAULT_FILTER_BACKENDS': (),
# Throttling
@ -130,7 +130,7 @@ IMPORT_STRINGS = (
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_METADATA_CLASS',
'DEFAULT_VERSIONING_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_CLASS',
'DEFAULT_FILTER_BACKENDS',
'EXCEPTION_HANDLER',
'TEST_REQUEST_RENDERER_CLASSES',

View File

@ -1,10 +1,9 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
from django.core.paginator import Paginator
from django.test import TestCase
from django.utils import unittest
from rest_framework import generics, serializers, status, pagination, filters
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
from .models import BasicModel, FilterableItem
@ -238,45 +237,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['previous'], None)
class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
class Meta:
object_serializer_class = serializers.Serializer
class UnitTestPagination(TestCase):
"""
Unit tests for pagination of primitive objects.
"""
def setUp(self):
self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
paginator = Paginator(self.objects, 10)
self.first_page = paginator.page(1)
self.last_page = paginator.page(3)
def test_native_pagination(self):
serializer = pagination.PaginationSerializer(self.first_page)
self.assertEqual(serializer.data['count'], 26)
self.assertEqual(serializer.data['next'], '?page=2')
self.assertEqual(serializer.data['previous'], None)
self.assertEqual(serializer.data['results'], self.objects[:10])
serializer = pagination.PaginationSerializer(self.last_page)
self.assertEqual(serializer.data['count'], 26)
self.assertEqual(serializer.data['next'], None)
self.assertEqual(serializer.data['previous'], '?page=2')
self.assertEqual(serializer.data['results'], self.objects[20:])
def test_context_available_in_result(self):
"""
Ensure context gets passed through to the object serializer.
"""
serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
serializer.data
results = serializer.fields[serializer.results_field]
self.assertEqual(serializer.context, results.context)
class TestUnpaginated(TestCase):
"""
Tests for list views without pagination.
@ -377,177 +337,3 @@ class TestMaxPaginateByParam(TestCase):
request = factory.get('/')
response = self.view(request).render()
self.assertEqual(response.data['results'], self.data[:3])
# Tests for context in pagination serializers
class CustomField(serializers.ReadOnlyField):
def to_native(self, value):
if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into custom field")
return "value"
class BasicModelSerializer(serializers.Serializer):
text = CustomField()
def to_native(self, value):
if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into serializer")
return super(BasicSerializer, self).to_native(value)
class TestContextPassedToCustomField(TestCase):
def setUp(self):
BasicModel.objects.create(text='ala ma kota')
def test_with_pagination(self):
class ListView(generics.ListCreateAPIView):
queryset = BasicModel.objects.all()
serializer_class = BasicModelSerializer
paginate_by = 1
self.view = ListView.as_view()
request = factory.get('/')
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Tests for custom pagination serializers
class LinksSerializer(serializers.Serializer):
next = pagination.NextPageField(source='*')
prev = pagination.PreviousPageField(source='*')
class CustomPaginationSerializer(pagination.BasePaginationSerializer):
links = LinksSerializer(source='*') # Takes the page object as the source
total_results = serializers.ReadOnlyField(source='paginator.count')
results_field = 'objects'
class CustomFooSerializer(serializers.Serializer):
foo = serializers.CharField()
class CustomFooPaginationSerializer(pagination.PaginationSerializer):
class Meta:
object_serializer_class = CustomFooSerializer
class TestCustomPaginationSerializer(TestCase):
def setUp(self):
objects = ['john', 'paul', 'george', 'ringo']
paginator = Paginator(objects, 2)
self.page = paginator.page(1)
def test_custom_pagination_serializer(self):
request = APIRequestFactory().get('/foobar')
serializer = CustomPaginationSerializer(
instance=self.page,
context={'request': request}
)
expected = {
'links': {
'next': 'http://testserver/foobar?page=2',
'prev': None
},
'total_results': 4,
'objects': ['john', 'paul']
}
self.assertEqual(serializer.data, expected)
def test_custom_pagination_serializer_with_custom_object_serializer(self):
objects = [
{'foo': 'bar'},
{'foo': 'spam'}
]
paginator = Paginator(objects, 1)
page = paginator.page(1)
serializer = CustomFooPaginationSerializer(page)
serializer.data
class NonIntegerPage(object):
def __init__(self, paginator, object_list, prev_token, token, next_token):
self.paginator = paginator
self.object_list = object_list
self.prev_token = prev_token
self.token = token
self.next_token = next_token
def has_next(self):
return not not self.next_token
def next_page_number(self):
return self.next_token
def has_previous(self):
return not not self.prev_token
def previous_page_number(self):
return self.prev_token
class NonIntegerPaginator(object):
def __init__(self, object_list, per_page):
self.object_list = object_list
self.per_page = per_page
def count(self):
# pretend like we don't know how many pages we have
return None
def page(self, token=None):
if token:
try:
first = self.object_list.index(token)
except ValueError:
first = 0
else:
first = 0
n = len(self.object_list)
last = min(first + self.per_page, n)
prev_token = self.object_list[last - (2 * self.per_page)] if first else None
next_token = self.object_list[last] if last < n else None
return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
class TestNonIntegerPagination(TestCase):
def test_custom_pagination_serializer(self):
objects = ['john', 'paul', 'george', 'ringo']
paginator = NonIntegerPaginator(objects, 2)
request = APIRequestFactory().get('/foobar')
serializer = CustomPaginationSerializer(
instance=paginator.page(),
context={'request': request}
)
expected = {
'links': {
'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
'prev': None
},
'total_results': None,
'objects': objects[:2]
}
self.assertEqual(serializer.data, expected)
request = APIRequestFactory().get('/foobar')
serializer = CustomPaginationSerializer(
instance=paginator.page('george'),
context={'request': request}
)
expected = {
'links': {
'next': None,
'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
},
'total_results': None,
'objects': objects[2:]
}
self.assertEqual(serializer.data, expected)