This commit is contained in:
David Medina 2013-10-10 09:38:39 -07:00
commit dac81ded1a
4 changed files with 325 additions and 13 deletions

View File

@ -11,20 +11,10 @@ from django.utils.translation import ugettext as _
from rest_framework import views, mixins, exceptions
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
from rest_framework.pagination import strict_positive_int
import warnings
def strict_positive_int(integer_string, cutoff=None):
"""
Cast a string to a strictly positive integer.
"""
ret = int(integer_string)
if ret <= 0:
raise ValueError()
if cutoff:
ret = min(ret, cutoff)
return ret
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to raise 404

View File

@ -10,6 +10,7 @@ from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import clone_request
from rest_framework import pagination
import warnings
@ -187,3 +188,35 @@ class DestroyModelMixin(object):
obj = self.get_object()
obj.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
class OffsetLimitPaginationMixin(object):
offset_kwarg = 'offset'
paginate_by_param = 'limit'
pagination_serializer_class = pagination.OffsetLimitPaginationSerializer
def paginate_queryset(self, queryset):
limit = self.get_paginate_by()
if not limit:
return # pagination not configured
offset_kwarg = self.kwargs.get(self.offset_kwarg)
offset_query_param = self.request.QUERY_PARAMS.get(self.offset_kwarg)
offset = offset_kwarg or offset_query_param or 0
try:
offset_number = pagination.strict_positive_int(offset)
except ValueError:
offset_number = 0
return pagination.OffsetLimitPage(queryset, offset_number, limit)
class LinkPaginationMixin(object):
pagination_serializer_class = pagination.LinkPaginationSerializer
def paginate_queryset(self, queryset, page_size=None):
page = super(LinkPaginationMixin, self).paginate_queryset(
queryset, page_size)
if page is not None:
page_ser = self.get_pagination_serializer(page)
self.headers.update(page_ser.get_link_header())
self.object_list = page.object_list
return None # Don't use pagination serializer on response

View File

@ -7,6 +7,18 @@ from rest_framework import serializers
from rest_framework.templatetags.rest_framework import replace_query_param
def strict_positive_int(integer_string, cutoff=None):
"""
Cast a string to a strictly positive integer.
"""
ret = int(integer_string)
if ret <= 0:
raise ValueError()
if cutoff:
ret = min(ret, cutoff)
return ret
class NextPageField(serializers.Field):
"""
Field that returns a link to the next page in paginated results.
@ -37,6 +49,35 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page)
class FirstPageField(serializers.Field):
"""
Field that returns a link to the first page in paginated results.
"""
page_field = 'page'
def to_native(self, value):
if not value.has_previous():
return None
request = self.context.get('request')
url = request and request.build_absolute_uri() or ''
return replace_query_param(url, self.page_field, 1)
class LastPageField(serializers.Field):
"""
Field that returns a link to the previous page in paginated results.
"""
page_field = 'page'
def to_native(self, value):
if not value.has_next():
return None
page = value.paginator.num_pages
request = self.context.get('request')
url = request and request.build_absolute_uri() or ''
return replace_query_param(url, self.page_field, page)
class DefaultObjectSerializer(serializers.Field):
"""
If no object serializer is specified, then this serializer will be applied
@ -82,7 +123,8 @@ class BasePaginationSerializer(serializers.Serializer):
else:
context_kwarg = {}
self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
self.fields[results_field] = object_serializer(source='object_list',
**context_kwarg)
class PaginationSerializer(BasePaginationSerializer):
@ -92,3 +134,42 @@ class PaginationSerializer(BasePaginationSerializer):
count = serializers.Field(source='paginator.count')
next = NextPageField(source='*')
previous = PreviousPageField(source='*')
class LinkPaginationSerializer(serializers.Serializer):
""" Pagination serializer in order to build Link header """
first = FirstPageField(source='*')
next = NextPageField(source='*')
previous = PreviousPageField(source='*')
last = LastPageField(source='*')
def get_link_header(self):
link_keader_items = [
'<%s>; rel="%s"' % (link, rel)
for rel, link in self.data.items()
if link is not None
]
return {'Link': ', '.join(link_keader_items)}
class OffsetLimitPage(object):
"""
A base class to allow offset and limit when listing a queryset.
"""
def __init__(self, queryset, offset, limit):
self.count = self._set_count(queryset)
self.object_list = queryset[offset:offset + limit]
def _set_count(self, queryset):
try:
return queryset.count()
except (AttributeError, TypeError):
# AttributeError if object_list has no count() method.
# TypeError if object_list.count() requires arguments
# (i.e. is of type list).
return len(queryset)
class OffsetLimitPaginationSerializer(BasePaginationSerializer):
""" OffsetLimitPage serializer """
count = serializers.Field()

View File

