mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-11 04:07:39 +03:00
Add pagination. Thanks @devioustree!
This commit is contained in:
parent
42cdd00591
commit
5db422c9d3
|
@ -4,20 +4,17 @@ classes that can be added to a `View`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from django.contrib.auth.models import AnonymousUser
|
from django.contrib.auth.models import AnonymousUser
|
||||||
from django.db.models.query import QuerySet
|
from django.core.paginator import Paginator
|
||||||
from django.db.models.fields.related import ForeignKey
|
from django.db.models.fields.related import ForeignKey
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
|
|
||||||
from djangorestframework import status
|
from djangorestframework import status
|
||||||
from djangorestframework.parsers import FormParser, MultiPartParser
|
|
||||||
from djangorestframework.renderers import BaseRenderer
|
from djangorestframework.renderers import BaseRenderer
|
||||||
from djangorestframework.resources import Resource, FormResource, ModelResource
|
from djangorestframework.resources import Resource, FormResource, ModelResource
|
||||||
from djangorestframework.response import Response, ErrorResponse
|
from djangorestframework.response import Response, ErrorResponse
|
||||||
from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX
|
from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX
|
||||||
from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence
|
from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence
|
||||||
|
|
||||||
from decimal import Decimal
|
|
||||||
import re
|
|
||||||
from StringIO import StringIO
|
from StringIO import StringIO
|
||||||
|
|
||||||
|
|
||||||
|
@ -640,3 +637,93 @@ class ListModelMixin(object):
|
||||||
return queryset.filter(**kwargs)
|
return queryset.filter(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
########## Pagination Mixins ##########
|
||||||
|
|
||||||
|
class PaginatorMixin(object):
|
||||||
|
"""
|
||||||
|
Adds pagination support to GET requests
|
||||||
|
Obviously should only be used on lists :)
|
||||||
|
|
||||||
|
A default limit can be set by setting `limit` on the object. This will also
|
||||||
|
be used as the maximum if the client sets the `limit` GET param
|
||||||
|
"""
|
||||||
|
limit = 20
|
||||||
|
|
||||||
|
def get_limit(self):
|
||||||
|
""" Helper method to determine what the `limit` should be """
|
||||||
|
try:
|
||||||
|
limit = int(self.request.GET.get('limit', self.limit))
|
||||||
|
return min(limit, self.limit)
|
||||||
|
except ValueError:
|
||||||
|
return self.limit
|
||||||
|
|
||||||
|
def url_with_page_number(self, page_number):
|
||||||
|
""" Constructs a url used for getting the next/previous urls """
|
||||||
|
url = "%s?page=%d" % (self.request.path, page_number)
|
||||||
|
|
||||||
|
limit = self.get_limit()
|
||||||
|
if limit != self.limit:
|
||||||
|
url = "%s&limit=%d" % (url, limit)
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
def next(self, page):
|
||||||
|
""" Returns a url to the next page of results (if any) """
|
||||||
|
if not page.has_next():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.url_with_page_number(page.next_page_number())
|
||||||
|
|
||||||
|
def previous(self, page):
|
||||||
|
""" Returns a url to the previous page of results (if any) """
|
||||||
|
if not page.has_previous():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.url_with_page_number(page.previous_page_number())
|
||||||
|
|
||||||
|
def serialize_page_info(self, page):
|
||||||
|
""" This is some useful information that is added to the response """
|
||||||
|
return {
|
||||||
|
'next': self.next(page),
|
||||||
|
'page': page.number,
|
||||||
|
'pages': page.paginator.num_pages,
|
||||||
|
'per_page': self.get_limit(),
|
||||||
|
'previous': self.previous(page),
|
||||||
|
'total': page.paginator.count,
|
||||||
|
}
|
||||||
|
|
||||||
|
def filter_response(self, obj):
|
||||||
|
"""
|
||||||
|
Given the response content, paginate and then serialize.
|
||||||
|
|
||||||
|
The response is modified to include to useful data relating to the number
|
||||||
|
of objects, number of pages, next/previous urls etc. etc.
|
||||||
|
|
||||||
|
The serialised objects are put into `results` on this new, modified
|
||||||
|
response
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We don't want to paginate responses for anything other than GET requests
|
||||||
|
if self.method.upper() != 'GET':
|
||||||
|
return self._resource.filter_response(obj)
|
||||||
|
|
||||||
|
paginator = Paginator(obj, self.get_limit())
|
||||||
|
|
||||||
|
try:
|
||||||
|
page_num = int(self.request.GET.get('page', '1'))
|
||||||
|
except ValueError:
|
||||||
|
raise ErrorResponse(status.HTTP_404_NOT_FOUND,
|
||||||
|
{'detail': 'That page contains no results'})
|
||||||
|
|
||||||
|
if page_num not in paginator.page_range:
|
||||||
|
raise ErrorResponse(status.HTTP_404_NOT_FOUND,
|
||||||
|
{'detail': 'That page contains no results'})
|
||||||
|
|
||||||
|
page = paginator.page(page_num)
|
||||||
|
|
||||||
|
serialized_object_list = self._resource.filter_response(page.object_list)
|
||||||
|
serialized_page_info = self.serialize_page_info(page)
|
||||||
|
|
||||||
|
serialized_page_info['results'] = serialized_object_list
|
||||||
|
|
||||||
|
return serialized_page_info
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
"""Tests for the status module"""
|
"""Tests for the mixin module"""
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
from django.utils import simplejson as json
|
||||||
from djangorestframework import status
|
from djangorestframework import status
|
||||||
from djangorestframework.compat import RequestFactory
|
from djangorestframework.compat import RequestFactory
|
||||||
from django.contrib.auth.models import Group, User
|
from django.contrib.auth.models import Group, User
|
||||||
from djangorestframework.mixins import CreateModelMixin
|
from djangorestframework.mixins import CreateModelMixin, PaginatorMixin
|
||||||
from djangorestframework.resources import ModelResource
|
from djangorestframework.resources import ModelResource
|
||||||
|
from djangorestframework.response import Response
|
||||||
from djangorestframework.tests.models import CustomUser
|
from djangorestframework.tests.models import CustomUser
|
||||||
|
from djangorestframework.views import View
|
||||||
|
|
||||||
|
|
||||||
class TestModelCreation(TestCase):
|
class TestModelCreation(TestCase):
|
||||||
|
@ -30,7 +33,6 @@ class TestModelCreation(TestCase):
|
||||||
self.assertEquals(1, Group.objects.count())
|
self.assertEquals(1, Group.objects.count())
|
||||||
self.assertEquals('foo', response.cleaned_content.name)
|
self.assertEquals('foo', response.cleaned_content.name)
|
||||||
|
|
||||||
|
|
||||||
def test_creation_with_m2m_relation(self):
|
def test_creation_with_m2m_relation(self):
|
||||||
class UserResource(ModelResource):
|
class UserResource(ModelResource):
|
||||||
model = User
|
model = User
|
||||||
|
@ -41,7 +43,11 @@ class TestModelCreation(TestCase):
|
||||||
group = Group(name='foo')
|
group = Group(name='foo')
|
||||||
group.save()
|
group.save()
|
||||||
|
|
||||||
form_data = {'username': 'bar', 'password': 'baz', 'groups': [group.id]}
|
form_data = {
|
||||||
|
'username': 'bar',
|
||||||
|
'password': 'baz',
|
||||||
|
'groups': [group.id]
|
||||||
|
}
|
||||||
request = self.req.post('/groups', data=form_data)
|
request = self.req.post('/groups', data=form_data)
|
||||||
cleaned_data = dict(form_data)
|
cleaned_data = dict(form_data)
|
||||||
cleaned_data['groups'] = [group]
|
cleaned_data['groups'] = [group]
|
||||||
|
@ -92,7 +98,6 @@ class TestModelCreation(TestCase):
|
||||||
self.assertEquals(1, response.cleaned_content.groups.count())
|
self.assertEquals(1, response.cleaned_content.groups.count())
|
||||||
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
|
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
|
||||||
|
|
||||||
|
|
||||||
group2 = Group(name='foo2')
|
group2 = Group(name='foo2')
|
||||||
group2.save()
|
group2.save()
|
||||||
|
|
||||||
|
@ -111,3 +116,122 @@ class TestModelCreation(TestCase):
|
||||||
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
|
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
|
||||||
|
|
||||||
|
|
||||||
|
class MockPaginatorView(PaginatorMixin, View):
|
||||||
|
total = 60
|
||||||
|
|
||||||
|
def get(self, request):
|
||||||
|
return range(0, self.total)
|
||||||
|
|
||||||
|
def post(self, request):
|
||||||
|
return Response(status.CREATED, {'status': 'OK'})
|
||||||
|
|
||||||
|
|
||||||
|
class TestPagination(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.req = RequestFactory()
|
||||||
|
|
||||||
|
def test_default_limit(self):
|
||||||
|
""" Tests if pagination works without overwriting the limit """
|
||||||
|
request = self.req.get('/paginator')
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
|
||||||
|
content = json.loads(response.content)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.OK)
|
||||||
|
self.assertEqual(MockPaginatorView.total, content['total'])
|
||||||
|
self.assertEqual(MockPaginatorView.limit, content['per_page'])
|
||||||
|
|
||||||
|
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
|
||||||
|
|
||||||
|
def test_overwriting_limit(self):
|
||||||
|
""" Tests if the limit can be overwritten """
|
||||||
|
limit = 10
|
||||||
|
|
||||||
|
request = self.req.get('/paginator')
|
||||||
|
response = MockPaginatorView.as_view(limit=limit)(request)
|
||||||
|
|
||||||
|
content = json.loads(response.content)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.OK)
|
||||||
|
self.assertEqual(content['per_page'], limit)
|
||||||
|
|
||||||
|
self.assertEqual(range(0, limit), content['results'])
|
||||||
|
|
||||||
|
def test_limit_param(self):
|
||||||
|
""" Tests if the client can set the limit """
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
limit = 5
|
||||||
|
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
|
||||||
|
|
||||||
|
request = self.req.get('/paginator/?limit=%d' % limit)
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
|
||||||
|
content = json.loads(response.content)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.OK)
|
||||||
|
self.assertEqual(MockPaginatorView.total, content['total'])
|
||||||
|
self.assertEqual(limit, content['per_page'])
|
||||||
|
self.assertEqual(num_pages, content['pages'])
|
||||||
|
|
||||||
|
def test_exceeding_limit(self):
|
||||||
|
""" Makes sure the client cannot exceed the default limit """
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
limit = MockPaginatorView.limit + 10
|
||||||
|
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
|
||||||
|
|
||||||
|
request = self.req.get('/paginator/?limit=%d' % limit)
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
|
||||||
|
content = json.loads(response.content)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.OK)
|
||||||
|
self.assertEqual(MockPaginatorView.total, content['total'])
|
||||||
|
self.assertNotEqual(limit, content['per_page'])
|
||||||
|
self.assertNotEqual(num_pages, content['pages'])
|
||||||
|
self.assertEqual(MockPaginatorView.limit, content['per_page'])
|
||||||
|
|
||||||
|
def test_only_works_for_get(self):
|
||||||
|
""" Pagination should only work for GET requests """
|
||||||
|
request = self.req.post('/paginator', data={'content': 'spam'})
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
|
||||||
|
content = json.loads(response.content)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.CREATED)
|
||||||
|
self.assertEqual(None, content.get('per_page'))
|
||||||
|
self.assertEqual('OK', content['status'])
|
||||||
|
|
||||||
|
def test_non_int_page(self):
|
||||||
|
""" Tests that it can handle invalid values """
|
||||||
|
request = self.req.get('/paginator/?page=spam')
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.NOT_FOUND)
|
||||||
|
|
||||||
|
def test_page_range(self):
|
||||||
|
""" Tests that the page range is handle correctly """
|
||||||
|
request = self.req.get('/paginator/?page=0')
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
content = json.loads(response.content)
|
||||||
|
self.assertEqual(response.status_code, status.NOT_FOUND)
|
||||||
|
|
||||||
|
request = self.req.get('/paginator/')
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
content = json.loads(response.content)
|
||||||
|
self.assertEqual(response.status_code, status.OK)
|
||||||
|
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
|
||||||
|
|
||||||
|
num_pages = content['pages']
|
||||||
|
|
||||||
|
request = self.req.get('/paginator/?page=%d' % num_pages)
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
content = json.loads(response.content)
|
||||||
|
self.assertEqual(response.status_code, status.OK)
|
||||||
|
self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results'])
|
||||||
|
|
||||||
|
request = self.req.get('/paginator/?page=%d' % (num_pages + 1,))
|
||||||
|
response = MockPaginatorView.as_view()(request)
|
||||||
|
content = json.loads(response.content)
|
||||||
|
self.assertEqual(response.status_code, status.NOT_FOUND)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user