mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-02 19:40:13 +03:00
Merge 986581d3a7
into f18158358d
This commit is contained in:
commit
dac81ded1a
|
@ -11,20 +11,10 @@ from django.utils.translation import ugettext as _
|
||||||
from rest_framework import views, mixins, exceptions
|
from rest_framework import views, mixins, exceptions
|
||||||
from rest_framework.request import clone_request
|
from rest_framework.request import clone_request
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
from rest_framework.pagination import strict_positive_int
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
def strict_positive_int(integer_string, cutoff=None):
|
|
||||||
"""
|
|
||||||
Cast a string to a strictly positive integer.
|
|
||||||
"""
|
|
||||||
ret = int(integer_string)
|
|
||||||
if ret <= 0:
|
|
||||||
raise ValueError()
|
|
||||||
if cutoff:
|
|
||||||
ret = min(ret, cutoff)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
|
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
|
||||||
"""
|
"""
|
||||||
Same as Django's standard shortcut, but make sure to raise 404
|
Same as Django's standard shortcut, but make sure to raise 404
|
||||||
|
|
|
@ -10,6 +10,7 @@ from django.http import Http404
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.request import clone_request
|
from rest_framework.request import clone_request
|
||||||
|
from rest_framework import pagination
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
@ -187,3 +188,35 @@ class DestroyModelMixin(object):
|
||||||
obj = self.get_object()
|
obj = self.get_object()
|
||||||
obj.delete()
|
obj.delete()
|
||||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
|
||||||
|
class OffsetLimitPaginationMixin(object):
|
||||||
|
offset_kwarg = 'offset'
|
||||||
|
paginate_by_param = 'limit'
|
||||||
|
pagination_serializer_class = pagination.OffsetLimitPaginationSerializer
|
||||||
|
|
||||||
|
def paginate_queryset(self, queryset):
|
||||||
|
limit = self.get_paginate_by()
|
||||||
|
if not limit:
|
||||||
|
return # pagination not configured
|
||||||
|
offset_kwarg = self.kwargs.get(self.offset_kwarg)
|
||||||
|
offset_query_param = self.request.QUERY_PARAMS.get(self.offset_kwarg)
|
||||||
|
offset = offset_kwarg or offset_query_param or 0
|
||||||
|
try:
|
||||||
|
offset_number = pagination.strict_positive_int(offset)
|
||||||
|
except ValueError:
|
||||||
|
offset_number = 0
|
||||||
|
return pagination.OffsetLimitPage(queryset, offset_number, limit)
|
||||||
|
|
||||||
|
|
||||||
|
class LinkPaginationMixin(object):
|
||||||
|
pagination_serializer_class = pagination.LinkPaginationSerializer
|
||||||
|
|
||||||
|
def paginate_queryset(self, queryset, page_size=None):
|
||||||
|
page = super(LinkPaginationMixin, self).paginate_queryset(
|
||||||
|
queryset, page_size)
|
||||||
|
if page is not None:
|
||||||
|
page_ser = self.get_pagination_serializer(page)
|
||||||
|
self.headers.update(page_ser.get_link_header())
|
||||||
|
self.object_list = page.object_list
|
||||||
|
return None # Don't use pagination serializer on response
|
||||||
|
|
|
@ -7,6 +7,18 @@ from rest_framework import serializers
|
||||||
from rest_framework.templatetags.rest_framework import replace_query_param
|
from rest_framework.templatetags.rest_framework import replace_query_param
|
||||||
|
|
||||||
|
|
||||||
|
def strict_positive_int(integer_string, cutoff=None):
|
||||||
|
"""
|
||||||
|
Cast a string to a strictly positive integer.
|
||||||
|
"""
|
||||||
|
ret = int(integer_string)
|
||||||
|
if ret <= 0:
|
||||||
|
raise ValueError()
|
||||||
|
if cutoff:
|
||||||
|
ret = min(ret, cutoff)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class NextPageField(serializers.Field):
|
class NextPageField(serializers.Field):
|
||||||
"""
|
"""
|
||||||
Field that returns a link to the next page in paginated results.
|
Field that returns a link to the next page in paginated results.
|
||||||
|
@ -37,6 +49,35 @@ class PreviousPageField(serializers.Field):
|
||||||
return replace_query_param(url, self.page_field, page)
|
return replace_query_param(url, self.page_field, page)
|
||||||
|
|
||||||
|
|
||||||
|
class FirstPageField(serializers.Field):
|
||||||
|
"""
|
||||||
|
Field that returns a link to the first page in paginated results.
|
||||||
|
"""
|
||||||
|
page_field = 'page'
|
||||||
|
|
||||||
|
def to_native(self, value):
|
||||||
|
if not value.has_previous():
|
||||||
|
return None
|
||||||
|
request = self.context.get('request')
|
||||||
|
url = request and request.build_absolute_uri() or ''
|
||||||
|
return replace_query_param(url, self.page_field, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class LastPageField(serializers.Field):
|
||||||
|
"""
|
||||||
|
Field that returns a link to the previous page in paginated results.
|
||||||
|
"""
|
||||||
|
page_field = 'page'
|
||||||
|
|
||||||
|
def to_native(self, value):
|
||||||
|
if not value.has_next():
|
||||||
|
return None
|
||||||
|
page = value.paginator.num_pages
|
||||||
|
request = self.context.get('request')
|
||||||
|
url = request and request.build_absolute_uri() or ''
|
||||||
|
return replace_query_param(url, self.page_field, page)
|
||||||
|
|
||||||
|
|
||||||
class DefaultObjectSerializer(serializers.Field):
|
class DefaultObjectSerializer(serializers.Field):
|
||||||
"""
|
"""
|
||||||
If no object serializer is specified, then this serializer will be applied
|
If no object serializer is specified, then this serializer will be applied
|
||||||
|
@ -82,7 +123,8 @@ class BasePaginationSerializer(serializers.Serializer):
|
||||||
else:
|
else:
|
||||||
context_kwarg = {}
|
context_kwarg = {}
|
||||||
|
|
||||||
self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
|
self.fields[results_field] = object_serializer(source='object_list',
|
||||||
|
**context_kwarg)
|
||||||
|
|
||||||
|
|
||||||
class PaginationSerializer(BasePaginationSerializer):
|
class PaginationSerializer(BasePaginationSerializer):
|
||||||
|
@ -92,3 +134,42 @@ class PaginationSerializer(BasePaginationSerializer):
|
||||||
count = serializers.Field(source='paginator.count')
|
count = serializers.Field(source='paginator.count')
|
||||||
next = NextPageField(source='*')
|
next = NextPageField(source='*')
|
||||||
previous = PreviousPageField(source='*')
|
previous = PreviousPageField(source='*')
|
||||||
|
|
||||||
|
|
||||||
|
class LinkPaginationSerializer(serializers.Serializer):
|
||||||
|
""" Pagination serializer in order to build Link header """
|
||||||
|
first = FirstPageField(source='*')
|
||||||
|
next = NextPageField(source='*')
|
||||||
|
previous = PreviousPageField(source='*')
|
||||||
|
last = LastPageField(source='*')
|
||||||
|
|
||||||
|
def get_link_header(self):
|
||||||
|
link_keader_items = [
|
||||||
|
'<%s>; rel="%s"' % (link, rel)
|
||||||
|
for rel, link in self.data.items()
|
||||||
|
if link is not None
|
||||||
|
]
|
||||||
|
return {'Link': ', '.join(link_keader_items)}
|
||||||
|
|
||||||
|
|
||||||
|
class OffsetLimitPage(object):
|
||||||
|
"""
|
||||||
|
A base class to allow offset and limit when listing a queryset.
|
||||||
|
"""
|
||||||
|
def __init__(self, queryset, offset, limit):
|
||||||
|
self.count = self._set_count(queryset)
|
||||||
|
self.object_list = queryset[offset:offset + limit]
|
||||||
|
|
||||||
|
def _set_count(self, queryset):
|
||||||
|
try:
|
||||||
|
return queryset.count()
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# AttributeError if object_list has no count() method.
|
||||||
|
# TypeError if object_list.count() requires arguments
|
||||||
|
# (i.e. is of type list).
|
||||||
|
return len(queryset)
|
||||||
|
|
||||||
|
|
||||||
|
class OffsetLimitPaginationSerializer(BasePaginationSerializer):
|
||||||
|
""" OffsetLimitPage serializer """
|
||||||
|
count = serializers.Field()
|
||||||
|
|
|
@ -5,7 +5,8 @@ from django.db import models
|
||||||
from django.core.paginator import Paginator
|
from django.core.paginator import Paginator
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import unittest
|
from django.utils import unittest
|
||||||
from rest_framework import generics, status, pagination, filters, serializers
|
from rest_framework import (generics, status, pagination, filters, serializers,
|
||||||
|
mixins)
|
||||||
from rest_framework.compat import django_filters
|
from rest_framework.compat import django_filters
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
from rest_framework.tests.models import BasicModel
|
from rest_framework.tests.models import BasicModel
|
||||||
|
@ -13,6 +14,48 @@ from rest_framework.tests.models import BasicModel
|
||||||
factory = APIRequestFactory()
|
factory = APIRequestFactory()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_header_link(value):
|
||||||
|
# Got it from python-requests code
|
||||||
|
# https://github.com/kennethreitz/requests/blob/master/requests/utils.py#L467
|
||||||
|
def parse_header_links(value):
|
||||||
|
"""Return a dict of parsed link headers proxies.
|
||||||
|
|
||||||
|
i.e. Link: <http:/.../front.jpeg>; rel=front; type="image/jpeg",<http://.../back.jpeg>; rel=back;type="image/jpeg"
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
links = []
|
||||||
|
|
||||||
|
replace_chars = " '\""
|
||||||
|
|
||||||
|
for val in value.split(","):
|
||||||
|
try:
|
||||||
|
url, params = val.split(";", 1)
|
||||||
|
except ValueError:
|
||||||
|
url, params = val, ''
|
||||||
|
|
||||||
|
link = {}
|
||||||
|
|
||||||
|
link["url"] = url.strip("<> '\"")
|
||||||
|
|
||||||
|
for param in params.split(";"):
|
||||||
|
try:
|
||||||
|
key, value = param.split("=")
|
||||||
|
except ValueError:
|
||||||
|
break
|
||||||
|
|
||||||
|
link[key.strip(replace_chars)] = value.strip(replace_chars)
|
||||||
|
|
||||||
|
links.append(link)
|
||||||
|
|
||||||
|
return links
|
||||||
|
links = parse_header_links(value)
|
||||||
|
link_map = {}
|
||||||
|
for link in links:
|
||||||
|
link_map[link['rel']] = link['url']
|
||||||
|
return link_map
|
||||||
|
|
||||||
|
|
||||||
class FilterableItem(models.Model):
|
class FilterableItem(models.Model):
|
||||||
text = models.CharField(max_length=100)
|
text = models.CharField(max_length=100)
|
||||||
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
||||||
|
@ -52,6 +95,18 @@ class MaxPaginateByView(generics.ListAPIView):
|
||||||
paginate_by_param = 'page_size'
|
paginate_by_param = 'page_size'
|
||||||
|
|
||||||
|
|
||||||
|
class LinkPaginationView(mixins.LinkPaginationMixin,
|
||||||
|
generics.ListCreateAPIView):
|
||||||
|
model = BasicModel
|
||||||
|
paginate_by = 10
|
||||||
|
|
||||||
|
|
||||||
|
class OffsetLimitPaginationView(mixins.OffsetLimitPaginationMixin,
|
||||||
|
generics.ListCreateAPIView):
|
||||||
|
model = BasicModel
|
||||||
|
paginate_by = 10
|
||||||
|
|
||||||
|
|
||||||
class IntegrationTestPagination(TestCase):
|
class IntegrationTestPagination(TestCase):
|
||||||
"""
|
"""
|
||||||
Integration tests for paginated list views.
|
Integration tests for paginated list views.
|
||||||
|
@ -103,6 +158,102 @@ class IntegrationTestPagination(TestCase):
|
||||||
self.assertNotEqual(response.data['previous'], None)
|
self.assertNotEqual(response.data['previous'], None)
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestLinkPagination(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Create 26 BasicModel instances.
|
||||||
|
"""
|
||||||
|
for char in 'abcdefghijklmnopqrstuvwxyz':
|
||||||
|
BasicModel(text=char * 3).save()
|
||||||
|
self.objects = BasicModel.objects
|
||||||
|
self.data = [
|
||||||
|
{'id': obj.id, 'text': obj.text}
|
||||||
|
for obj in self.objects.all()
|
||||||
|
]
|
||||||
|
self.view = LinkPaginationView.as_view()
|
||||||
|
|
||||||
|
def test_get_paginated_root_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to paginated ListCreateAPIView should return paginated results.
|
||||||
|
"""
|
||||||
|
request = factory.get('/')
|
||||||
|
# Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, self.data[0:10])
|
||||||
|
link_header = parse_header_link(response['Link'])
|
||||||
|
self.assertTrue('next' in link_header)
|
||||||
|
self.assertTrue('last' in link_header)
|
||||||
|
self.assertFalse('previous' in link_header)
|
||||||
|
self.assertFalse('first' in link_header)
|
||||||
|
|
||||||
|
request = factory.get(link_header['next'])
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, self.data[10:20])
|
||||||
|
link_header = parse_header_link(response['Link'])
|
||||||
|
self.assertTrue('next' in link_header)
|
||||||
|
self.assertTrue('last' in link_header)
|
||||||
|
self.assertTrue('previous' in link_header)
|
||||||
|
self.assertTrue('first' in link_header)
|
||||||
|
|
||||||
|
request = factory.get(link_header['last'])
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, self.data[20:])
|
||||||
|
link_header = parse_header_link(response['Link'])
|
||||||
|
self.assertFalse('next' in link_header)
|
||||||
|
self.assertFalse('last' in link_header)
|
||||||
|
self.assertTrue('previous' in link_header)
|
||||||
|
self.assertTrue('first' in link_header)
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestOffsetLimitPagination(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Create 26 BasicModel instances.
|
||||||
|
"""
|
||||||
|
for char in 'abcdefghijklmnopqrstuvwxyz':
|
||||||
|
BasicModel(text=char * 3).save()
|
||||||
|
self.objects = BasicModel.objects
|
||||||
|
self.data = [
|
||||||
|
{'id': obj.id, 'text': obj.text}
|
||||||
|
for obj in self.objects.all()
|
||||||
|
]
|
||||||
|
self.view = OffsetLimitPaginationView.as_view()
|
||||||
|
|
||||||
|
def test_get_paginated_root_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to paginated ListCreateAPIView should return paginated results.
|
||||||
|
"""
|
||||||
|
request = factory.get('/')
|
||||||
|
# Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data['results'], self.data[0:10])
|
||||||
|
self.assertEqual(response.data['count'], 26)
|
||||||
|
|
||||||
|
request = factory.get('/', {'offset': 10})
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data['results'], self.data[10:20])
|
||||||
|
self.assertEqual(response.data['count'], 26)
|
||||||
|
|
||||||
|
request = factory.get('/', {'offset': 20, 'limit': 3})
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data['results'], self.data[20:23])
|
||||||
|
self.assertEqual(response.data['count'], 26)
|
||||||
|
|
||||||
|
|
||||||
class IntegrationTestPaginationAndFiltering(TestCase):
|
class IntegrationTestPaginationAndFiltering(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -258,6 +409,63 @@ class UnitTestPagination(TestCase):
|
||||||
self.assertEqual(serializer.context, results.context)
|
self.assertEqual(serializer.context, results.context)
|
||||||
|
|
||||||
|
|
||||||
|
class UnitTestLinkPagination(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.objects = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
paginator = Paginator(self.objects, 2)
|
||||||
|
self.first_page = paginator.page(1)
|
||||||
|
self.third_page = paginator.page(3)
|
||||||
|
self.last_page = paginator.page(5)
|
||||||
|
|
||||||
|
def test_native_pagination(self):
|
||||||
|
serializer = pagination.LinkPaginationSerializer(self.first_page)
|
||||||
|
self.assertEqual(serializer.data['next'], '?page=2')
|
||||||
|
self.assertEqual(serializer.data['previous'], None)
|
||||||
|
self.assertEqual(serializer.data['first'], None)
|
||||||
|
self.assertEqual(serializer.data['last'], '?page=5')
|
||||||
|
|
||||||
|
serializer = pagination.LinkPaginationSerializer(self.third_page)
|
||||||
|
self.assertEqual(serializer.data['next'], '?page=4')
|
||||||
|
self.assertEqual(serializer.data['previous'], '?page=2')
|
||||||
|
self.assertEqual(serializer.data['first'], '?page=1')
|
||||||
|
self.assertEqual(serializer.data['last'], '?page=5')
|
||||||
|
|
||||||
|
serializer = pagination.LinkPaginationSerializer(self.last_page)
|
||||||
|
self.assertEqual(serializer.data['next'], None)
|
||||||
|
self.assertEqual(serializer.data['previous'], '?page=4')
|
||||||
|
self.assertEqual(serializer.data['first'], '?page=1')
|
||||||
|
self.assertEqual(serializer.data['last'], None)
|
||||||
|
|
||||||
|
|
||||||
|
class UnitTestOffsetLimitPagination(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.objects = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
self.first_page = pagination.OffsetLimitPage(
|
||||||
|
self.objects, offset=0, limit=2)
|
||||||
|
self.third_page = pagination.OffsetLimitPage(
|
||||||
|
self.objects, offset=4, limit=2)
|
||||||
|
self.last_page = pagination.OffsetLimitPage(
|
||||||
|
self.objects, offset=8, limit=2)
|
||||||
|
|
||||||
|
def test_native_pagination(self):
|
||||||
|
serializer = pagination.OffsetLimitPaginationSerializer(
|
||||||
|
self.first_page)
|
||||||
|
self.assertEqual(serializer.data['count'], 10)
|
||||||
|
self.assertEqual(serializer.data['results'], self.objects[:2])
|
||||||
|
|
||||||
|
serializer = pagination.OffsetLimitPaginationSerializer(
|
||||||
|
self.third_page)
|
||||||
|
self.assertEqual(serializer.data['count'], 10)
|
||||||
|
self.assertEqual(serializer.data['results'], self.objects[4:6])
|
||||||
|
|
||||||
|
serializer = pagination.OffsetLimitPaginationSerializer(
|
||||||
|
self.last_page)
|
||||||
|
self.assertEqual(serializer.data['count'], 10)
|
||||||
|
self.assertEqual(serializer.data['results'], self.objects[8:])
|
||||||
|
|
||||||
|
|
||||||
class TestUnpaginated(TestCase):
|
class TestUnpaginated(TestCase):
|
||||||
"""
|
"""
|
||||||
Tests for list views without pagination.
|
Tests for list views without pagination.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user