mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-22 09:36:49 +03:00
First pass at 3.1 pagination API
This commit is contained in:
parent
11efde8905
commit
73feaf6299
|
@ -2,29 +2,13 @@
|
|||
Generic views that provide commonly needed behaviour.
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.core.paginator import Paginator, InvalidPage
|
||||
from django.db.models.query import QuerySet
|
||||
from django.http import Http404
|
||||
from django.shortcuts import get_object_or_404 as _get_object_or_404
|
||||
from django.utils import six
|
||||
from django.utils.translation import ugettext as _
|
||||
from rest_framework import views, mixins
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
def strict_positive_int(integer_string, cutoff=None):
|
||||
"""
|
||||
Cast a string to a strictly positive integer.
|
||||
"""
|
||||
ret = int(integer_string)
|
||||
if ret <= 0:
|
||||
raise ValueError()
|
||||
if cutoff:
|
||||
ret = min(ret, cutoff)
|
||||
return ret
|
||||
|
||||
|
||||
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
|
||||
"""
|
||||
Same as Django's standard shortcut, but make sure to also raise 404
|
||||
|
@ -40,7 +24,6 @@ class GenericAPIView(views.APIView):
|
|||
"""
|
||||
Base class for all other generic views.
|
||||
"""
|
||||
|
||||
# You'll need to either set these attributes,
|
||||
# or override `get_queryset()`/`get_serializer_class()`.
|
||||
# If you are overriding a view method, it is important that you call
|
||||
|
@ -50,146 +33,16 @@ class GenericAPIView(views.APIView):
|
|||
queryset = None
|
||||
serializer_class = None
|
||||
|
||||
# If you want to use object lookups other than pk, set this attribute.
|
||||
# If you want to use object lookups other than pk, set 'lookup_field'.
|
||||
# For more complex lookup requirements override `get_object()`.
|
||||
lookup_field = 'pk'
|
||||
lookup_url_kwarg = None
|
||||
|
||||
# Pagination settings
|
||||
paginate_by = api_settings.PAGINATE_BY
|
||||
paginate_by_param = api_settings.PAGINATE_BY_PARAM
|
||||
max_paginate_by = api_settings.MAX_PAGINATE_BY
|
||||
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
|
||||
page_kwarg = 'page'
|
||||
|
||||
# The filter backend classes to use for queryset filtering
|
||||
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
|
||||
|
||||
# The following attribute may be subject to change,
|
||||
# and should be considered private API.
|
||||
paginator_class = Paginator
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""
|
||||
Extra context provided to the serializer class.
|
||||
"""
|
||||
return {
|
||||
'request': self.request,
|
||||
'format': self.format_kwarg,
|
||||
'view': self
|
||||
}
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
"""
|
||||
Return the serializer instance that should be used for validating and
|
||||
deserializing input, and for serializing output.
|
||||
"""
|
||||
serializer_class = self.get_serializer_class()
|
||||
kwargs['context'] = self.get_serializer_context()
|
||||
return serializer_class(*args, **kwargs)
|
||||
|
||||
def get_pagination_serializer(self, page):
|
||||
"""
|
||||
Return a serializer instance to use with paginated data.
|
||||
"""
|
||||
class SerializerClass(self.pagination_serializer_class):
|
||||
class Meta:
|
||||
object_serializer_class = self.get_serializer_class()
|
||||
|
||||
pagination_serializer_class = SerializerClass
|
||||
context = self.get_serializer_context()
|
||||
return pagination_serializer_class(instance=page, context=context)
|
||||
|
||||
def paginate_queryset(self, queryset):
|
||||
"""
|
||||
Paginate a queryset if required, either returning a page object,
|
||||
or `None` if pagination is not configured for this view.
|
||||
"""
|
||||
page_size = self.get_paginate_by()
|
||||
if not page_size:
|
||||
return None
|
||||
|
||||
paginator = self.paginator_class(queryset, page_size)
|
||||
page_kwarg = self.kwargs.get(self.page_kwarg)
|
||||
page_query_param = self.request.query_params.get(self.page_kwarg)
|
||||
page = page_kwarg or page_query_param or 1
|
||||
try:
|
||||
page_number = paginator.validate_number(page)
|
||||
except InvalidPage:
|
||||
if page == 'last':
|
||||
page_number = paginator.num_pages
|
||||
else:
|
||||
raise Http404(_('Choose a valid page number. Page numbers must be a whole number, or must be the string "last".'))
|
||||
|
||||
try:
|
||||
page = paginator.page(page_number)
|
||||
except InvalidPage as exc:
|
||||
error_format = _('Invalid page "{page_number}": {message}.')
|
||||
raise Http404(error_format.format(
|
||||
page_number=page_number, message=six.text_type(exc)
|
||||
))
|
||||
|
||||
return page
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
"""
|
||||
Given a queryset, filter it with whichever filter backend is in use.
|
||||
|
||||
You are unlikely to want to override this method, although you may need
|
||||
to call it either from a list view, or from a custom `get_object`
|
||||
method if you want to apply the configured filtering backend to the
|
||||
default queryset.
|
||||
"""
|
||||
for backend in self.get_filter_backends():
|
||||
queryset = backend().filter_queryset(self.request, queryset, self)
|
||||
return queryset
|
||||
|
||||
def get_filter_backends(self):
|
||||
"""
|
||||
Returns the list of filter backends that this view requires.
|
||||
"""
|
||||
return list(self.filter_backends)
|
||||
|
||||
# The following methods provide default implementations
|
||||
# that you may want to override for more complex cases.
|
||||
|
||||
def get_paginate_by(self):
|
||||
"""
|
||||
Return the size of pages to use with pagination.
|
||||
|
||||
If `PAGINATE_BY_PARAM` is set it will attempt to get the page size
|
||||
from a named query parameter in the url, eg. ?page_size=100
|
||||
|
||||
Otherwise defaults to using `self.paginate_by`.
|
||||
"""
|
||||
if self.paginate_by_param:
|
||||
try:
|
||||
return strict_positive_int(
|
||||
self.request.query_params[self.paginate_by_param],
|
||||
cutoff=self.max_paginate_by
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
return self.paginate_by
|
||||
|
||||
def get_serializer_class(self):
|
||||
"""
|
||||
Return the class to use for the serializer.
|
||||
Defaults to using `self.serializer_class`.
|
||||
|
||||
You may want to override this if you need to provide different
|
||||
serializations depending on the incoming request.
|
||||
|
||||
(Eg. admins get full serialization, others get basic serialization)
|
||||
"""
|
||||
assert self.serializer_class is not None, (
|
||||
"'%s' should either include a `serializer_class` attribute, "
|
||||
"or override the `get_serializer_class()` method."
|
||||
% self.__class__.__name__
|
||||
)
|
||||
|
||||
return self.serializer_class
|
||||
# The style to use for queryset pagination.
|
||||
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
|
@ -246,6 +99,73 @@ class GenericAPIView(views.APIView):
|
|||
|
||||
return obj
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
"""
|
||||
Return the serializer instance that should be used for validating and
|
||||
deserializing input, and for serializing output.
|
||||
"""
|
||||
serializer_class = self.get_serializer_class()
|
||||
kwargs['context'] = self.get_serializer_context()
|
||||
return serializer_class(*args, **kwargs)
|
||||
|
||||
def get_serializer_class(self):
|
||||
"""
|
||||
Return the class to use for the serializer.
|
||||
Defaults to using `self.serializer_class`.
|
||||
|
||||
You may want to override this if you need to provide different
|
||||
serializations depending on the incoming request.
|
||||
|
||||
(Eg. admins get full serialization, others get basic serialization)
|
||||
"""
|
||||
assert self.serializer_class is not None, (
|
||||
"'%s' should either include a `serializer_class` attribute, "
|
||||
"or override the `get_serializer_class()` method."
|
||||
% self.__class__.__name__
|
||||
)
|
||||
|
||||
return self.serializer_class
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""
|
||||
Extra context provided to the serializer class.
|
||||
"""
|
||||
return {
|
||||
'request': self.request,
|
||||
'format': self.format_kwarg,
|
||||
'view': self
|
||||
}
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
"""
|
||||
Given a queryset, filter it with whichever filter backend is in use.
|
||||
|
||||
You are unlikely to want to override this method, although you may need
|
||||
to call it either from a list view, or from a custom `get_object`
|
||||
method if you want to apply the configured filtering backend to the
|
||||
default queryset.
|
||||
"""
|
||||
for backend in list(self.filter_backends):
|
||||
queryset = backend().filter_queryset(self.request, queryset, self)
|
||||
return queryset
|
||||
|
||||
@property
|
||||
def pager(self):
|
||||
if not hasattr(self, '_pager'):
|
||||
if self.pagination_class is None:
|
||||
self._pager = None
|
||||
else:
|
||||
self._pager = self.pagination_class()
|
||||
return self._pager
|
||||
|
||||
def paginate_queryset(self, queryset):
|
||||
if self.pager is None:
|
||||
return None
|
||||
return self.pager.paginate_queryset(queryset, self.request, view=self)
|
||||
|
||||
def get_paginated_response(self, objects):
|
||||
return self.pager.get_paginated_response(objects)
|
||||
|
||||
|
||||
# Concrete view classes that provide method handlers
|
||||
# by composing the mixin classes with the base view.
|
||||
|
|
|
@ -5,7 +5,6 @@ We don't bind behaviour to http method handlers yet,
|
|||
which allows mixin classes to be composed in interesting ways.
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
|
@ -37,12 +36,14 @@ class ListModelMixin(object):
|
|||
List a queryset.
|
||||
"""
|
||||
def list(self, request, *args, **kwargs):
|
||||
instance = self.filter_queryset(self.get_queryset())
|
||||
page = self.paginate_queryset(instance)
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
serializer = self.get_pagination_serializer(page)
|
||||
else:
|
||||
serializer = self.get_serializer(instance, many=True)
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
|
|
|
@ -3,87 +3,192 @@ Pagination serializers determine the structure of the output that should
|
|||
be used for paginated responses.
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
from rest_framework import serializers
|
||||
from django.core.paginator import InvalidPage, Paginator as DjangoPaginator
|
||||
from django.utils import six
|
||||
from django.utils.translation import ugettext as _
|
||||
from rest_framework.compat import OrderedDict
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.templatetags.rest_framework import replace_query_param
|
||||
|
||||
|
||||
class NextPageField(serializers.Field):
|
||||
def _strict_positive_int(integer_string, cutoff=None):
|
||||
"""
|
||||
Field that returns a link to the next page in paginated results.
|
||||
Cast a string to a strictly positive integer.
|
||||
"""
|
||||
page_field = 'page'
|
||||
|
||||
def to_representation(self, value):
|
||||
if not value.has_next():
|
||||
return None
|
||||
page = value.next_page_number()
|
||||
request = self.context.get('request')
|
||||
url = request and request.build_absolute_uri() or ''
|
||||
return replace_query_param(url, self.page_field, page)
|
||||
ret = int(integer_string)
|
||||
if ret <= 0:
|
||||
raise ValueError()
|
||||
if cutoff:
|
||||
ret = min(ret, cutoff)
|
||||
return ret
|
||||
|
||||
|
||||
class PreviousPageField(serializers.Field):
|
||||
class BasePagination(object):
|
||||
def paginate_queryset(self, queryset, request):
|
||||
raise NotImplemented('paginate_queryset() must be implemented.')
|
||||
|
||||
def get_paginated_response(self, data, page, request):
|
||||
raise NotImplemented('get_paginated_response() must be implemented.')
|
||||
|
||||
|
||||
class PageNumberPagination(BasePagination):
|
||||
"""
|
||||
Field that returns a link to the previous page in paginated results.
|
||||
A simple page number based style that supports page numbers as
|
||||
query parameters. For example:
|
||||
|
||||
http://api.example.org/accounts/?page=4
|
||||
http://api.example.org/accounts/?page=4&page_size=100
|
||||
"""
|
||||
page_field = 'page'
|
||||
# The default page size.
|
||||
# Defaults to `None`, meaning pagination is disabled.
|
||||
paginate_by = api_settings.PAGINATE_BY
|
||||
|
||||
def to_representation(self, value):
|
||||
if not value.has_previous():
|
||||
return None
|
||||
page = value.previous_page_number()
|
||||
request = self.context.get('request')
|
||||
url = request and request.build_absolute_uri() or ''
|
||||
return replace_query_param(url, self.page_field, page)
|
||||
# Client can control the page using this query parameter.
|
||||
page_query_param = 'page'
|
||||
|
||||
# Client can control the page size using this query parameter.
|
||||
# Default is 'None'. Set to eg 'page_size' to enable usage.
|
||||
paginate_by_param = api_settings.PAGINATE_BY_PARAM
|
||||
|
||||
class DefaultObjectSerializer(serializers.ReadOnlyField):
|
||||
"""
|
||||
If no object serializer is specified, then this serializer will be applied
|
||||
as the default.
|
||||
"""
|
||||
# Set to an integer to limit the maximum page size the client may request.
|
||||
# Only relevant if 'paginate_by_param' has also been set.
|
||||
max_paginate_by = api_settings.MAX_PAGINATE_BY
|
||||
|
||||
def __init__(self, source=None, many=None, context=None):
|
||||
# Note: Swallow context and many kwargs - only required for
|
||||
# eg. ModelSerializer.
|
||||
super(DefaultObjectSerializer, self).__init__(source=source)
|
||||
|
||||
|
||||
class BasePaginationSerializer(serializers.Serializer):
|
||||
"""
|
||||
A base class for pagination serializers to inherit from,
|
||||
to make implementing custom serializers more easy.
|
||||
"""
|
||||
results_field = 'results'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def paginate_queryset(self, queryset, request, view):
|
||||
"""
|
||||
Override init to add in the object serializer field on-the-fly.
|
||||
Paginate a queryset if required, either returning a page object,
|
||||
or `None` if pagination is not configured for this view.
|
||||
"""
|
||||
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
|
||||
results_field = self.results_field
|
||||
for attr in (
|
||||
'paginate_by', 'page_query_param',
|
||||
'paginate_by_param', 'max_paginate_by'
|
||||
):
|
||||
if hasattr(view, attr):
|
||||
setattr(self, attr, getattr(view, attr))
|
||||
|
||||
page_size = self.get_page_size(request)
|
||||
if not page_size:
|
||||
return None
|
||||
|
||||
paginator = DjangoPaginator(queryset, page_size)
|
||||
page_string = request.query_params.get(self.page_query_param, 1)
|
||||
try:
|
||||
page_number = paginator.validate_number(page_string)
|
||||
except InvalidPage:
|
||||
if page_string == 'last':
|
||||
page_number = paginator.num_pages
|
||||
else:
|
||||
msg = _(
|
||||
'Choose a valid page number. Page numbers must be a '
|
||||
'whole number, or must be the string "last".'
|
||||
)
|
||||
raise NotFound(msg)
|
||||
|
||||
try:
|
||||
object_serializer = self.Meta.object_serializer_class
|
||||
except AttributeError:
|
||||
object_serializer = DefaultObjectSerializer
|
||||
self.page = paginator.page(page_number)
|
||||
except InvalidPage as exc:
|
||||
msg = _('Invalid page "{page_number}": {message}.').format(
|
||||
page_number=page_number, message=six.text_type(exc)
|
||||
)
|
||||
raise NotFound(msg)
|
||||
|
||||
self.request = request
|
||||
return self.page
|
||||
|
||||
def get_paginated_response(self, objects):
|
||||
return Response(OrderedDict([
|
||||
('count', self.page.paginator.count),
|
||||
('next', self.get_next_link()),
|
||||
('previous', self.get_previous_link()),
|
||||
('results', objects)
|
||||
]))
|
||||
|
||||
def get_page_size(self, request):
|
||||
if self.paginate_by_param:
|
||||
try:
|
||||
return _strict_positive_int(
|
||||
request.query_params[self.paginate_by_param],
|
||||
cutoff=self.max_paginate_by
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
return self.paginate_by
|
||||
|
||||
def get_next_link(self):
|
||||
if not self.page.has_next():
|
||||
return None
|
||||
url = self.request.build_absolute_uri()
|
||||
page_number = self.page.next_page_number()
|
||||
return replace_query_param(url, self.page_query_param, page_number)
|
||||
|
||||
def get_previous_link(self):
|
||||
if not self.page.has_previous():
|
||||
return None
|
||||
url = self.request.build_absolute_uri()
|
||||
page_number = self.page.previous_page_number()
|
||||
return replace_query_param(url, self.page_query_param, page_number)
|
||||
|
||||
|
||||
class LimitOffsetPagination(BasePagination):
|
||||
"""
|
||||
A limit/offset based style. For example:
|
||||
|
||||
http://api.example.org/accounts/?limit=100
|
||||
http://api.example.org/accounts/?offset=400&limit=100
|
||||
"""
|
||||
default_limit = api_settings.PAGINATE_BY
|
||||
limit_query_param = 'limit'
|
||||
offset_query_param = 'offset'
|
||||
max_limit = None
|
||||
|
||||
def paginate_queryset(self, queryset, request, view):
|
||||
self.limit = self.get_limit(request)
|
||||
self.offset = self.get_offset(request)
|
||||
self.count = queryset.count()
|
||||
self.request = request
|
||||
return queryset[self.offset:self.offset + self.limit]
|
||||
|
||||
def get_paginated_response(self, objects):
|
||||
return Response(OrderedDict([
|
||||
('count', self.count),
|
||||
('next', self.get_next_link()),
|
||||
('previous', self.get_previous_link()),
|
||||
('results', objects)
|
||||
]))
|
||||
|
||||
def get_limit(self, request):
|
||||
if self.limit_query_param:
|
||||
try:
|
||||
return _strict_positive_int(
|
||||
request.query_params[self.limit_query_param],
|
||||
cutoff=self.max_limit
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
return self.default_limit
|
||||
|
||||
def get_offset(self, request):
|
||||
try:
|
||||
list_serializer_class = object_serializer.Meta.list_serializer_class
|
||||
except AttributeError:
|
||||
list_serializer_class = serializers.ListSerializer
|
||||
return _strict_positive_int(
|
||||
request.query_params[self.offset_query_param],
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
return 0
|
||||
|
||||
self.fields[results_field] = list_serializer_class(
|
||||
child=object_serializer(),
|
||||
source='object_list'
|
||||
)
|
||||
self.fields[results_field].bind(field_name=results_field, parent=self)
|
||||
def get_next_link(self, page):
|
||||
if self.offset + self.limit >= self.count:
|
||||
return None
|
||||
url = self.request.build_absolute_uri()
|
||||
offset = self.offset + self.limit
|
||||
return replace_query_param(url, self.offset_query_param, offset)
|
||||
|
||||
|
||||
class PaginationSerializer(BasePaginationSerializer):
|
||||
"""
|
||||
A default implementation of a pagination serializer.
|
||||
"""
|
||||
count = serializers.ReadOnlyField(source='paginator.count')
|
||||
next = NextPageField(source='*')
|
||||
previous = PreviousPageField(source='*')
|
||||
def get_previous_link(self, page):
|
||||
if self.offset - self.limit < 0:
|
||||
return None
|
||||
url = self.request.build_absolute_uri()
|
||||
offset = self.offset - self.limit
|
||||
return replace_query_param(url, self.offset_query_param, offset)
|
||||
|
|
|
@ -49,7 +49,7 @@ DEFAULTS = {
|
|||
'DEFAULT_VERSIONING_CLASS': None,
|
||||
|
||||
# Generic view behavior
|
||||
'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
|
||||
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
|
||||
'DEFAULT_FILTER_BACKENDS': (),
|
||||
|
||||
# Throttling
|
||||
|
@ -130,7 +130,7 @@ IMPORT_STRINGS = (
|
|||
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
|
||||
'DEFAULT_METADATA_CLASS',
|
||||
'DEFAULT_VERSIONING_CLASS',
|
||||
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
|
||||
'DEFAULT_PAGINATION_CLASS',
|
||||
'DEFAULT_FILTER_BACKENDS',
|
||||
'EXCEPTION_HANDLER',
|
||||
'TEST_REQUEST_RENDERER_CLASSES',
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from __future__ import unicode_literals
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
from django.core.paginator import Paginator
|
||||
from django.test import TestCase
|
||||
from django.utils import unittest
|
||||
from rest_framework import generics, serializers, status, pagination, filters
|
||||
from rest_framework import generics, serializers, status, filters
|
||||
from rest_framework.compat import django_filters
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from .models import BasicModel, FilterableItem
|
||||
|
@ -238,45 +237,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
|
|||
self.assertEqual(response.data['previous'], None)
|
||||
|
||||
|
||||
class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
|
||||
class Meta:
|
||||
object_serializer_class = serializers.Serializer
|
||||
|
||||
|
||||
class UnitTestPagination(TestCase):
|
||||
"""
|
||||
Unit tests for pagination of primitive objects.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
|
||||
paginator = Paginator(self.objects, 10)
|
||||
self.first_page = paginator.page(1)
|
||||
self.last_page = paginator.page(3)
|
||||
|
||||
def test_native_pagination(self):
|
||||
serializer = pagination.PaginationSerializer(self.first_page)
|
||||
self.assertEqual(serializer.data['count'], 26)
|
||||
self.assertEqual(serializer.data['next'], '?page=2')
|
||||
self.assertEqual(serializer.data['previous'], None)
|
||||
self.assertEqual(serializer.data['results'], self.objects[:10])
|
||||
|
||||
serializer = pagination.PaginationSerializer(self.last_page)
|
||||
self.assertEqual(serializer.data['count'], 26)
|
||||
self.assertEqual(serializer.data['next'], None)
|
||||
self.assertEqual(serializer.data['previous'], '?page=2')
|
||||
self.assertEqual(serializer.data['results'], self.objects[20:])
|
||||
|
||||
def test_context_available_in_result(self):
|
||||
"""
|
||||
Ensure context gets passed through to the object serializer.
|
||||
"""
|
||||
serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
|
||||
serializer.data
|
||||
results = serializer.fields[serializer.results_field]
|
||||
self.assertEqual(serializer.context, results.context)
|
||||
|
||||
|
||||
class TestUnpaginated(TestCase):
|
||||
"""
|
||||
Tests for list views without pagination.
|
||||
|
@ -377,177 +337,3 @@ class TestMaxPaginateByParam(TestCase):
|
|||
request = factory.get('/')
|
||||
response = self.view(request).render()
|
||||
self.assertEqual(response.data['results'], self.data[:3])
|
||||
|
||||
|
||||
# Tests for context in pagination serializers
|
||||
|
||||
class CustomField(serializers.ReadOnlyField):
|
||||
def to_native(self, value):
|
||||
if 'view' not in self.context:
|
||||
raise RuntimeError("context isn't getting passed into custom field")
|
||||
return "value"
|
||||
|
||||
|
||||
class BasicModelSerializer(serializers.Serializer):
|
||||
text = CustomField()
|
||||
|
||||
def to_native(self, value):
|
||||
if 'view' not in self.context:
|
||||
raise RuntimeError("context isn't getting passed into serializer")
|
||||
return super(BasicSerializer, self).to_native(value)
|
||||
|
||||
|
||||
class TestContextPassedToCustomField(TestCase):
|
||||
def setUp(self):
|
||||
BasicModel.objects.create(text='ala ma kota')
|
||||
|
||||
def test_with_pagination(self):
|
||||
class ListView(generics.ListCreateAPIView):
|
||||
queryset = BasicModel.objects.all()
|
||||
serializer_class = BasicModelSerializer
|
||||
paginate_by = 1
|
||||
|
||||
self.view = ListView.as_view()
|
||||
request = factory.get('/')
|
||||
response = self.view(request).render()
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
# Tests for custom pagination serializers
|
||||
|
||||
class LinksSerializer(serializers.Serializer):
|
||||
next = pagination.NextPageField(source='*')
|
||||
prev = pagination.PreviousPageField(source='*')
|
||||
|
||||
|
||||
class CustomPaginationSerializer(pagination.BasePaginationSerializer):
|
||||
links = LinksSerializer(source='*') # Takes the page object as the source
|
||||
total_results = serializers.ReadOnlyField(source='paginator.count')
|
||||
|
||||
results_field = 'objects'
|
||||
|
||||
|
||||
class CustomFooSerializer(serializers.Serializer):
|
||||
foo = serializers.CharField()
|
||||
|
||||
|
||||
class CustomFooPaginationSerializer(pagination.PaginationSerializer):
|
||||
class Meta:
|
||||
object_serializer_class = CustomFooSerializer
|
||||
|
||||
|
||||
class TestCustomPaginationSerializer(TestCase):
|
||||
def setUp(self):
|
||||
objects = ['john', 'paul', 'george', 'ringo']
|
||||
paginator = Paginator(objects, 2)
|
||||
self.page = paginator.page(1)
|
||||
|
||||
def test_custom_pagination_serializer(self):
|
||||
request = APIRequestFactory().get('/foobar')
|
||||
serializer = CustomPaginationSerializer(
|
||||
instance=self.page,
|
||||
context={'request': request}
|
||||
)
|
||||
expected = {
|
||||
'links': {
|
||||
'next': 'http://testserver/foobar?page=2',
|
||||
'prev': None
|
||||
},
|
||||
'total_results': 4,
|
||||
'objects': ['john', 'paul']
|
||||
}
|
||||
self.assertEqual(serializer.data, expected)
|
||||
|
||||
def test_custom_pagination_serializer_with_custom_object_serializer(self):
|
||||
objects = [
|
||||
{'foo': 'bar'},
|
||||
{'foo': 'spam'}
|
||||
]
|
||||
paginator = Paginator(objects, 1)
|
||||
page = paginator.page(1)
|
||||
serializer = CustomFooPaginationSerializer(page)
|
||||
serializer.data
|
||||
|
||||
|
||||
class NonIntegerPage(object):
|
||||
|
||||
def __init__(self, paginator, object_list, prev_token, token, next_token):
|
||||
self.paginator = paginator
|
||||
self.object_list = object_list
|
||||
self.prev_token = prev_token
|
||||
self.token = token
|
||||
self.next_token = next_token
|
||||
|
||||
def has_next(self):
|
||||
return not not self.next_token
|
||||
|
||||
def next_page_number(self):
|
||||
return self.next_token
|
||||
|
||||
def has_previous(self):
|
||||
return not not self.prev_token
|
||||
|
||||
def previous_page_number(self):
|
||||
return self.prev_token
|
||||
|
||||
|
||||
class NonIntegerPaginator(object):
|
||||
|
||||
def __init__(self, object_list, per_page):
|
||||
self.object_list = object_list
|
||||
self.per_page = per_page
|
||||
|
||||
def count(self):
|
||||
# pretend like we don't know how many pages we have
|
||||
return None
|
||||
|
||||
def page(self, token=None):
|
||||
if token:
|
||||
try:
|
||||
first = self.object_list.index(token)
|
||||
except ValueError:
|
||||
first = 0
|
||||
else:
|
||||
first = 0
|
||||
n = len(self.object_list)
|
||||
last = min(first + self.per_page, n)
|
||||
prev_token = self.object_list[last - (2 * self.per_page)] if first else None
|
||||
next_token = self.object_list[last] if last < n else None
|
||||
return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
|
||||
|
||||
|
||||
class TestNonIntegerPagination(TestCase):
|
||||
def test_custom_pagination_serializer(self):
|
||||
objects = ['john', 'paul', 'george', 'ringo']
|
||||
paginator = NonIntegerPaginator(objects, 2)
|
||||
|
||||
request = APIRequestFactory().get('/foobar')
|
||||
serializer = CustomPaginationSerializer(
|
||||
instance=paginator.page(),
|
||||
context={'request': request}
|
||||
)
|
||||
expected = {
|
||||
'links': {
|
||||
'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
|
||||
'prev': None
|
||||
},
|
||||
'total_results': None,
|
||||
'objects': objects[:2]
|
||||
}
|
||||
self.assertEqual(serializer.data, expected)
|
||||
|
||||
request = APIRequestFactory().get('/foobar')
|
||||
serializer = CustomPaginationSerializer(
|
||||
instance=paginator.page('george'),
|
||||
context={'request': request}
|
||||
)
|
||||
expected = {
|
||||
'links': {
|
||||
'next': None,
|
||||
'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
|
||||
},
|
||||
'total_results': None,
|
||||
'objects': objects[2:]
|
||||
}
|
||||
self.assertEqual(serializer.data, expected)
|
||||
|
|
Loading…
Reference in New Issue
Block a user