mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-02 11:30:12 +03:00
Merge 986581d3a7
into f18158358d
This commit is contained in:
commit
dac81ded1a
|
@ -11,20 +11,10 @@ from django.utils.translation import ugettext as _
|
|||
from rest_framework import views, mixins, exceptions
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.pagination import strict_positive_int
|
||||
import warnings
|
||||
|
||||
|
||||
def strict_positive_int(integer_string, cutoff=None):
|
||||
"""
|
||||
Cast a string to a strictly positive integer.
|
||||
"""
|
||||
ret = int(integer_string)
|
||||
if ret <= 0:
|
||||
raise ValueError()
|
||||
if cutoff:
|
||||
ret = min(ret, cutoff)
|
||||
return ret
|
||||
|
||||
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
|
||||
"""
|
||||
Same as Django's standard shortcut, but make sure to raise 404
|
||||
|
|
|
@ -10,6 +10,7 @@ from django.http import Http404
|
|||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.request import clone_request
|
||||
from rest_framework import pagination
|
||||
import warnings
|
||||
|
||||
|
||||
|
@ -187,3 +188,35 @@ class DestroyModelMixin(object):
|
|||
obj = self.get_object()
|
||||
obj.delete()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
class OffsetLimitPaginationMixin(object):
|
||||
offset_kwarg = 'offset'
|
||||
paginate_by_param = 'limit'
|
||||
pagination_serializer_class = pagination.OffsetLimitPaginationSerializer
|
||||
|
||||
def paginate_queryset(self, queryset):
|
||||
limit = self.get_paginate_by()
|
||||
if not limit:
|
||||
return # pagination not configured
|
||||
offset_kwarg = self.kwargs.get(self.offset_kwarg)
|
||||
offset_query_param = self.request.QUERY_PARAMS.get(self.offset_kwarg)
|
||||
offset = offset_kwarg or offset_query_param or 0
|
||||
try:
|
||||
offset_number = pagination.strict_positive_int(offset)
|
||||
except ValueError:
|
||||
offset_number = 0
|
||||
return pagination.OffsetLimitPage(queryset, offset_number, limit)
|
||||
|
||||
|
||||
class LinkPaginationMixin(object):
|
||||
pagination_serializer_class = pagination.LinkPaginationSerializer
|
||||
|
||||
def paginate_queryset(self, queryset, page_size=None):
|
||||
page = super(LinkPaginationMixin, self).paginate_queryset(
|
||||
queryset, page_size)
|
||||
if page is not None:
|
||||
page_ser = self.get_pagination_serializer(page)
|
||||
self.headers.update(page_ser.get_link_header())
|
||||
self.object_list = page.object_list
|
||||
return None # Don't use pagination serializer on response
|
||||
|
|
|
@ -7,6 +7,18 @@ from rest_framework import serializers
|
|||
from rest_framework.templatetags.rest_framework import replace_query_param
|
||||
|
||||
|
||||
def strict_positive_int(integer_string, cutoff=None):
|
||||
"""
|
||||
Cast a string to a strictly positive integer.
|
||||
"""
|
||||
ret = int(integer_string)
|
||||
if ret <= 0:
|
||||
raise ValueError()
|
||||
if cutoff:
|
||||
ret = min(ret, cutoff)
|
||||
return ret
|
||||
|
||||
|
||||
class NextPageField(serializers.Field):
|
||||
"""
|
||||
Field that returns a link to the next page in paginated results.
|
||||
|
@ -37,6 +49,35 @@ class PreviousPageField(serializers.Field):
|
|||
return replace_query_param(url, self.page_field, page)
|
||||
|
||||
|
||||
class FirstPageField(serializers.Field):
|
||||
"""
|
||||
Field that returns a link to the first page in paginated results.
|
||||
"""
|
||||
page_field = 'page'
|
||||
|
||||
def to_native(self, value):
|
||||
if not value.has_previous():
|
||||
return None
|
||||
request = self.context.get('request')
|
||||
url = request and request.build_absolute_uri() or ''
|
||||
return replace_query_param(url, self.page_field, 1)
|
||||
|
||||
|
||||
class LastPageField(serializers.Field):
|
||||
"""
|
||||
Field that returns a link to the previous page in paginated results.
|
||||
"""
|
||||
page_field = 'page'
|
||||
|
||||
def to_native(self, value):
|
||||
if not value.has_next():
|
||||
return None
|
||||
page = value.paginator.num_pages
|
||||
request = self.context.get('request')
|
||||
url = request and request.build_absolute_uri() or ''
|
||||
return replace_query_param(url, self.page_field, page)
|
||||
|
||||
|
||||
class DefaultObjectSerializer(serializers.Field):
|
||||
"""
|
||||
If no object serializer is specified, then this serializer will be applied
|
||||
|
@ -82,7 +123,8 @@ class BasePaginationSerializer(serializers.Serializer):
|
|||
else:
|
||||
context_kwarg = {}
|
||||
|
||||
self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
|
||||
self.fields[results_field] = object_serializer(source='object_list',
|
||||
**context_kwarg)
|
||||
|
||||
|
||||
class PaginationSerializer(BasePaginationSerializer):
|
||||
|
@ -92,3 +134,42 @@ class PaginationSerializer(BasePaginationSerializer):
|
|||
count = serializers.Field(source='paginator.count')
|
||||
next = NextPageField(source='*')
|
||||
previous = PreviousPageField(source='*')
|
||||
|
||||
|
||||
class LinkPaginationSerializer(serializers.Serializer):
|
||||
""" Pagination serializer in order to build Link header """
|
||||
first = FirstPageField(source='*')
|
||||
next = NextPageField(source='*')
|
||||
previous = PreviousPageField(source='*')
|
||||
last = LastPageField(source='*')
|
||||
|
||||
def get_link_header(self):
|
||||
link_keader_items = [
|
||||
'<%s>; rel="%s"' % (link, rel)
|
||||
for rel, link in self.data.items()
|
||||
if link is not None
|
||||
]
|
||||
return {'Link': ', '.join(link_keader_items)}
|
||||
|
||||
|
||||
class OffsetLimitPage(object):
|
||||
"""
|
||||
A base class to allow offset and limit when listing a queryset.
|
||||
"""
|
||||
def __init__(self, queryset, offset, limit):
|
||||
self.count = self._set_count(queryset)
|
||||
self.object_list = queryset[offset:offset + limit]
|
||||
|
||||
def _set_count(self, queryset):
|
||||
try:
|
||||
return queryset.count()
|
||||
except (AttributeError, TypeError):
|
||||
# AttributeError if object_list has no count() method.
|
||||
# TypeError if object_list.count() requires arguments
|
||||
# (i.e. is of type list).
|
||||
return len(queryset)
|
||||
|
||||
|
||||
class OffsetLimitPaginationSerializer(BasePaginationSerializer):
|
||||
""" OffsetLimitPage serializer """
|
||||
count = serializers.Field()
|
||||
|
|
|
@ -5,7 +5,8 @@ from django.db import models
|
|||
from django.core.paginator import Paginator
|
||||
from django.test import TestCase
|
||||
from django.utils import unittest
|
||||
from rest_framework import generics, status, pagination, filters, serializers
|
||||
from rest_framework import (generics, status, pagination, filters, serializers,
|
||||
mixins)
|
||||
from rest_framework.compat import django_filters
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.tests.models import BasicModel
|
||||
|
@ -13,6 +14,48 @@ from rest_framework.tests.models import BasicModel
|
|||
factory = APIRequestFactory()
|
||||
|
||||
|
||||
def parse_header_link(value):
|
||||
# Got it from python-requests code
|
||||
# https://github.com/kennethreitz/requests/blob/master/requests/utils.py#L467
|
||||
def parse_header_links(value):
|
||||
"""Return a dict of parsed link headers proxies.
|
||||
|
||||
i.e. Link: <http:/.../front.jpeg>; rel=front; type="image/jpeg",<http://.../back.jpeg>; rel=back;type="image/jpeg"
|
||||
|
||||
"""
|
||||
|
||||
links = []
|
||||
|
||||
replace_chars = " '\""
|
||||
|
||||
for val in value.split(","):
|
||||
try:
|
||||
url, params = val.split(";", 1)
|
||||
except ValueError:
|
||||
url, params = val, ''
|
||||
|
||||
link = {}
|
||||
|
||||
link["url"] = url.strip("<> '\"")
|
||||
|
||||
for param in params.split(";"):
|
||||
try:
|
||||
key, value = param.split("=")
|
||||
except ValueError:
|
||||
break
|
||||
|
||||
link[key.strip(replace_chars)] = value.strip(replace_chars)
|
||||
|
||||
links.append(link)
|
||||
|
||||
return links
|
||||
links = parse_header_links(value)
|
||||
link_map = {}
|
||||
for link in links:
|
||||
link_map[link['rel']] = link['url']
|
||||
return link_map
|
||||
|
||||
|
||||
class FilterableItem(models.Model):
|
||||
text = models.CharField(max_length=100)
|
||||
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
||||
|
@ -52,6 +95,18 @@ class MaxPaginateByView(generics.ListAPIView):
|
|||
paginate_by_param = 'page_size'
|
||||
|
||||
|
||||
class LinkPaginationView(mixins.LinkPaginationMixin,
|
||||
generics.ListCreateAPIView):
|
||||
model = BasicModel
|
||||
paginate_by = 10
|
||||
|
||||
|
||||
class OffsetLimitPaginationView(mixins.OffsetLimitPaginationMixin,
|
||||
generics.ListCreateAPIView):
|
||||
model = BasicModel
|
||||
paginate_by = 10
|
||||
|
||||
|
||||
class IntegrationTestPagination(TestCase):
|
||||
"""
|
||||
Integration tests for paginated list views.
|
||||
|
@ -103,6 +158,102 @@ class IntegrationTestPagination(TestCase):
|
|||
self.assertNotEqual(response.data['previous'], None)
|
||||
|
||||
|
||||
class IntegrationTestLinkPagination(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create 26 BasicModel instances.
|
||||
"""
|
||||
for char in 'abcdefghijklmnopqrstuvwxyz':
|
||||
BasicModel(text=char * 3).save()
|
||||
self.objects = BasicModel.objects
|
||||
self.data = [
|
||||
{'id': obj.id, 'text': obj.text}
|
||||
for obj in self.objects.all()
|
||||
]
|
||||
self.view = LinkPaginationView.as_view()
|
||||
|
||||
def test_get_paginated_root_view(self):
|
||||
"""
|
||||
GET requests to paginated ListCreateAPIView should return paginated results.
|
||||
"""
|
||||
request = factory.get('/')
|
||||
# Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
|
||||
with self.assertNumQueries(2):
|
||||
response = self.view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data, self.data[0:10])
|
||||
link_header = parse_header_link(response['Link'])
|
||||
self.assertTrue('next' in link_header)
|
||||
self.assertTrue('last' in link_header)
|
||||
self.assertFalse('previous' in link_header)
|
||||
self.assertFalse('first' in link_header)
|
||||
|
||||
request = factory.get(link_header['next'])
|
||||
with self.assertNumQueries(2):
|
||||
response = self.view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data, self.data[10:20])
|
||||
link_header = parse_header_link(response['Link'])
|
||||
self.assertTrue('next' in link_header)
|
||||
self.assertTrue('last' in link_header)
|
||||
self.assertTrue('previous' in link_header)
|
||||
self.assertTrue('first' in link_header)
|
||||
|
||||
request = factory.get(link_header['last'])
|
||||
with self.assertNumQueries(2):
|
||||
response = self.view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data, self.data[20:])
|
||||
link_header = parse_header_link(response['Link'])
|
||||
self.assertFalse('next' in link_header)
|
||||
self.assertFalse('last' in link_header)
|
||||
self.assertTrue('previous' in link_header)
|
||||
self.assertTrue('first' in link_header)
|
||||
|
||||
|
||||
class IntegrationTestOffsetLimitPagination(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create 26 BasicModel instances.
|
||||
"""
|
||||
for char in 'abcdefghijklmnopqrstuvwxyz':
|
||||
BasicModel(text=char * 3).save()
|
||||
self.objects = BasicModel.objects
|
||||
self.data = [
|
||||
{'id': obj.id, 'text': obj.text}
|
||||
for obj in self.objects.all()
|
||||
]
|
||||
self.view = OffsetLimitPaginationView.as_view()
|
||||
|
||||
def test_get_paginated_root_view(self):
|
||||
"""
|
||||
GET requests to paginated ListCreateAPIView should return paginated results.
|
||||
"""
|
||||
request = factory.get('/')
|
||||
# Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
|
||||
with self.assertNumQueries(2):
|
||||
response = self.view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data['results'], self.data[0:10])
|
||||
self.assertEqual(response.data['count'], 26)
|
||||
|
||||
request = factory.get('/', {'offset': 10})
|
||||
with self.assertNumQueries(2):
|
||||
response = self.view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data['results'], self.data[10:20])
|
||||
self.assertEqual(response.data['count'], 26)
|
||||
|
||||
request = factory.get('/', {'offset': 20, 'limit': 3})
|
||||
with self.assertNumQueries(2):
|
||||
response = self.view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data['results'], self.data[20:23])
|
||||
self.assertEqual(response.data['count'], 26)
|
||||
|
||||
|
||||
class IntegrationTestPaginationAndFiltering(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -258,6 +409,63 @@ class UnitTestPagination(TestCase):
|
|||
self.assertEqual(serializer.context, results.context)
|
||||
|
||||
|
||||
class UnitTestLinkPagination(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.objects = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
paginator = Paginator(self.objects, 2)
|
||||
self.first_page = paginator.page(1)
|
||||
self.third_page = paginator.page(3)
|
||||
self.last_page = paginator.page(5)
|
||||
|
||||
def test_native_pagination(self):
|
||||
serializer = pagination.LinkPaginationSerializer(self.first_page)
|
||||
self.assertEqual(serializer.data['next'], '?page=2')
|
||||
self.assertEqual(serializer.data['previous'], None)
|
||||
self.assertEqual(serializer.data['first'], None)
|
||||
self.assertEqual(serializer.data['last'], '?page=5')
|
||||
|
||||
serializer = pagination.LinkPaginationSerializer(self.third_page)
|
||||
self.assertEqual(serializer.data['next'], '?page=4')
|
||||
self.assertEqual(serializer.data['previous'], '?page=2')
|
||||
self.assertEqual(serializer.data['first'], '?page=1')
|
||||
self.assertEqual(serializer.data['last'], '?page=5')
|
||||
|
||||
serializer = pagination.LinkPaginationSerializer(self.last_page)
|
||||
self.assertEqual(serializer.data['next'], None)
|
||||
self.assertEqual(serializer.data['previous'], '?page=4')
|
||||
self.assertEqual(serializer.data['first'], '?page=1')
|
||||
self.assertEqual(serializer.data['last'], None)
|
||||
|
||||
|
||||
class UnitTestOffsetLimitPagination(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.objects = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
self.first_page = pagination.OffsetLimitPage(
|
||||
self.objects, offset=0, limit=2)
|
||||
self.third_page = pagination.OffsetLimitPage(
|
||||
self.objects, offset=4, limit=2)
|
||||
self.last_page = pagination.OffsetLimitPage(
|
||||
self.objects, offset=8, limit=2)
|
||||
|
||||
def test_native_pagination(self):
|
||||
serializer = pagination.OffsetLimitPaginationSerializer(
|
||||
self.first_page)
|
||||
self.assertEqual(serializer.data['count'], 10)
|
||||
self.assertEqual(serializer.data['results'], self.objects[:2])
|
||||
|
||||
serializer = pagination.OffsetLimitPaginationSerializer(
|
||||
self.third_page)
|
||||
self.assertEqual(serializer.data['count'], 10)
|
||||
self.assertEqual(serializer.data['results'], self.objects[4:6])
|
||||
|
||||
serializer = pagination.OffsetLimitPaginationSerializer(
|
||||
self.last_page)
|
||||
self.assertEqual(serializer.data['count'], 10)
|
||||
self.assertEqual(serializer.data['results'], self.objects[8:])
|
||||
|
||||
|
||||
class TestUnpaginated(TestCase):
|
||||
"""
|
||||
Tests for list views without pagination.
|
||||
|
|
Loading…
Reference in New Issue
Block a user