Add pagination. Thanks @devioustree!

This commit is contained in:
Tom Christie 2011-12-09 13:37:53 +00:00
parent 42cdd00591
commit 5db422c9d3
2 changed files with 254 additions and 43 deletions

View File

@ -1,23 +1,20 @@
""" """
The :mod:`mixins` module provides a set of reusable `mixin` The :mod:`mixins` module provides a set of reusable `mixin`
classes that can be added to a `View`. classes that can be added to a `View`.
""" """
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
from django.db.models.query import QuerySet from django.core.paginator import Paginator
from django.db.models.fields.related import ForeignKey from django.db.models.fields.related import ForeignKey
from django.http import HttpResponse from django.http import HttpResponse
from djangorestframework import status from djangorestframework import status
from djangorestframework.parsers import FormParser, MultiPartParser
from djangorestframework.renderers import BaseRenderer from djangorestframework.renderers import BaseRenderer
from djangorestframework.resources import Resource, FormResource, ModelResource from djangorestframework.resources import Resource, FormResource, ModelResource
from djangorestframework.response import Response, ErrorResponse from djangorestframework.response import Response, ErrorResponse
from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX
from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence
from decimal import Decimal
import re
from StringIO import StringIO from StringIO import StringIO
@ -52,7 +49,7 @@ class RequestMixin(object):
""" """
The set of request parsers that the view can handle. The set of request parsers that the view can handle.
Should be a tuple/list of classes as described in the :mod:`parsers` module. Should be a tuple/list of classes as described in the :mod:`parsers` module.
""" """
parsers = () parsers = ()
@ -158,7 +155,7 @@ class RequestMixin(object):
# We only need to use form overloading on form POST requests. # We only need to use form overloading on form POST requests.
if not self._USE_FORM_OVERLOADING or self._method != 'POST' or not is_form_media_type(self._content_type): if not self._USE_FORM_OVERLOADING or self._method != 'POST' or not is_form_media_type(self._content_type):
return return
# At this point we're committed to parsing the request as form data. # At this point we're committed to parsing the request as form data.
self._data = data = self.request.POST.copy() self._data = data = self.request.POST.copy()
self._files = self.request.FILES self._files = self.request.FILES
@ -203,12 +200,12 @@ class RequestMixin(object):
""" """
return [parser.media_type for parser in self.parsers] return [parser.media_type for parser in self.parsers]
@property @property
def _default_parser(self): def _default_parser(self):
""" """
Return the view's default parser class. Return the view's default parser class.
""" """
return self.parsers[0] return self.parsers[0]
@ -218,7 +215,7 @@ class RequestMixin(object):
class ResponseMixin(object): class ResponseMixin(object):
""" """
Adds behavior for pluggable `Renderers` to a :class:`views.View` class. Adds behavior for pluggable `Renderers` to a :class:`views.View` class.
Default behavior is to use standard HTTP Accept header content negotiation. Default behavior is to use standard HTTP Accept header content negotiation.
Also supports overriding the content type by specifying an ``_accept=`` parameter in the URL. Also supports overriding the content type by specifying an ``_accept=`` parameter in the URL.
Ignores Accept headers from Internet Explorer user agents and uses a sensible browser Accept header instead. Ignores Accept headers from Internet Explorer user agents and uses a sensible browser Accept header instead.
@ -229,8 +226,8 @@ class ResponseMixin(object):
""" """
The set of response renderers that the view can handle. The set of response renderers that the view can handle.
Should be a tuple/list of classes as described in the :mod:`renderers` module. Should be a tuple/list of classes as described in the :mod:`renderers` module.
""" """
renderers = () renderers = ()
@ -253,7 +250,7 @@ class ResponseMixin(object):
# Set the media type of the response # Set the media type of the response
# Note that the renderer *could* override it in .render() if required. # Note that the renderer *could* override it in .render() if required.
response.media_type = renderer.media_type response.media_type = renderer.media_type
# Serialize the response content # Serialize the response content
if response.has_content_body: if response.has_content_body:
content = renderer.render(response.cleaned_content, media_type) content = renderer.render(response.cleaned_content, media_type)
@ -317,7 +314,7 @@ class ResponseMixin(object):
Return an list of all the media types that this view can render. Return an list of all the media types that this view can render.
""" """
return [renderer.media_type for renderer in self.renderers] return [renderer.media_type for renderer in self.renderers]
@property @property
def _rendered_formats(self): def _rendered_formats(self):
""" """
@ -339,18 +336,18 @@ class AuthMixin(object):
""" """
Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class. Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class.
""" """
""" """
The set of authentication types that this view can handle. The set of authentication types that this view can handle.
Should be a tuple/list of classes as described in the :mod:`authentication` module. Should be a tuple/list of classes as described in the :mod:`authentication` module.
""" """
authentication = () authentication = ()
""" """
The set of permissions that will be enforced on this view. The set of permissions that will be enforced on this view.
Should be a tuple/list of classes as described in the :mod:`permissions` module. Should be a tuple/list of classes as described in the :mod:`permissions` module.
""" """
permissions = () permissions = ()
@ -359,7 +356,7 @@ class AuthMixin(object):
def user(self): def user(self):
""" """
Returns the :obj:`user` for the current request, as determined by the set of Returns the :obj:`user` for the current request, as determined by the set of
:class:`authentication` classes applied to the :class:`View`. :class:`authentication` classes applied to the :class:`View`.
""" """
if not hasattr(self, '_user'): if not hasattr(self, '_user'):
self._user = self._authenticate() self._user = self._authenticate()
@ -541,13 +538,13 @@ class CreateModelMixin(object):
for fieldname in m2m_data: for fieldname in m2m_data:
manager = getattr(instance, fieldname) manager = getattr(instance, fieldname)
if hasattr(manager, 'add'): if hasattr(manager, 'add'):
manager.add(*m2m_data[fieldname][1]) manager.add(*m2m_data[fieldname][1])
else: else:
data = {} data = {}
data[manager.source_field_name] = instance data[manager.source_field_name] = instance
for related_item in m2m_data[fieldname][1]: for related_item in m2m_data[fieldname][1]:
data[m2m_data[fieldname][0]] = related_item data[m2m_data[fieldname][0]] = related_item
manager.through(**data).save() manager.through(**data).save()
@ -564,8 +561,8 @@ class UpdateModelMixin(object):
""" """
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
model = self.resource.model model = self.resource.model
# TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url
try: try:
if args: if args:
# If we have any none kwargs then assume the last represents the primary key # If we have any none kwargs then assume the last represents the primary key
@ -640,3 +637,93 @@ class ListModelMixin(object):
return queryset.filter(**kwargs) return queryset.filter(**kwargs)
########## Pagination Mixins ##########
class PaginatorMixin(object):
"""
Adds pagination support to GET requests
Obviously should only be used on lists :)
A default limit can be set by setting `limit` on the object. This will also
be used as the maximum if the client sets the `limit` GET param
"""
limit = 20
def get_limit(self):
""" Helper method to determine what the `limit` should be """
try:
limit = int(self.request.GET.get('limit', self.limit))
return min(limit, self.limit)
except ValueError:
return self.limit
def url_with_page_number(self, page_number):
""" Constructs a url used for getting the next/previous urls """
url = "%s?page=%d" % (self.request.path, page_number)
limit = self.get_limit()
if limit != self.limit:
url = "%s&limit=%d" % (url, limit)
return url
def next(self, page):
""" Returns a url to the next page of results (if any) """
if not page.has_next():
return None
return self.url_with_page_number(page.next_page_number())
def previous(self, page):
""" Returns a url to the previous page of results (if any) """
if not page.has_previous():
return None
return self.url_with_page_number(page.previous_page_number())
def serialize_page_info(self, page):
""" This is some useful information that is added to the response """
return {
'next': self.next(page),
'page': page.number,
'pages': page.paginator.num_pages,
'per_page': self.get_limit(),
'previous': self.previous(page),
'total': page.paginator.count,
}
def filter_response(self, obj):
"""
Given the response content, paginate and then serialize.
The response is modified to include to useful data relating to the number
of objects, number of pages, next/previous urls etc. etc.
The serialised objects are put into `results` on this new, modified
response
"""
# We don't want to paginate responses for anything other than GET requests
if self.method.upper() != 'GET':
return self._resource.filter_response(obj)
paginator = Paginator(obj, self.get_limit())
try:
page_num = int(self.request.GET.get('page', '1'))
except ValueError:
raise ErrorResponse(status.HTTP_404_NOT_FOUND,
{'detail': 'That page contains no results'})
if page_num not in paginator.page_range:
raise ErrorResponse(status.HTTP_404_NOT_FOUND,
{'detail': 'That page contains no results'})
page = paginator.page(page_num)
serialized_object_list = self._resource.filter_response(page.object_list)
serialized_page_info = self.serialize_page_info(page)
serialized_page_info['results'] = serialized_object_list
return serialized_page_info