@ -5,7 +5,8 @@ from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework import (generics, status, pagination, filters, serializers,
mixins)
from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
@ -13,6 +14,48 @@ from rest_framework.tests.models import BasicModel
factory = APIRequestFactory()
def parse_header_link(value):
# Got it from python-requests code
# https://github.com/kennethreitz/requests/blob/master/requests/utils.py#L467
def parse_header_links(value):
"""Return a dict of parsed link headers proxies.
i.e. Link: <http:/.../front.jpeg>; rel=front; type="image/jpeg",<http://.../back.jpeg>; rel=back;type="image/jpeg"
"""
links = []
replace_chars = " '\""
for val in value.split(","):
try:
url, params = val.split(";", 1)
except ValueError:
url, params = val, ''
link = {}
link["url"] = url.strip("<> '\"")
for param in params.split(";"):
try:
key, value = param.split("=")
except ValueError:
break
link[key.strip(replace_chars)] = value.strip(replace_chars)
links.append(link)
return links
links = parse_header_links(value)
link_map = {}
for link in links:
link_map[link['rel']] = link['url']
return link_map
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
@ -52,6 +95,18 @@ class MaxPaginateByView(generics.ListAPIView):
paginate_by_param = 'page_size'
class LinkPaginationView(mixins.LinkPaginationMixin,
generics.ListCreateAPIView):
model = BasicModel
paginate_by = 10
class OffsetLimitPaginationView(mixins.OffsetLimitPaginationMixin,
generics.ListCreateAPIView):
model = BasicModel
paginate_by = 10
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@ -103,6 +158,102 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['previous'], None)
class IntegrationTestLinkPagination(TestCase):
def setUp(self):
"""
Create 26 BasicModel instances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = LinkPaginationView.as_view()
def test_get_paginated_root_view(self):
"""
GET requests to paginated ListCreateAPIView should return paginated results.
"""
request = factory.get('/')
# Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data[0:10])
link_header = parse_header_link(response['Link'])
self.assertTrue('next' in link_header)
self.assertTrue('last' in link_header)
self.assertFalse('previous' in link_header)
self.assertFalse('first' in link_header)
request = factory.get(link_header['next'])
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data[10:20])
link_header = parse_header_link(response['Link'])
self.assertTrue('next' in link_header)
self.assertTrue('last' in link_header)
self.assertTrue('previous' in link_header)
self.assertTrue('first' in link_header)
request = factory.get(link_header['last'])
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data[20:])
link_header = parse_header_link(response['Link'])
self.assertFalse('next' in link_header)
self.assertFalse('last' in link_header)
self.assertTrue('previous' in link_header)
self.assertTrue('first' in link_header)
class IntegrationTestOffsetLimitPagination(TestCase):
def setUp(self):
"""
Create 26 BasicModel instances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = OffsetLimitPaginationView.as_view()
def test_get_paginated_root_view(self):
"""
GET requests to paginated ListCreateAPIView should return paginated results.
"""
request = factory.get('/')
# Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['results'], self.data[0:10])
self.assertEqual(response.data['count'], 26)
request = factory.get('/', {'offset': 10})
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['results'], self.data[10:20])
self.assertEqual(response.data['count'], 26)
request = factory.get('/', {'offset': 20, 'limit': 3})
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['results'], self.data[20:23])
self.assertEqual(response.data['count'], 26)
class IntegrationTestPaginationAndFiltering(TestCase):
def setUp(self):
@ -258,6 +409,63 @@ class UnitTestPagination(TestCase):
self.assertEqual(serializer.context, results.context)
class UnitTestLinkPagination(TestCase):
def setUp(self):
self.objects = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
paginator = Paginator(self.objects, 2)
self.first_page = paginator.page(1)
self.third_page = paginator.page(3)
self.last_page = paginator.page(5)
def test_native_pagination(self):
serializer = pagination.LinkPaginationSerializer(self.first_page)
self.assertEqual(serializer.data['next'], '?page=2')
self.assertEqual(serializer.data['previous'], None)
self.assertEqual(serializer.data['first'], None)
self.assertEqual(serializer.data['last'], '?page=5')
serializer = pagination.LinkPaginationSerializer(self.third_page)
self.assertEqual(serializer.data['next'], '?page=4')
self.assertEqual(serializer.data['previous'], '?page=2')
self.assertEqual(serializer.data['first'], '?page=1')
self.assertEqual(serializer.data['last'], '?page=5')
serializer = pagination.LinkPaginationSerializer(self.last_page)
self.assertEqual(serializer.data['next'], None)
self.assertEqual(serializer.data['previous'], '?page=4')
self.assertEqual(serializer.data['first'], '?page=1')
self.assertEqual(serializer.data['last'], None)
class UnitTestOffsetLimitPagination(TestCase):
def setUp(self):
self.objects = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
self.first_page = pagination.OffsetLimitPage(
self.objects, offset=0, limit=2)
self.third_page = pagination.OffsetLimitPage(
self.objects, offset=4, limit=2)
self.last_page = pagination.OffsetLimitPage(
self.objects, offset=8, limit=2)
def test_native_pagination(self):
serializer = pagination.OffsetLimitPaginationSerializer(
self.first_page)
self.assertEqual(serializer.data['count'], 10)
self.assertEqual(serializer.data['results'], self.objects[:2])
serializer = pagination.OffsetLimitPaginationSerializer(
self.third_page)
self.assertEqual(serializer.data['count'], 10)
self.assertEqual(serializer.data['results'], self.objects[4:6])
serializer = pagination.OffsetLimitPaginationSerializer(
self.last_page)
self.assertEqual(serializer.data['count'], 10)
self.assertEqual(serializer.data['results'], self.objects[8:])
class TestUnpaginated(TestCase):
"""
Tests for list views without pagination.