mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-03-31 15:24:31 +03:00
Add pagination. Thanks @devioustree!
This commit is contained in:
parent
42cdd00591
commit
5db422c9d3
|
@ -1,23 +1,20 @@
|
|||
"""
|
||||
The :mod:`mixins` module provides a set of reusable `mixin`
|
||||
The :mod:`mixins` module provides a set of reusable `mixin`
|
||||
classes that can be added to a `View`.
|
||||
"""
|
||||
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.db.models.query import QuerySet
|
||||
from django.core.paginator import Paginator
|
||||
from django.db.models.fields.related import ForeignKey
|
||||
from django.http import HttpResponse
|
||||
|
||||
from djangorestframework import status
|
||||
from djangorestframework.parsers import FormParser, MultiPartParser
|
||||
from djangorestframework.renderers import BaseRenderer
|
||||
from djangorestframework.resources import Resource, FormResource, ModelResource
|
||||
from djangorestframework.response import Response, ErrorResponse
|
||||
from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX
|
||||
from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence
|
||||
|
||||
from decimal import Decimal
|
||||
import re
|
||||
from StringIO import StringIO
|
||||
|
||||
|
||||
|
@ -52,7 +49,7 @@ class RequestMixin(object):
|
|||
|
||||
"""
|
||||
The set of request parsers that the view can handle.
|
||||
|
||||
|
||||
Should be a tuple/list of classes as described in the :mod:`parsers` module.
|
||||
"""
|
||||
parsers = ()
|
||||
|
@ -158,7 +155,7 @@ class RequestMixin(object):
|
|||
# We only need to use form overloading on form POST requests.
|
||||
if not self._USE_FORM_OVERLOADING or self._method != 'POST' or not is_form_media_type(self._content_type):
|
||||
return
|
||||
|
||||
|
||||
# At this point we're committed to parsing the request as form data.
|
||||
self._data = data = self.request.POST.copy()
|
||||
self._files = self.request.FILES
|
||||
|
@ -203,12 +200,12 @@ class RequestMixin(object):
|
|||
"""
|
||||
return [parser.media_type for parser in self.parsers]
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def _default_parser(self):
|
||||
"""
|
||||
Return the view's default parser class.
|
||||
"""
|
||||
"""
|
||||
return self.parsers[0]
|
||||
|
||||
|
||||
|
@ -218,7 +215,7 @@ class RequestMixin(object):
|
|||
class ResponseMixin(object):
|
||||
"""
|
||||
Adds behavior for pluggable `Renderers` to a :class:`views.View` class.
|
||||
|
||||
|
||||
Default behavior is to use standard HTTP Accept header content negotiation.
|
||||
Also supports overriding the content type by specifying an ``_accept=`` parameter in the URL.
|
||||
Ignores Accept headers from Internet Explorer user agents and uses a sensible browser Accept header instead.
|
||||
|
@ -229,8 +226,8 @@ class ResponseMixin(object):
|
|||
|
||||
"""
|
||||
The set of response renderers that the view can handle.
|
||||
|
||||
Should be a tuple/list of classes as described in the :mod:`renderers` module.
|
||||
|
||||
Should be a tuple/list of classes as described in the :mod:`renderers` module.
|
||||
"""
|
||||
renderers = ()
|
||||
|
||||
|
@ -253,7 +250,7 @@ class ResponseMixin(object):
|
|||
# Set the media type of the response
|
||||
# Note that the renderer *could* override it in .render() if required.
|
||||
response.media_type = renderer.media_type
|
||||
|
||||
|
||||
# Serialize the response content
|
||||
if response.has_content_body:
|
||||
content = renderer.render(response.cleaned_content, media_type)
|
||||
|
@ -317,7 +314,7 @@ class ResponseMixin(object):
|
|||
Return an list of all the media types that this view can render.
|
||||
"""
|
||||
return [renderer.media_type for renderer in self.renderers]
|
||||
|
||||
|
||||
@property
|
||||
def _rendered_formats(self):
|
||||
"""
|
||||
|
@ -339,18 +336,18 @@ class AuthMixin(object):
|
|||
"""
|
||||
Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class.
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
The set of authentication types that this view can handle.
|
||||
|
||||
Should be a tuple/list of classes as described in the :mod:`authentication` module.
|
||||
|
||||
Should be a tuple/list of classes as described in the :mod:`authentication` module.
|
||||
"""
|
||||
authentication = ()
|
||||
|
||||
"""
|
||||
The set of permissions that will be enforced on this view.
|
||||
|
||||
Should be a tuple/list of classes as described in the :mod:`permissions` module.
|
||||
|
||||
Should be a tuple/list of classes as described in the :mod:`permissions` module.
|
||||
"""
|
||||
permissions = ()
|
||||
|
||||
|
@ -359,7 +356,7 @@ class AuthMixin(object):
|
|||
def user(self):
|
||||
"""
|
||||
Returns the :obj:`user` for the current request, as determined by the set of
|
||||
:class:`authentication` classes applied to the :class:`View`.
|
||||
:class:`authentication` classes applied to the :class:`View`.
|
||||
"""
|
||||
if not hasattr(self, '_user'):
|
||||
self._user = self._authenticate()
|
||||
|
@ -541,13 +538,13 @@ class CreateModelMixin(object):
|
|||
|
||||
for fieldname in m2m_data:
|
||||
manager = getattr(instance, fieldname)
|
||||
|
||||
|
||||
if hasattr(manager, 'add'):
|
||||
manager.add(*m2m_data[fieldname][1])
|
||||
else:
|
||||
data = {}
|
||||
data[manager.source_field_name] = instance
|
||||
|
||||
|
||||
for related_item in m2m_data[fieldname][1]:
|
||||
data[m2m_data[fieldname][0]] = related_item
|
||||
manager.through(**data).save()
|
||||
|
@ -564,8 +561,8 @@ class UpdateModelMixin(object):
|
|||
"""
|
||||
def put(self, request, *args, **kwargs):
|
||||
model = self.resource.model
|
||||
|
||||
# TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url
|
||||
|
||||
# TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url
|
||||
try:
|
||||
if args:
|
||||
# If we have any none kwargs then assume the last represents the primary key
|
||||
|
@ -640,3 +637,93 @@ class ListModelMixin(object):
|
|||
return queryset.filter(**kwargs)
|
||||
|
||||
|
||||
########## Pagination Mixins ##########
|
||||
|
||||
class PaginatorMixin(object):
|
||||
"""
|
||||
Adds pagination support to GET requests
|
||||
Obviously should only be used on lists :)
|
||||
|
||||
A default limit can be set by setting `limit` on the object. This will also
|
||||
be used as the maximum if the client sets the `limit` GET param
|
||||
"""
|
||||
limit = 20
|
||||
|
||||
def get_limit(self):
|
||||
""" Helper method to determine what the `limit` should be """
|
||||
try:
|
||||
limit = int(self.request.GET.get('limit', self.limit))
|
||||
return min(limit, self.limit)
|
||||
except ValueError:
|
||||
return self.limit
|
||||
|
||||
def url_with_page_number(self, page_number):
|
||||
""" Constructs a url used for getting the next/previous urls """
|
||||
url = "%s?page=%d" % (self.request.path, page_number)
|
||||
|
||||
limit = self.get_limit()
|
||||
if limit != self.limit:
|
||||
url = "%s&limit=%d" % (url, limit)
|
||||
|
||||
return url
|
||||
|
||||
def next(self, page):
|
||||
""" Returns a url to the next page of results (if any) """
|
||||
if not page.has_next():
|
||||
return None
|
||||
|
||||
return self.url_with_page_number(page.next_page_number())
|
||||
|
||||
def previous(self, page):
|
||||
""" Returns a url to the previous page of results (if any) """
|
||||
if not page.has_previous():
|
||||
return None
|
||||
|
||||
return self.url_with_page_number(page.previous_page_number())
|
||||
|
||||
def serialize_page_info(self, page):
|
||||
""" This is some useful information that is added to the response """
|
||||
return {
|
||||
'next': self.next(page),
|
||||
'page': page.number,
|
||||
'pages': page.paginator.num_pages,
|
||||
'per_page': self.get_limit(),
|
||||
'previous': self.previous(page),
|
||||
'total': page.paginator.count,
|
||||
}
|
||||
|
||||
def filter_response(self, obj):
|
||||
"""
|
||||
Given the response content, paginate and then serialize.
|
||||
|
||||
The response is modified to include to useful data relating to the number
|
||||
of objects, number of pages, next/previous urls etc. etc.
|
||||
|
||||
The serialised objects are put into `results` on this new, modified
|
||||
response
|
||||
"""
|
||||
|
||||
# We don't want to paginate responses for anything other than GET requests
|
||||
if self.method.upper() != 'GET':
|
||||
return self._resource.filter_response(obj)
|
||||
|
||||
paginator = Paginator(obj, self.get_limit())
|
||||
|
||||
try:
|
||||
page_num = int(self.request.GET.get('page', '1'))
|
||||
except ValueError:
|
||||
raise ErrorResponse(status.HTTP_404_NOT_FOUND,
|
||||
{'detail': 'That page contains no results'})
|
||||
|
||||
if page_num not in paginator.page_range:
|
||||
raise ErrorResponse(status.HTTP_404_NOT_FOUND,
|
||||
{'detail': 'That page contains no results'})
|
||||
|
||||
page = paginator.page(page_num)
|
||||
|
||||
serialized_object_list = self._resource.filter_response(page.object_list)
|
||||
serialized_page_info = self.serialize_page_info(page)
|
||||
|
||||
serialized_page_info['results'] = serialized_object_list
|
||||
|
||||
return serialized_page_info
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
"""Tests for the status module"""
|
||||
"""Tests for the mixin module"""
|
||||
from django.test import TestCase
|
||||
from django.utils import simplejson as json
|
||||
from djangorestframework import status
|
||||
from djangorestframework.compat import RequestFactory
|
||||
from django.contrib.auth.models import Group, User
|
||||
from djangorestframework.mixins import CreateModelMixin
|
||||
from djangorestframework.mixins import CreateModelMixin, PaginatorMixin
|
||||
from djangorestframework.resources import ModelResource
|
||||
from djangorestframework.response import Response
|
||||
from djangorestframework.tests.models import CustomUser
|
||||
from djangorestframework.views import View
|
||||
|
||||
|
||||
class TestModelCreation(TestCase):
|
||||
class TestModelCreation(TestCase):
|
||||
"""Tests on CreateModelMixin"""
|
||||
|
||||
def setUp(self):
|
||||
|
@ -25,23 +28,26 @@ class TestModelCreation(TestCase):
|
|||
mixin = CreateModelMixin()
|
||||
mixin.resource = GroupResource
|
||||
mixin.CONTENT = form_data
|
||||
|
||||
|
||||
response = mixin.post(request)
|
||||
self.assertEquals(1, Group.objects.count())
|
||||
self.assertEquals('foo', response.cleaned_content.name)
|
||||
|
||||
|
||||
def test_creation_with_m2m_relation(self):
|
||||
class UserResource(ModelResource):
|
||||
model = User
|
||||
|
||||
|
||||
def url(self, instance):
|
||||
return "/users/%i" % instance.id
|
||||
|
||||
group = Group(name='foo')
|
||||
group.save()
|
||||
|
||||
form_data = {'username': 'bar', 'password': 'baz', 'groups': [group.id]}
|
||||
form_data = {
|
||||
'username': 'bar',
|
||||
'password': 'baz',
|
||||
'groups': [group.id]
|
||||
}
|
||||
request = self.req.post('/groups', data=form_data)
|
||||
cleaned_data = dict(form_data)
|
||||
cleaned_data['groups'] = [group]
|
||||
|
@ -53,18 +59,18 @@ class TestModelCreation(TestCase):
|
|||
self.assertEquals(1, User.objects.count())
|
||||
self.assertEquals(1, response.cleaned_content.groups.count())
|
||||
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
|
||||
|
||||
|
||||
def test_creation_with_m2m_relation_through(self):
|
||||
"""
|
||||
Tests creation where the m2m relation uses a through table
|
||||
"""
|
||||
class UserResource(ModelResource):
|
||||
model = CustomUser
|
||||
|
||||
|
||||
def url(self, instance):
|
||||
return "/customusers/%i" % instance.id
|
||||
|
||||
form_data = {'username': 'bar0', 'groups': []}
|
||||
|
||||
form_data = {'username': 'bar0', 'groups': []}
|
||||
request = self.req.post('/groups', data=form_data)
|
||||
cleaned_data = dict(form_data)
|
||||
cleaned_data['groups'] = []
|
||||
|
@ -74,12 +80,12 @@ class TestModelCreation(TestCase):
|
|||
|
||||
response = mixin.post(request)
|
||||
self.assertEquals(1, CustomUser.objects.count())
|
||||
self.assertEquals(0, response.cleaned_content.groups.count())
|
||||
self.assertEquals(0, response.cleaned_content.groups.count())
|
||||
|
||||
group = Group(name='foo1')
|
||||
group.save()
|
||||
|
||||
form_data = {'username': 'bar1', 'groups': [group.id]}
|
||||
form_data = {'username': 'bar1', 'groups': [group.id]}
|
||||
request = self.req.post('/groups', data=form_data)
|
||||
cleaned_data = dict(form_data)
|
||||
cleaned_data['groups'] = [group]
|
||||
|
@ -91,12 +97,11 @@ class TestModelCreation(TestCase):
|
|||
self.assertEquals(2, CustomUser.objects.count())
|
||||
self.assertEquals(1, response.cleaned_content.groups.count())
|
||||
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
|
||||
|
||||
|
||||
|
||||
group2 = Group(name='foo2')
|
||||
group2.save()
|
||||
|
||||
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
|
||||
group2.save()
|
||||
|
||||
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
|
||||
request = self.req.post('/groups', data=form_data)
|
||||
cleaned_data = dict(form_data)
|
||||
cleaned_data['groups'] = [group, group2]
|
||||
|
@ -109,5 +114,124 @@ class TestModelCreation(TestCase):
|
|||
self.assertEquals(2, response.cleaned_content.groups.count())
|
||||
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
|
||||
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
|
||||
|
||||
|
||||
|
||||
class MockPaginatorView(PaginatorMixin, View):
|
||||
total = 60
|
||||
|
||||
def get(self, request):
|
||||
return range(0, self.total)
|
||||
|
||||
def post(self, request):
|
||||
return Response(status.CREATED, {'status': 'OK'})
|
||||
|
||||
|
||||
class TestPagination(TestCase):
|
||||
def setUp(self):
|
||||
self.req = RequestFactory()
|
||||
|
||||
def test_default_limit(self):
|
||||
""" Tests if pagination works without overwriting the limit """
|
||||
request = self.req.get('/paginator')
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
|
||||
content = json.loads(response.content)
|
||||
|
||||
self.assertEqual(response.status_code, status.OK)
|
||||
self.assertEqual(MockPaginatorView.total, content['total'])
|
||||
self.assertEqual(MockPaginatorView.limit, content['per_page'])
|
||||
|
||||
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
|
||||
|
||||
def test_overwriting_limit(self):
|
||||
""" Tests if the limit can be overwritten """
|
||||
limit = 10
|
||||
|
||||
request = self.req.get('/paginator')
|
||||
response = MockPaginatorView.as_view(limit=limit)(request)
|
||||
|
||||
content = json.loads(response.content)
|
||||
|
||||
self.assertEqual(response.status_code, status.OK)
|
||||
self.assertEqual(content['per_page'], limit)
|
||||
|
||||
self.assertEqual(range(0, limit), content['results'])
|
||||
|
||||
def test_limit_param(self):
|
||||
""" Tests if the client can set the limit """
|
||||
from math import ceil
|
||||
|
||||
limit = 5
|
||||
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
|
||||
|
||||
request = self.req.get('/paginator/?limit=%d' % limit)
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
|
||||
content = json.loads(response.content)
|
||||
|
||||
self.assertEqual(response.status_code, status.OK)
|
||||
self.assertEqual(MockPaginatorView.total, content['total'])
|
||||
self.assertEqual(limit, content['per_page'])
|
||||
self.assertEqual(num_pages, content['pages'])
|
||||
|
||||
def test_exceeding_limit(self):
|
||||
""" Makes sure the client cannot exceed the default limit """
|
||||
from math import ceil
|
||||
|
||||
limit = MockPaginatorView.limit + 10
|
||||
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
|
||||
|
||||
request = self.req.get('/paginator/?limit=%d' % limit)
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
|
||||
content = json.loads(response.content)
|
||||
|
||||
self.assertEqual(response.status_code, status.OK)
|
||||
self.assertEqual(MockPaginatorView.total, content['total'])
|
||||
self.assertNotEqual(limit, content['per_page'])
|
||||
self.assertNotEqual(num_pages, content['pages'])
|
||||
self.assertEqual(MockPaginatorView.limit, content['per_page'])
|
||||
|
||||
def test_only_works_for_get(self):
|
||||
""" Pagination should only work for GET requests """
|
||||
request = self.req.post('/paginator', data={'content': 'spam'})
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
|
||||
content = json.loads(response.content)
|
||||
|
||||
self.assertEqual(response.status_code, status.CREATED)
|
||||
self.assertEqual(None, content.get('per_page'))
|
||||
self.assertEqual('OK', content['status'])
|
||||
|
||||
def test_non_int_page(self):
|
||||
""" Tests that it can handle invalid values """
|
||||
request = self.req.get('/paginator/?page=spam')
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
|
||||
self.assertEqual(response.status_code, status.NOT_FOUND)
|
||||
|
||||
def test_page_range(self):
|
||||
""" Tests that the page range is handle correctly """
|
||||
request = self.req.get('/paginator/?page=0')
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
content = json.loads(response.content)
|
||||
self.assertEqual(response.status_code, status.NOT_FOUND)
|
||||
|
||||
request = self.req.get('/paginator/')
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
content = json.loads(response.content)
|
||||
self.assertEqual(response.status_code, status.OK)
|
||||
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
|
||||
|
||||
num_pages = content['pages']
|
||||
|
||||
request = self.req.get('/paginator/?page=%d' % num_pages)
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
content = json.loads(response.content)
|
||||
self.assertEqual(response.status_code, status.OK)
|
||||
self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results'])
|
||||
|
||||
request = self.req.get('/paginator/?page=%d' % (num_pages + 1,))
|
||||
response = MockPaginatorView.as_view()(request)
|
||||
content = json.loads(response.content)
|
||||
self.assertEqual(response.status_code, status.NOT_FOUND)
|
||||
|
|
Loading…
Reference in New Issue
Block a user