View File

@ -1,14 +1,17 @@
"""Tests for the status module""" """Tests for the mixin module"""
from django.test import TestCase from django.test import TestCase
from django.utils import simplejson as json
from djangorestframework import status from djangorestframework import status
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from djangorestframework.mixins import CreateModelMixin from djangorestframework.mixins import CreateModelMixin, PaginatorMixin
from djangorestframework.resources import ModelResource from djangorestframework.resources import ModelResource
from djangorestframework.response import Response
from djangorestframework.tests.models import CustomUser from djangorestframework.tests.models import CustomUser
from djangorestframework.views import View
class TestModelCreation(TestCase): class TestModelCreation(TestCase):
"""Tests on CreateModelMixin""" """Tests on CreateModelMixin"""
def setUp(self): def setUp(self):
@ -25,23 +28,26 @@ class TestModelCreation(TestCase):
mixin = CreateModelMixin() mixin = CreateModelMixin()
mixin.resource = GroupResource mixin.resource = GroupResource
mixin.CONTENT = form_data mixin.CONTENT = form_data
response = mixin.post(request) response = mixin.post(request)
self.assertEquals(1, Group.objects.count()) self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', response.cleaned_content.name) self.assertEquals('foo', response.cleaned_content.name)
def test_creation_with_m2m_relation(self): def test_creation_with_m2m_relation(self):
class UserResource(ModelResource): class UserResource(ModelResource):
model = User model = User
def url(self, instance): def url(self, instance):
return "/users/%i" % instance.id return "/users/%i" % instance.id
group = Group(name='foo') group = Group(name='foo')
group.save() group.save()
form_data = {'username': 'bar', 'password': 'baz', 'groups': [group.id]} form_data = {
'username': 'bar',
'password': 'baz',
'groups': [group.id]
}
request = self.req.post('/groups', data=form_data) request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data) cleaned_data = dict(form_data)
cleaned_data['groups'] = [group] cleaned_data['groups'] = [group]
@ -53,18 +59,18 @@ class TestModelCreation(TestCase):
self.assertEquals(1, User.objects.count()) self.assertEquals(1, User.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count()) self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name) self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
def test_creation_with_m2m_relation_through(self): def test_creation_with_m2m_relation_through(self):
""" """
Tests creation where the m2m relation uses a through table Tests creation where the m2m relation uses a through table
""" """
class UserResource(ModelResource): class UserResource(ModelResource):
model = CustomUser model = CustomUser
def url(self, instance): def url(self, instance):
return "/customusers/%i" % instance.id return "/customusers/%i" % instance.id
form_data = {'username': 'bar0', 'groups': []} form_data = {'username': 'bar0', 'groups': []}
request = self.req.post('/groups', data=form_data) request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data) cleaned_data = dict(form_data)
cleaned_data['groups'] = [] cleaned_data['groups'] = []
@ -74,12 +80,12 @@ class TestModelCreation(TestCase):
response = mixin.post(request) response = mixin.post(request)
self.assertEquals(1, CustomUser.objects.count()) self.assertEquals(1, CustomUser.objects.count())
self.assertEquals(0, response.cleaned_content.groups.count()) self.assertEquals(0, response.cleaned_content.groups.count())
group = Group(name='foo1') group = Group(name='foo1')
group.save() group.save()
form_data = {'username': 'bar1', 'groups': [group.id]} form_data = {'username': 'bar1', 'groups': [group.id]}
request = self.req.post('/groups', data=form_data) request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data) cleaned_data = dict(form_data)
cleaned_data['groups'] = [group] cleaned_data['groups'] = [group]
@ -91,12 +97,11 @@ class TestModelCreation(TestCase):
self.assertEquals(2, CustomUser.objects.count()) self.assertEquals(2, CustomUser.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count()) self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name) self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
group2 = Group(name='foo2') group2 = Group(name='foo2')
group2.save() group2.save()
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]} form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data) request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data) cleaned_data = dict(form_data)
cleaned_data['groups'] = [group, group2] cleaned_data['groups'] = [group, group2]
@ -109,5 +114,124 @@ class TestModelCreation(TestCase):
self.assertEquals(2, response.cleaned_content.groups.count()) self.assertEquals(2, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name) self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name) self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
class MockPaginatorView(PaginatorMixin, View):
total = 60
def get(self, request):
return range(0, self.total)
def post(self, request):
return Response(status.CREATED, {'status': 'OK'})
class TestPagination(TestCase):
def setUp(self):
self.req = RequestFactory()
def test_default_limit(self):
""" Tests if pagination works without overwriting the limit """
request = self.req.get('/paginator')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertEqual(MockPaginatorView.limit, content['per_page'])
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
def test_overwriting_limit(self):
""" Tests if the limit can be overwritten """
limit = 10
request = self.req.get('/paginator')
response = MockPaginatorView.as_view(limit=limit)(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(content['per_page'], limit)
self.assertEqual(range(0, limit), content['results'])
def test_limit_param(self):
""" Tests if the client can set the limit """
from math import ceil
limit = 5
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
request = self.req.get('/paginator/?limit=%d' % limit)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertEqual(limit, content['per_page'])
self.assertEqual(num_pages, content['pages'])
def test_exceeding_limit(self):
""" Makes sure the client cannot exceed the default limit """
from math import ceil
limit = MockPaginatorView.limit + 10
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
request = self.req.get('/paginator/?limit=%d' % limit)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertNotEqual(limit, content['per_page'])
self.assertNotEqual(num_pages, content['pages'])
self.assertEqual(MockPaginatorView.limit, content['per_page'])
def test_only_works_for_get(self):
""" Pagination should only work for GET requests """
request = self.req.post('/paginator', data={'content': 'spam'})
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.CREATED)
self.assertEqual(None, content.get('per_page'))
self.assertEqual('OK', content['status'])
def test_non_int_page(self):
""" Tests that it can handle invalid values """
request = self.req.get('/paginator/?page=spam')
response = MockPaginatorView.as_view()(request)
self.assertEqual(response.status_code, status.NOT_FOUND)
def test_page_range(self):
""" Tests that the page range is handle correctly """
request = self.req.get('/paginator/?page=0')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.NOT_FOUND)
request = self.req.get('/paginator/')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
num_pages = content['pages']
request = self.req.get('/paginator/?page=%d' % num_pages)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results'])
request = self.req.get('/paginator/?page=%d' % (num_pages + 1,))
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.NOT_FOUND)