mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-05-29 02:33:07 +03:00
First pass at 3.1 pagination API
This commit is contained in:
parent
11efde8905
commit
73feaf6299
|
@ -2,29 +2,13 @@
|
||||||
Generic views that provide commonly needed behaviour.
|
Generic views that provide commonly needed behaviour.
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from django.core.paginator import Paginator, InvalidPage
|
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
from django.shortcuts import get_object_or_404 as _get_object_or_404
|
from django.shortcuts import get_object_or_404 as _get_object_or_404
|
||||||
from django.utils import six
|
|
||||||
from django.utils.translation import ugettext as _
|
|
||||||
from rest_framework import views, mixins
|
from rest_framework import views, mixins
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
|
||||||
def strict_positive_int(integer_string, cutoff=None):
|
|
||||||
"""
|
|
||||||
Cast a string to a strictly positive integer.
|
|
||||||
"""
|
|
||||||
ret = int(integer_string)
|
|
||||||
if ret <= 0:
|
|
||||||
raise ValueError()
|
|
||||||
if cutoff:
|
|
||||||
ret = min(ret, cutoff)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
|
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
|
||||||
"""
|
"""
|
||||||
Same as Django's standard shortcut, but make sure to also raise 404
|
Same as Django's standard shortcut, but make sure to also raise 404
|
||||||
|
@ -40,7 +24,6 @@ class GenericAPIView(views.APIView):
|
||||||
"""
|
"""
|
||||||
Base class for all other generic views.
|
Base class for all other generic views.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# You'll need to either set these attributes,
|
# You'll need to either set these attributes,
|
||||||
# or override `get_queryset()`/`get_serializer_class()`.
|
# or override `get_queryset()`/`get_serializer_class()`.
|
||||||
# If you are overriding a view method, it is important that you call
|
# If you are overriding a view method, it is important that you call
|
||||||
|
@ -50,146 +33,16 @@ class GenericAPIView(views.APIView):
|
||||||
queryset = None
|
queryset = None
|
||||||
serializer_class = None
|
serializer_class = None
|
||||||
|
|
||||||
# If you want to use object lookups other than pk, set this attribute.
|
# If you want to use object lookups other than pk, set 'lookup_field'.
|
||||||
# For more complex lookup requirements override `get_object()`.
|
# For more complex lookup requirements override `get_object()`.
|
||||||
lookup_field = 'pk'
|
lookup_field = 'pk'
|
||||||
lookup_url_kwarg = None
|
lookup_url_kwarg = None
|
||||||
|
|
||||||
# Pagination settings
|
|
||||||
paginate_by = api_settings.PAGINATE_BY
|
|
||||||
paginate_by_param = api_settings.PAGINATE_BY_PARAM
|
|
||||||
max_paginate_by = api_settings.MAX_PAGINATE_BY
|
|
||||||
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
|
|
||||||
page_kwarg = 'page'
|
|
||||||
|
|
||||||
# The filter backend classes to use for queryset filtering
|
# The filter backend classes to use for queryset filtering
|
||||||
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
|
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
|
||||||
|
|
||||||
# The following attribute may be subject to change,
|
# The style to use for queryset pagination.
|
||||||
# and should be considered private API.
|
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
||||||
paginator_class = Paginator
|
|
||||||
|
|
||||||
def get_serializer_context(self):
|
|
||||||
"""
|
|
||||||
Extra context provided to the serializer class.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
'request': self.request,
|
|
||||||
'format': self.format_kwarg,
|
|
||||||
'view': self
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_serializer(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Return the serializer instance that should be used for validating and
|
|
||||||
deserializing input, and for serializing output.
|
|
||||||
"""
|
|
||||||
serializer_class = self.get_serializer_class()
|
|
||||||
kwargs['context'] = self.get_serializer_context()
|
|
||||||
return serializer_class(*args, **kwargs)
|
|
||||||
|
|
||||||
def get_pagination_serializer(self, page):
|
|
||||||
"""
|
|
||||||
Return a serializer instance to use with paginated data.
|
|
||||||
"""
|
|
||||||
class SerializerClass(self.pagination_serializer_class):
|
|
||||||
class Meta:
|
|
||||||
object_serializer_class = self.get_serializer_class()
|
|
||||||
|
|
||||||
pagination_serializer_class = SerializerClass
|
|
||||||
context = self.get_serializer_context()
|
|
||||||
return pagination_serializer_class(instance=page, context=context)
|
|
||||||
|
|
||||||
def paginate_queryset(self, queryset):
|
|
||||||
"""
|
|
||||||
Paginate a queryset if required, either returning a page object,
|
|
||||||
or `None` if pagination is not configured for this view.
|
|
||||||
"""
|
|
||||||
page_size = self.get_paginate_by()
|
|
||||||
if not page_size:
|
|
||||||
return None
|
|
||||||
|
|
||||||
paginator = self.paginator_class(queryset, page_size)
|
|
||||||
page_kwarg = self.kwargs.get(self.page_kwarg)
|
|
||||||
page_query_param = self.request.query_params.get(self.page_kwarg)
|
|
||||||
page = page_kwarg or page_query_param or 1
|
|
||||||
try:
|
|
||||||
page_number = paginator.validate_number(page)
|
|
||||||
except InvalidPage:
|
|
||||||
if page == 'last':
|
|
||||||
page_number = paginator.num_pages
|
|
||||||
else:
|
|
||||||
raise Http404(_('Choose a valid page number. Page numbers must be a whole number, or must be the string "last".'))
|
|
||||||
|
|
||||||
try:
|
|
||||||
page = paginator.page(page_number)
|
|
||||||
except InvalidPage as exc:
|
|
||||||
error_format = _('Invalid page "{page_number}": {message}.')
|
|
||||||
raise Http404(error_format.format(
|
|
||||||
page_number=page_number, message=six.text_type(exc)
|
|
||||||
))
|
|
||||||
|
|
||||||
return page
|
|
||||||
|
|
||||||
def filter_queryset(self, queryset):
|
|
||||||
"""
|
|
||||||
Given a queryset, filter it with whichever filter backend is in use.
|
|
||||||
|
|
||||||
You are unlikely to want to override this method, although you may need
|
|
||||||
to call it either from a list view, or from a custom `get_object`
|
|
||||||
method if you want to apply the configured filtering backend to the
|
|
||||||
default queryset.
|
|
||||||
"""
|
|
||||||
for backend in self.get_filter_backends():
|
|
||||||
queryset = backend().filter_queryset(self.request, queryset, self)
|
|
||||||
return queryset
|
|
||||||
|
|
||||||
def get_filter_backends(self):
|
|
||||||
"""
|
|
||||||
Returns the list of filter backends that this view requires.
|
|
||||||
"""
|
|
||||||
return list(self.filter_backends)
|
|
||||||
|
|
||||||
# The following methods provide default implementations
|
|
||||||
# that you may want to override for more complex cases.
|
|
||||||
|
|
||||||
def get_paginate_by(self):
|
|
||||||
"""
|
|
||||||
Return the size of pages to use with pagination.
|
|
||||||
|
|
||||||
If `PAGINATE_BY_PARAM` is set it will attempt to get the page size
|
|
||||||
from a named query parameter in the url, eg. ?page_size=100
|
|
||||||
|
|
||||||
Otherwise defaults to using `self.paginate_by`.
|
|
||||||
"""
|
|
||||||
if self.paginate_by_param:
|
|
||||||
try:
|
|
||||||
return strict_positive_int(
|
|
||||||
self.request.query_params[self.paginate_by_param],
|
|
||||||
cutoff=self.max_paginate_by
|
|
||||||
)
|
|
||||||
except (KeyError, ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return self.paginate_by
|
|
||||||
|
|
||||||
def get_serializer_class(self):
|
|
||||||
"""
|
|
||||||
Return the class to use for the serializer.
|
|
||||||
Defaults to using `self.serializer_class`.
|
|
||||||
|
|
||||||
You may want to override this if you need to provide different
|
|
||||||
serializations depending on the incoming request.
|
|
||||||
|
|
||||||
(Eg. admins get full serialization, others get basic serialization)
|
|
||||||
"""
|
|
||||||
assert self.serializer_class is not None, (
|
|
||||||
"'%s' should either include a `serializer_class` attribute, "
|
|
||||||
"or override the `get_serializer_class()` method."
|
|
||||||
% self.__class__.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.serializer_class
|
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
"""
|
"""
|
||||||
|
@ -246,6 +99,73 @@ class GenericAPIView(views.APIView):
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def get_serializer(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Return the serializer instance that should be used for validating and
|
||||||
|
deserializing input, and for serializing output.
|
||||||
|
"""
|
||||||
|
serializer_class = self.get_serializer_class()
|
||||||
|
kwargs['context'] = self.get_serializer_context()
|
||||||
|
return serializer_class(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_serializer_class(self):
|
||||||
|
"""
|
||||||
|
Return the class to use for the serializer.
|
||||||
|
Defaults to using `self.serializer_class`.
|
||||||
|
|
||||||
|
You may want to override this if you need to provide different
|
||||||
|
serializations depending on the incoming request.
|
||||||
|
|
||||||
|
(Eg. admins get full serialization, others get basic serialization)
|
||||||
|
"""
|
||||||
|
assert self.serializer_class is not None, (
|
||||||
|
"'%s' should either include a `serializer_class` attribute, "
|
||||||
|
"or override the `get_serializer_class()` method."
|
||||||
|
% self.__class__.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.serializer_class
|
||||||
|
|
||||||
|
def get_serializer_context(self):
|
||||||
|
"""
|
||||||
|
Extra context provided to the serializer class.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'request': self.request,
|
||||||
|
'format': self.format_kwarg,
|
||||||
|
'view': self
|
||||||
|
}
|
||||||
|
|
||||||
|
def filter_queryset(self, queryset):
|
||||||
|
"""
|
||||||
|
Given a queryset, filter it with whichever filter backend is in use.
|
||||||
|
|
||||||
|
You are unlikely to want to override this method, although you may need
|
||||||
|
to call it either from a list view, or from a custom `get_object`
|
||||||
|
method if you want to apply the configured filtering backend to the
|
||||||
|
default queryset.
|
||||||
|
"""
|
||||||
|
for backend in list(self.filter_backends):
|
||||||
|
queryset = backend().filter_queryset(self.request, queryset, self)
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pager(self):
|
||||||
|
if not hasattr(self, '_pager'):
|
||||||
|
if self.pagination_class is None:
|
||||||
|
self._pager = None
|
||||||
|
else:
|
||||||
|
self._pager = self.pagination_class()
|
||||||
|
return self._pager
|
||||||
|
|
||||||
|
def paginate_queryset(self, queryset):
|
||||||
|
if self.pager is None:
|
||||||
|
return None
|
||||||
|
return self.pager.paginate_queryset(queryset, self.request, view=self)
|
||||||
|
|
||||||
|
def get_paginated_response(self, objects):
|
||||||
|
return self.pager.get_paginated_response(objects)
|
||||||
|
|
||||||
|
|
||||||
# Concrete view classes that provide method handlers
|
# Concrete view classes that provide method handlers
|
||||||
# by composing the mixin classes with the base view.
|
# by composing the mixin classes with the base view.
|
||||||
|
|
|
@ -5,7 +5,6 @@ We don't bind behaviour to http method handlers yet,
|
||||||
which allows mixin classes to be composed in interesting ways.
|
which allows mixin classes to be composed in interesting ways.
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
@ -37,12 +36,14 @@ class ListModelMixin(object):
|
||||||
List a queryset.
|
List a queryset.
|
||||||
"""
|
"""
|
||||||
def list(self, request, *args, **kwargs):
|
def list(self, request, *args, **kwargs):
|
||||||
instance = self.filter_queryset(self.get_queryset())
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
page = self.paginate_queryset(instance)
|
|
||||||
|
page = self.paginate_queryset(queryset)
|
||||||
if page is not None:
|
if page is not None:
|
||||||
serializer = self.get_pagination_serializer(page)
|
serializer = self.get_serializer(page, many=True)
|
||||||
else:
|
return self.get_paginated_response(serializer.data)
|
||||||
serializer = self.get_serializer(instance, many=True)
|
|
||||||
|
serializer = self.get_serializer(queryset, many=True)
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,87 +3,192 @@ Pagination serializers determine the structure of the output that should
|
||||||
be used for paginated responses.
|
be used for paginated responses.
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
from rest_framework import serializers
|
from django.core.paginator import InvalidPage, Paginator as DjangoPaginator
|
||||||
|
from django.utils import six
|
||||||
|
from django.utils.translation import ugettext as _
|
||||||
|
from rest_framework.compat import OrderedDict
|
||||||
|
from rest_framework.exceptions import NotFound
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.settings import api_settings
|
||||||
from rest_framework.templatetags.rest_framework import replace_query_param
|
from rest_framework.templatetags.rest_framework import replace_query_param
|
||||||
|
|
||||||
|
|
||||||
class NextPageField(serializers.Field):
|
def _strict_positive_int(integer_string, cutoff=None):
|
||||||
"""
|
"""
|
||||||
Field that returns a link to the next page in paginated results.
|
Cast a string to a strictly positive integer.
|
||||||
"""
|
"""
|
||||||
page_field = 'page'
|
ret = int(integer_string)
|
||||||
|
if ret <= 0:
|
||||||
def to_representation(self, value):
|
raise ValueError()
|
||||||
if not value.has_next():
|
if cutoff:
|
||||||
return None
|
ret = min(ret, cutoff)
|
||||||
page = value.next_page_number()
|
return ret
|
||||||
request = self.context.get('request')
|
|
||||||
url = request and request.build_absolute_uri() or ''
|
|
||||||
return replace_query_param(url, self.page_field, page)
|
|
||||||
|
|
||||||
|
|
||||||
class PreviousPageField(serializers.Field):
|
class BasePagination(object):
|
||||||
|
def paginate_queryset(self, queryset, request):
|
||||||
|
raise NotImplemented('paginate_queryset() must be implemented.')
|
||||||
|
|
||||||
|
def get_paginated_response(self, data, page, request):
|
||||||
|
raise NotImplemented('get_paginated_response() must be implemented.')
|
||||||
|
|
||||||
|
|
||||||
|
class PageNumberPagination(BasePagination):
|
||||||
"""
|
"""
|
||||||
Field that returns a link to the previous page in paginated results.
|
A simple page number based style that supports page numbers as
|
||||||
|
query parameters. For example:
|
||||||
|
|
||||||
|
http://api.example.org/accounts/?page=4
|
||||||
|
http://api.example.org/accounts/?page=4&page_size=100
|
||||||
"""
|
"""
|
||||||
page_field = 'page'
|
# The default page size.
|
||||||
|
# Defaults to `None`, meaning pagination is disabled.
|
||||||
|
paginate_by = api_settings.PAGINATE_BY
|
||||||
|
|
||||||
def to_representation(self, value):
|
# Client can control the page using this query parameter.
|
||||||
if not value.has_previous():
|
page_query_param = 'page'
|
||||||
return None
|
|
||||||
page = value.previous_page_number()
|
|
||||||
request = self.context.get('request')
|
|
||||||
url = request and request.build_absolute_uri() or ''
|
|
||||||
return replace_query_param(url, self.page_field, page)
|
|
||||||
|
|
||||||
|
# Client can control the page size using this query parameter.
|
||||||
|
# Default is 'None'. Set to eg 'page_size' to enable usage.
|
||||||
|
paginate_by_param = api_settings.PAGINATE_BY_PARAM
|
||||||
|
|
||||||
class DefaultObjectSerializer(serializers.ReadOnlyField):
|
# Set to an integer to limit the maximum page size the client may request.
|
||||||
"""
|
# Only relevant if 'paginate_by_param' has also been set.
|
||||||
If no object serializer is specified, then this serializer will be applied
|
max_paginate_by = api_settings.MAX_PAGINATE_BY
|
||||||
as the default.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, source=None, many=None, context=None):
|
def paginate_queryset(self, queryset, request, view):
|
||||||
# Note: Swallow context and many kwargs - only required for
|
|
||||||
# eg. ModelSerializer.
|
|
||||||
super(DefaultObjectSerializer, self).__init__(source=source)
|
|
||||||
|
|
||||||
|
|
||||||
class BasePaginationSerializer(serializers.Serializer):
|
|
||||||
"""
|
|
||||||
A base class for pagination serializers to inherit from,
|
|
||||||
to make implementing custom serializers more easy.
|
|
||||||
"""
|
|
||||||
results_field = 'results'
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Override init to add in the object serializer field on-the-fly.
|
Paginate a queryset if required, either returning a page object,
|
||||||
|
or `None` if pagination is not configured for this view.
|
||||||
"""
|
"""
|
||||||
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
|
for attr in (
|
||||||
results_field = self.results_field
|
'paginate_by', 'page_query_param',
|
||||||
|
'paginate_by_param', 'max_paginate_by'
|
||||||
|
):
|
||||||
|
if hasattr(view, attr):
|
||||||
|
setattr(self, attr, getattr(view, attr))
|
||||||
|
|
||||||
|
page_size = self.get_page_size(request)
|
||||||
|
if not page_size:
|
||||||
|
return None
|
||||||
|
|
||||||
|
paginator = DjangoPaginator(queryset, page_size)
|
||||||
|
page_string = request.query_params.get(self.page_query_param, 1)
|
||||||
|
try:
|
||||||
|
page_number = paginator.validate_number(page_string)
|
||||||
|
except InvalidPage:
|
||||||
|
if page_string == 'last':
|
||||||
|
page_number = paginator.num_pages
|
||||||
|
else:
|
||||||
|
msg = _(
|
||||||
|
'Choose a valid page number. Page numbers must be a '
|
||||||
|
'whole number, or must be the string "last".'
|
||||||
|
)
|
||||||
|
raise NotFound(msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
object_serializer = self.Meta.object_serializer_class
|
self.page = paginator.page(page_number)
|
||||||
except AttributeError:
|
except InvalidPage as exc:
|
||||||
object_serializer = DefaultObjectSerializer
|
msg = _('Invalid page "{page_number}": {message}.').format(
|
||||||
|
page_number=page_number, message=six.text_type(exc)
|
||||||
|
)
|
||||||
|
raise NotFound(msg)
|
||||||
|
|
||||||
|
self.request = request
|
||||||
|
return self.page
|
||||||
|
|
||||||
|
def get_paginated_response(self, objects):
|
||||||
|
return Response(OrderedDict([
|
||||||
|
('count', self.page.paginator.count),
|
||||||
|
('next', self.get_next_link()),
|
||||||
|
('previous', self.get_previous_link()),
|
||||||
|
('results', objects)
|
||||||
|
]))
|
||||||
|
|
||||||
|
def get_page_size(self, request):
|
||||||
|
if self.paginate_by_param:
|
||||||
|
try:
|
||||||
|
return _strict_positive_int(
|
||||||
|
request.query_params[self.paginate_by_param],
|
||||||
|
cutoff=self.max_paginate_by
|
||||||
|
)
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return self.paginate_by
|
||||||
|
|
||||||
|
def get_next_link(self):
|
||||||
|
if not self.page.has_next():
|
||||||
|
return None
|
||||||
|
url = self.request.build_absolute_uri()
|
||||||
|
page_number = self.page.next_page_number()
|
||||||
|
return replace_query_param(url, self.page_query_param, page_number)
|
||||||
|
|
||||||
|
def get_previous_link(self):
|
||||||
|
if not self.page.has_previous():
|
||||||
|
return None
|
||||||
|
url = self.request.build_absolute_uri()
|
||||||
|
page_number = self.page.previous_page_number()
|
||||||
|
return replace_query_param(url, self.page_query_param, page_number)
|
||||||
|
|
||||||
|
|
||||||
|
class LimitOffsetPagination(BasePagination):
|
||||||
|
"""
|
||||||
|
A limit/offset based style. For example:
|
||||||
|
|
||||||
|
http://api.example.org/accounts/?limit=100
|
||||||
|
http://api.example.org/accounts/?offset=400&limit=100
|
||||||
|
"""
|
||||||
|
default_limit = api_settings.PAGINATE_BY
|
||||||
|
limit_query_param = 'limit'
|
||||||
|
offset_query_param = 'offset'
|
||||||
|
max_limit = None
|
||||||
|
|
||||||
|
def paginate_queryset(self, queryset, request, view):
|
||||||
|
self.limit = self.get_limit(request)
|
||||||
|
self.offset = self.get_offset(request)
|
||||||
|
self.count = queryset.count()
|
||||||
|
self.request = request
|
||||||
|
return queryset[self.offset:self.offset + self.limit]
|
||||||
|
|
||||||
|
def get_paginated_response(self, objects):
|
||||||
|
return Response(OrderedDict([
|
||||||
|
('count', self.count),
|
||||||
|
('next', self.get_next_link()),
|
||||||
|
('previous', self.get_previous_link()),
|
||||||
|
('results', objects)
|
||||||
|
]))
|
||||||
|
|
||||||
|
def get_limit(self, request):
|
||||||
|
if self.limit_query_param:
|
||||||
|
try:
|
||||||
|
return _strict_positive_int(
|
||||||
|
request.query_params[self.limit_query_param],
|
||||||
|
cutoff=self.max_limit
|
||||||
|
)
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return self.default_limit
|
||||||
|
|
||||||
|
def get_offset(self, request):
|
||||||
try:
|
try:
|
||||||
list_serializer_class = object_serializer.Meta.list_serializer_class
|
return _strict_positive_int(
|
||||||
except AttributeError:
|
request.query_params[self.offset_query_param],
|
||||||
list_serializer_class = serializers.ListSerializer
|
)
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
return 0
|
||||||
|
|
||||||
self.fields[results_field] = list_serializer_class(
|
def get_next_link(self, page):
|
||||||
child=object_serializer(),
|
if self.offset + self.limit >= self.count:
|
||||||
source='object_list'
|
return None
|
||||||
)
|
url = self.request.build_absolute_uri()
|
||||||
self.fields[results_field].bind(field_name=results_field, parent=self)
|
offset = self.offset + self.limit
|
||||||
|
return replace_query_param(url, self.offset_query_param, offset)
|
||||||
|
|
||||||
|
def get_previous_link(self, page):
|
||||||
class PaginationSerializer(BasePaginationSerializer):
|
if self.offset - self.limit < 0:
|
||||||
"""
|
return None
|
||||||
A default implementation of a pagination serializer.
|
url = self.request.build_absolute_uri()
|
||||||
"""
|
offset = self.offset - self.limit
|
||||||
count = serializers.ReadOnlyField(source='paginator.count')
|
return replace_query_param(url, self.offset_query_param, offset)
|
||||||
next = NextPageField(source='*')
|
|
||||||
previous = PreviousPageField(source='*')
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ DEFAULTS = {
|
||||||
'DEFAULT_VERSIONING_CLASS': None,
|
'DEFAULT_VERSIONING_CLASS': None,
|
||||||
|
|
||||||
# Generic view behavior
|
# Generic view behavior
|
||||||
'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
|
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
|
||||||
'DEFAULT_FILTER_BACKENDS': (),
|
'DEFAULT_FILTER_BACKENDS': (),
|
||||||
|
|
||||||
# Throttling
|
# Throttling
|
||||||
|
@ -130,7 +130,7 @@ IMPORT_STRINGS = (
|
||||||
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
|
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
|
||||||
'DEFAULT_METADATA_CLASS',
|
'DEFAULT_METADATA_CLASS',
|
||||||
'DEFAULT_VERSIONING_CLASS',
|
'DEFAULT_VERSIONING_CLASS',
|
||||||
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
|
'DEFAULT_PAGINATION_CLASS',
|
||||||
'DEFAULT_FILTER_BACKENDS',
|
'DEFAULT_FILTER_BACKENDS',
|
||||||
'EXCEPTION_HANDLER',
|
'EXCEPTION_HANDLER',
|
||||||
'TEST_REQUEST_RENDERER_CLASSES',
|
'TEST_REQUEST_RENDERER_CLASSES',
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import datetime
|
import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from django.core.paginator import Paginator
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import unittest
|
from django.utils import unittest
|
||||||
from rest_framework import generics, serializers, status, pagination, filters
|
from rest_framework import generics, serializers, status, filters
|
||||||
from rest_framework.compat import django_filters
|
from rest_framework.compat import django_filters
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
from .models import BasicModel, FilterableItem
|
from .models import BasicModel, FilterableItem
|
||||||
|
@ -238,45 +237,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
|
||||||
self.assertEqual(response.data['previous'], None)
|
self.assertEqual(response.data['previous'], None)
|
||||||
|
|
||||||
|
|
||||||
class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
|
|
||||||
class Meta:
|
|
||||||
object_serializer_class = serializers.Serializer
|
|
||||||
|
|
||||||
|
|
||||||
class UnitTestPagination(TestCase):
|
|
||||||
"""
|
|
||||||
Unit tests for pagination of primitive objects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
|
|
||||||
paginator = Paginator(self.objects, 10)
|
|
||||||
self.first_page = paginator.page(1)
|
|
||||||
self.last_page = paginator.page(3)
|
|
||||||
|
|
||||||
def test_native_pagination(self):
|
|
||||||
serializer = pagination.PaginationSerializer(self.first_page)
|
|
||||||
self.assertEqual(serializer.data['count'], 26)
|
|
||||||
self.assertEqual(serializer.data['next'], '?page=2')
|
|
||||||
self.assertEqual(serializer.data['previous'], None)
|
|
||||||
self.assertEqual(serializer.data['results'], self.objects[:10])
|
|
||||||
|
|
||||||
serializer = pagination.PaginationSerializer(self.last_page)
|
|
||||||
self.assertEqual(serializer.data['count'], 26)
|
|
||||||
self.assertEqual(serializer.data['next'], None)
|
|
||||||
self.assertEqual(serializer.data['previous'], '?page=2')
|
|
||||||
self.assertEqual(serializer.data['results'], self.objects[20:])
|
|
||||||
|
|
||||||
def test_context_available_in_result(self):
|
|
||||||
"""
|
|
||||||
Ensure context gets passed through to the object serializer.
|
|
||||||
"""
|
|
||||||
serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
|
|
||||||
serializer.data
|
|
||||||
results = serializer.fields[serializer.results_field]
|
|
||||||
self.assertEqual(serializer.context, results.context)
|
|
||||||
|
|
||||||
|
|
||||||
class TestUnpaginated(TestCase):
|
class TestUnpaginated(TestCase):
|
||||||
"""
|
"""
|
||||||
Tests for list views without pagination.
|
Tests for list views without pagination.
|
||||||
|
@ -377,177 +337,3 @@ class TestMaxPaginateByParam(TestCase):
|
||||||
request = factory.get('/')
|
request = factory.get('/')
|
||||||
response = self.view(request).render()
|
response = self.view(request).render()
|
||||||
self.assertEqual(response.data['results'], self.data[:3])
|
self.assertEqual(response.data['results'], self.data[:3])
|
||||||
|
|
||||||
|
|
||||||
# Tests for context in pagination serializers
|
|
||||||
|
|
||||||
class CustomField(serializers.ReadOnlyField):
|
|
||||||
def to_native(self, value):
|
|
||||||
if 'view' not in self.context:
|
|
||||||
raise RuntimeError("context isn't getting passed into custom field")
|
|
||||||
return "value"
|
|
||||||
|
|
||||||
|
|
||||||
class BasicModelSerializer(serializers.Serializer):
|
|
||||||
text = CustomField()
|
|
||||||
|
|
||||||
def to_native(self, value):
|
|
||||||
if 'view' not in self.context:
|
|
||||||
raise RuntimeError("context isn't getting passed into serializer")
|
|
||||||
return super(BasicSerializer, self).to_native(value)
|
|
||||||
|
|
||||||
|
|
||||||
class TestContextPassedToCustomField(TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
BasicModel.objects.create(text='ala ma kota')
|
|
||||||
|
|
||||||
def test_with_pagination(self):
|
|
||||||
class ListView(generics.ListCreateAPIView):
|
|
||||||
queryset = BasicModel.objects.all()
|
|
||||||
serializer_class = BasicModelSerializer
|
|
||||||
paginate_by = 1
|
|
||||||
|
|
||||||
self.view = ListView.as_view()
|
|
||||||
request = factory.get('/')
|
|
||||||
response = self.view(request).render()
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
||||||
|
|
||||||
|
|
||||||
# Tests for custom pagination serializers
|
|
||||||
|
|
||||||
class LinksSerializer(serializers.Serializer):
|
|
||||||
next = pagination.NextPageField(source='*')
|
|
||||||
prev = pagination.PreviousPageField(source='*')
|
|
||||||
|
|
||||||
|
|
||||||
class CustomPaginationSerializer(pagination.BasePaginationSerializer):
|
|
||||||
links = LinksSerializer(source='*') # Takes the page object as the source
|
|
||||||
total_results = serializers.ReadOnlyField(source='paginator.count')
|
|
||||||
|
|
||||||
results_field = 'objects'
|
|
||||||
|
|
||||||
|
|
||||||
class CustomFooSerializer(serializers.Serializer):
|
|
||||||
foo = serializers.CharField()
|
|
||||||
|
|
||||||
|
|
||||||
class CustomFooPaginationSerializer(pagination.PaginationSerializer):
|
|
||||||
class Meta:
|
|
||||||
object_serializer_class = CustomFooSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomPaginationSerializer(TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
objects = ['john', 'paul', 'george', 'ringo']
|
|
||||||
paginator = Paginator(objects, 2)
|
|
||||||
self.page = paginator.page(1)
|
|
||||||
|
|
||||||
def test_custom_pagination_serializer(self):
|
|
||||||
request = APIRequestFactory().get('/foobar')
|
|
||||||
serializer = CustomPaginationSerializer(
|
|
||||||
instance=self.page,
|
|
||||||
context={'request': request}
|
|
||||||
)
|
|
||||||
expected = {
|
|
||||||
'links': {
|
|
||||||
'next': 'http://testserver/foobar?page=2',
|
|
||||||
'prev': None
|
|
||||||
},
|
|
||||||
'total_results': 4,
|
|
||||||
'objects': ['john', 'paul']
|
|
||||||
}
|
|
||||||
self.assertEqual(serializer.data, expected)
|
|
||||||
|
|
||||||
def test_custom_pagination_serializer_with_custom_object_serializer(self):
|
|
||||||
objects = [
|
|
||||||
{'foo': 'bar'},
|
|
||||||
{'foo': 'spam'}
|
|
||||||
]
|
|
||||||
paginator = Paginator(objects, 1)
|
|
||||||
page = paginator.page(1)
|
|
||||||
serializer = CustomFooPaginationSerializer(page)
|
|
||||||
serializer.data
|
|
||||||
|
|
||||||
|
|
||||||
class NonIntegerPage(object):
|
|
||||||
|
|
||||||
def __init__(self, paginator, object_list, prev_token, token, next_token):
|
|
||||||
self.paginator = paginator
|
|
||||||
self.object_list = object_list
|
|
||||||
self.prev_token = prev_token
|
|
||||||
self.token = token
|
|
||||||
self.next_token = next_token
|
|
||||||
|
|
||||||
def has_next(self):
|
|
||||||
return not not self.next_token
|
|
||||||
|
|
||||||
def next_page_number(self):
|
|
||||||
return self.next_token
|
|
||||||
|
|
||||||
def has_previous(self):
|
|
||||||
return not not self.prev_token
|
|
||||||
|
|
||||||
def previous_page_number(self):
|
|
||||||
return self.prev_token
|
|
||||||
|
|
||||||
|
|
||||||
class NonIntegerPaginator(object):
|
|
||||||
|
|
||||||
def __init__(self, object_list, per_page):
|
|
||||||
self.object_list = object_list
|
|
||||||
self.per_page = per_page
|
|
||||||
|
|
||||||
def count(self):
|
|
||||||
# pretend like we don't know how many pages we have
|
|
||||||
return None
|
|
||||||
|
|
||||||
def page(self, token=None):
|
|
||||||
if token:
|
|
||||||
try:
|
|
||||||
first = self.object_list.index(token)
|
|
||||||
except ValueError:
|
|
||||||
first = 0
|
|
||||||
else:
|
|
||||||
first = 0
|
|
||||||
n = len(self.object_list)
|
|
||||||
last = min(first + self.per_page, n)
|
|
||||||
prev_token = self.object_list[last - (2 * self.per_page)] if first else None
|
|
||||||
next_token = self.object_list[last] if last < n else None
|
|
||||||
return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNonIntegerPagination(TestCase):
|
|
||||||
def test_custom_pagination_serializer(self):
|
|
||||||
objects = ['john', 'paul', 'george', 'ringo']
|
|
||||||
paginator = NonIntegerPaginator(objects, 2)
|
|
||||||
|
|
||||||
request = APIRequestFactory().get('/foobar')
|
|
||||||
serializer = CustomPaginationSerializer(
|
|
||||||
instance=paginator.page(),
|
|
||||||
context={'request': request}
|
|
||||||
)
|
|
||||||
expected = {
|
|
||||||
'links': {
|
|
||||||
'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
|
|
||||||
'prev': None
|
|
||||||
},
|
|
||||||
'total_results': None,
|
|
||||||
'objects': objects[:2]
|
|
||||||
}
|
|
||||||
self.assertEqual(serializer.data, expected)
|
|
||||||
|
|
||||||
request = APIRequestFactory().get('/foobar')
|
|
||||||
serializer = CustomPaginationSerializer(
|
|
||||||
instance=paginator.page('george'),
|
|
||||||
context={'request': request}
|
|
||||||
)
|
|
||||||
expected = {
|
|
||||||
'links': {
|
|
||||||
'next': None,
|
|
||||||
'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
|
|
||||||
},
|
|
||||||
'total_results': None,
|
|
||||||
'objects': objects[2:]
|
|
||||||
}
|
|
||||||
self.assertEqual(serializer.data, expected)
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user