mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-01 00:17:40 +03:00 
			
		
		
		
	Add pagination. Thanks @devioustree!
This commit is contained in:
		
							parent
							
								
									42cdd00591
								
							
						
					
					
						commit
						5db422c9d3
					
				|  | @ -1,23 +1,20 @@ | |||
| """ | ||||
| The :mod:`mixins` module provides a set of reusable `mixin`  | ||||
| The :mod:`mixins` module provides a set of reusable `mixin` | ||||
| classes that can be added to a `View`. | ||||
| """ | ||||
| 
 | ||||
| from django.contrib.auth.models import AnonymousUser | ||||
| from django.db.models.query import QuerySet | ||||
| from django.core.paginator import Paginator | ||||
| from django.db.models.fields.related import ForeignKey | ||||
| from django.http import HttpResponse | ||||
| 
 | ||||
| from djangorestframework import status | ||||
| from djangorestframework.parsers import FormParser, MultiPartParser | ||||
| from djangorestframework.renderers import BaseRenderer | ||||
| from djangorestframework.resources import Resource, FormResource, ModelResource | ||||
| from djangorestframework.response import Response, ErrorResponse | ||||
| from djangorestframework.utils import as_tuple, MSIE_USER_AGENT_REGEX | ||||
| from djangorestframework.utils.mediatypes import is_form_media_type, order_by_precedence | ||||
| 
 | ||||
| from decimal import Decimal | ||||
| import re | ||||
| from StringIO import StringIO | ||||
| 
 | ||||
| 
 | ||||
|  | @ -52,7 +49,7 @@ class RequestMixin(object): | |||
| 
 | ||||
|     """ | ||||
|     The set of request parsers that the view can handle. | ||||
|      | ||||
| 
 | ||||
|     Should be a tuple/list of classes as described in the :mod:`parsers` module. | ||||
|     """ | ||||
|     parsers = () | ||||
|  | @ -158,7 +155,7 @@ class RequestMixin(object): | |||
|         # We only need to use form overloading on form POST requests. | ||||
|         if not self._USE_FORM_OVERLOADING or self._method != 'POST' or not is_form_media_type(self._content_type): | ||||
|             return | ||||
|          | ||||
| 
 | ||||
|         # At this point we're committed to parsing the request as form data. | ||||
|         self._data = data = self.request.POST.copy() | ||||
|         self._files = self.request.FILES | ||||
|  | @ -203,12 +200,12 @@ class RequestMixin(object): | |||
|         """ | ||||
|         return [parser.media_type for parser in self.parsers] | ||||
| 
 | ||||
|      | ||||
| 
 | ||||
|     @property | ||||
|     def _default_parser(self): | ||||
|         """ | ||||
|         Return the view's default parser class. | ||||
|         """         | ||||
|         """ | ||||
|         return self.parsers[0] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -218,7 +215,7 @@ class RequestMixin(object): | |||
| class ResponseMixin(object): | ||||
|     """ | ||||
|     Adds behavior for pluggable `Renderers` to a :class:`views.View` class. | ||||
|      | ||||
| 
 | ||||
|     Default behavior is to use standard HTTP Accept header content negotiation. | ||||
|     Also supports overriding the content type by specifying an ``_accept=`` parameter in the URL. | ||||
|     Ignores Accept headers from Internet Explorer user agents and uses a sensible browser Accept header instead. | ||||
|  | @ -229,8 +226,8 @@ class ResponseMixin(object): | |||
| 
 | ||||
|     """ | ||||
|     The set of response renderers that the view can handle. | ||||
|      | ||||
|     Should be a tuple/list of classes as described in the :mod:`renderers` module.     | ||||
| 
 | ||||
|     Should be a tuple/list of classes as described in the :mod:`renderers` module. | ||||
|     """ | ||||
|     renderers = () | ||||
| 
 | ||||
|  | @ -253,7 +250,7 @@ class ResponseMixin(object): | |||
|         # Set the media type of the response | ||||
|         # Note that the renderer *could* override it in .render() if required. | ||||
|         response.media_type = renderer.media_type | ||||
|          | ||||
| 
 | ||||
|         # Serialize the response content | ||||
|         if response.has_content_body: | ||||
|             content = renderer.render(response.cleaned_content, media_type) | ||||
|  | @ -317,7 +314,7 @@ class ResponseMixin(object): | |||
|         Return an list of all the media types that this view can render. | ||||
|         """ | ||||
|         return [renderer.media_type for renderer in self.renderers] | ||||
|      | ||||
| 
 | ||||
|     @property | ||||
|     def _rendered_formats(self): | ||||
|         """ | ||||
|  | @ -339,18 +336,18 @@ class AuthMixin(object): | |||
|     """ | ||||
|     Simple :class:`mixin` class to add authentication and permission checking to a :class:`View` class. | ||||
|     """ | ||||
|      | ||||
| 
 | ||||
|     """ | ||||
|     The set of authentication types that this view can handle. | ||||
|     | ||||
|     Should be a tuple/list of classes as described in the :mod:`authentication` module.     | ||||
| 
 | ||||
|     Should be a tuple/list of classes as described in the :mod:`authentication` module. | ||||
|     """ | ||||
|     authentication = () | ||||
| 
 | ||||
|     """ | ||||
|     The set of permissions that will be enforced on this view. | ||||
|      | ||||
|     Should be a tuple/list of classes as described in the :mod:`permissions` module.     | ||||
| 
 | ||||
|     Should be a tuple/list of classes as described in the :mod:`permissions` module. | ||||
|     """ | ||||
|     permissions = () | ||||
| 
 | ||||
|  | @ -359,7 +356,7 @@ class AuthMixin(object): | |||
|     def user(self): | ||||
|         """ | ||||
|         Returns the :obj:`user` for the current request, as determined by the set of | ||||
|         :class:`authentication` classes applied to the :class:`View`.   | ||||
|         :class:`authentication` classes applied to the :class:`View`. | ||||
|         """ | ||||
|         if not hasattr(self, '_user'): | ||||
|             self._user = self._authenticate() | ||||
|  | @ -541,13 +538,13 @@ class CreateModelMixin(object): | |||
| 
 | ||||
|         for fieldname in m2m_data: | ||||
|             manager = getattr(instance, fieldname) | ||||
|              | ||||
| 
 | ||||
|             if hasattr(manager, 'add'): | ||||
|                 manager.add(*m2m_data[fieldname][1]) | ||||
|             else: | ||||
|                 data = {} | ||||
|                 data[manager.source_field_name] = instance | ||||
|                  | ||||
| 
 | ||||
|                 for related_item in m2m_data[fieldname][1]: | ||||
|                     data[m2m_data[fieldname][0]] = related_item | ||||
|                     manager.through(**data).save() | ||||
|  | @ -564,8 +561,8 @@ class UpdateModelMixin(object): | |||
|     """ | ||||
|     def put(self, request, *args, **kwargs): | ||||
|         model = self.resource.model | ||||
|          | ||||
|         # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url  | ||||
| 
 | ||||
|         # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url | ||||
|         try: | ||||
|             if args: | ||||
|                 # If we have any none kwargs then assume the last represents the primary key | ||||
|  | @ -640,3 +637,93 @@ class ListModelMixin(object): | |||
|         return queryset.filter(**kwargs) | ||||
| 
 | ||||
| 
 | ||||
| ########## Pagination Mixins ########## | ||||
| 
 | ||||
| class PaginatorMixin(object): | ||||
|     """ | ||||
|     Adds pagination support to GET requests | ||||
|     Obviously should only be used on lists :) | ||||
| 
 | ||||
|     A default limit can be set by setting `limit` on the object. This will also | ||||
|     be used as the maximum if the client sets the `limit` GET param | ||||
|     """ | ||||
|     limit = 20 | ||||
| 
 | ||||
|     def get_limit(self): | ||||
|         """ Helper method to determine what the `limit` should be """ | ||||
|         try: | ||||
|             limit = int(self.request.GET.get('limit', self.limit)) | ||||
|             return min(limit, self.limit) | ||||
|         except ValueError: | ||||
|             return self.limit | ||||
| 
 | ||||
|     def url_with_page_number(self, page_number): | ||||
|         """ Constructs a url used for getting the next/previous urls """ | ||||
|         url = "%s?page=%d" % (self.request.path, page_number) | ||||
| 
 | ||||
|         limit = self.get_limit() | ||||
|         if limit != self.limit: | ||||
|             url = "%s&limit=%d" % (url, limit) | ||||
| 
 | ||||
|         return url | ||||
| 
 | ||||
|     def next(self, page): | ||||
|         """ Returns a url to the next page of results (if any) """ | ||||
|         if not page.has_next(): | ||||
|             return None | ||||
| 
 | ||||
|         return self.url_with_page_number(page.next_page_number()) | ||||
| 
 | ||||
|     def previous(self, page): | ||||
|         """ Returns a url to the previous page of results (if any) """ | ||||
|         if not page.has_previous(): | ||||
|             return None | ||||
| 
 | ||||
|         return self.url_with_page_number(page.previous_page_number()) | ||||
| 
 | ||||
|     def serialize_page_info(self, page): | ||||
|         """ This is some useful information that is added to the response """ | ||||
|         return { | ||||
|             'next': self.next(page), | ||||
|             'page': page.number, | ||||
|             'pages': page.paginator.num_pages, | ||||
|             'per_page': self.get_limit(), | ||||
|             'previous': self.previous(page), | ||||
|             'total': page.paginator.count, | ||||
|         } | ||||
| 
 | ||||
|     def filter_response(self, obj): | ||||
|         """ | ||||
|         Given the response content, paginate and then serialize. | ||||
| 
 | ||||
|         The response is modified to include to useful data relating to the number | ||||
|         of objects, number of pages, next/previous urls etc. etc. | ||||
| 
 | ||||
|         The serialised objects are put into `results` on this new, modified | ||||
|         response | ||||
|         """ | ||||
| 
 | ||||
|         # We don't want to paginate responses for anything other than GET requests | ||||
|         if self.method.upper() != 'GET': | ||||
|             return self._resource.filter_response(obj) | ||||
| 
 | ||||
|         paginator = Paginator(obj, self.get_limit()) | ||||
| 
 | ||||
|         try: | ||||
|             page_num = int(self.request.GET.get('page', '1')) | ||||
|         except ValueError: | ||||
|             raise ErrorResponse(status.HTTP_404_NOT_FOUND, | ||||
|                                 {'detail': 'That page contains no results'}) | ||||
| 
 | ||||
|         if page_num not in paginator.page_range: | ||||
|             raise ErrorResponse(status.HTTP_404_NOT_FOUND, | ||||
|                                 {'detail': 'That page contains no results'}) | ||||
| 
 | ||||
|         page = paginator.page(page_num) | ||||
| 
 | ||||
|         serialized_object_list = self._resource.filter_response(page.object_list) | ||||
|         serialized_page_info = self.serialize_page_info(page) | ||||
| 
 | ||||
|         serialized_page_info['results'] = serialized_object_list | ||||
| 
 | ||||
|         return serialized_page_info | ||||
|  |  | |||
|  | @ -1,14 +1,17 @@ | |||
| """Tests for the status module""" | ||||
| """Tests for the mixin module""" | ||||
| from django.test import TestCase | ||||
| from django.utils import simplejson as json | ||||
| from djangorestframework import status | ||||
| from djangorestframework.compat import RequestFactory | ||||
| from django.contrib.auth.models import Group, User | ||||
| from djangorestframework.mixins import CreateModelMixin | ||||
| from djangorestframework.mixins import CreateModelMixin, PaginatorMixin | ||||
| from djangorestframework.resources import ModelResource | ||||
| from djangorestframework.response import Response | ||||
| from djangorestframework.tests.models import CustomUser | ||||
| from djangorestframework.views import View | ||||
| 
 | ||||
| 
 | ||||
| class TestModelCreation(TestCase):  | ||||
| class TestModelCreation(TestCase): | ||||
|     """Tests on CreateModelMixin""" | ||||
| 
 | ||||
|     def setUp(self): | ||||
|  | @ -25,23 +28,26 @@ class TestModelCreation(TestCase): | |||
|         mixin = CreateModelMixin() | ||||
|         mixin.resource = GroupResource | ||||
|         mixin.CONTENT = form_data | ||||
|          | ||||
| 
 | ||||
|         response = mixin.post(request) | ||||
|         self.assertEquals(1, Group.objects.count()) | ||||
|         self.assertEquals('foo', response.cleaned_content.name) | ||||
|          | ||||
| 
 | ||||
|     def test_creation_with_m2m_relation(self): | ||||
|         class UserResource(ModelResource): | ||||
|             model = User | ||||
|     | ||||
| 
 | ||||
|             def url(self, instance): | ||||
|                 return "/users/%i" % instance.id | ||||
| 
 | ||||
|         group = Group(name='foo') | ||||
|         group.save() | ||||
| 
 | ||||
|         form_data = {'username': 'bar', 'password': 'baz', 'groups': [group.id]}         | ||||
|         form_data = { | ||||
|             'username': 'bar', | ||||
|             'password': 'baz', | ||||
|             'groups': [group.id] | ||||
|         } | ||||
|         request = self.req.post('/groups', data=form_data) | ||||
|         cleaned_data = dict(form_data) | ||||
|         cleaned_data['groups'] = [group] | ||||
|  | @ -53,18 +59,18 @@ class TestModelCreation(TestCase): | |||
|         self.assertEquals(1, User.objects.count()) | ||||
|         self.assertEquals(1, response.cleaned_content.groups.count()) | ||||
|         self.assertEquals('foo', response.cleaned_content.groups.all()[0].name) | ||||
|          | ||||
| 
 | ||||
|     def test_creation_with_m2m_relation_through(self): | ||||
|         """ | ||||
|         Tests creation where the m2m relation uses a through table | ||||
|         """ | ||||
|         class UserResource(ModelResource): | ||||
|             model = CustomUser | ||||
|     | ||||
| 
 | ||||
|             def url(self, instance): | ||||
|                 return "/customusers/%i" % instance.id | ||||
|              | ||||
|         form_data = {'username': 'bar0', 'groups': []}         | ||||
| 
 | ||||
|         form_data = {'username': 'bar0', 'groups': []} | ||||
|         request = self.req.post('/groups', data=form_data) | ||||
|         cleaned_data = dict(form_data) | ||||
|         cleaned_data['groups'] = [] | ||||
|  | @ -74,12 +80,12 @@ class TestModelCreation(TestCase): | |||
| 
 | ||||
|         response = mixin.post(request) | ||||
|         self.assertEquals(1, CustomUser.objects.count()) | ||||
|         self.assertEquals(0, response.cleaned_content.groups.count())             | ||||
|         self.assertEquals(0, response.cleaned_content.groups.count()) | ||||
| 
 | ||||
|         group = Group(name='foo1') | ||||
|         group.save() | ||||
| 
 | ||||
|         form_data = {'username': 'bar1', 'groups': [group.id]}         | ||||
|         form_data = {'username': 'bar1', 'groups': [group.id]} | ||||
|         request = self.req.post('/groups', data=form_data) | ||||
|         cleaned_data = dict(form_data) | ||||
|         cleaned_data['groups'] = [group] | ||||
|  | @ -91,12 +97,11 @@ class TestModelCreation(TestCase): | |||
|         self.assertEquals(2, CustomUser.objects.count()) | ||||
|         self.assertEquals(1, response.cleaned_content.groups.count()) | ||||
|         self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name) | ||||
|          | ||||
|          | ||||
| 
 | ||||
|         group2 = Group(name='foo2') | ||||
|         group2.save()         | ||||
|          | ||||
|         form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}         | ||||
|         group2.save() | ||||
| 
 | ||||
|         form_data = {'username': 'bar2', 'groups': [group.id, group2.id]} | ||||
|         request = self.req.post('/groups', data=form_data) | ||||
|         cleaned_data = dict(form_data) | ||||
|         cleaned_data['groups'] = [group, group2] | ||||
|  | @ -109,5 +114,124 @@ class TestModelCreation(TestCase): | |||
|         self.assertEquals(2, response.cleaned_content.groups.count()) | ||||
|         self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name) | ||||
|         self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name) | ||||
|          | ||||
| 
 | ||||
| 
 | ||||
| class MockPaginatorView(PaginatorMixin, View): | ||||
|     total = 60 | ||||
| 
 | ||||
|     def get(self, request): | ||||
|         return range(0, self.total) | ||||
| 
 | ||||
|     def post(self, request): | ||||
|         return Response(status.CREATED, {'status': 'OK'}) | ||||
| 
 | ||||
| 
 | ||||
| class TestPagination(TestCase): | ||||
|     def setUp(self): | ||||
|         self.req = RequestFactory() | ||||
| 
 | ||||
|     def test_default_limit(self): | ||||
|         """ Tests if pagination works without overwriting the limit """ | ||||
|         request = self.req.get('/paginator') | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
| 
 | ||||
|         content = json.loads(response.content) | ||||
| 
 | ||||
|         self.assertEqual(response.status_code, status.OK) | ||||
|         self.assertEqual(MockPaginatorView.total, content['total']) | ||||
|         self.assertEqual(MockPaginatorView.limit, content['per_page']) | ||||
| 
 | ||||
|         self.assertEqual(range(0, MockPaginatorView.limit), content['results']) | ||||
| 
 | ||||
|     def test_overwriting_limit(self): | ||||
|         """ Tests if the limit can be overwritten """ | ||||
|         limit = 10 | ||||
| 
 | ||||
|         request = self.req.get('/paginator') | ||||
|         response = MockPaginatorView.as_view(limit=limit)(request) | ||||
| 
 | ||||
|         content = json.loads(response.content) | ||||
| 
 | ||||
|         self.assertEqual(response.status_code, status.OK) | ||||
|         self.assertEqual(content['per_page'], limit) | ||||
| 
 | ||||
|         self.assertEqual(range(0, limit), content['results']) | ||||
| 
 | ||||
|     def test_limit_param(self): | ||||
|         """ Tests if the client can set the limit """ | ||||
|         from math import ceil | ||||
| 
 | ||||
|         limit = 5 | ||||
|         num_pages = int(ceil(MockPaginatorView.total / float(limit))) | ||||
| 
 | ||||
|         request = self.req.get('/paginator/?limit=%d' % limit) | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
| 
 | ||||
|         content = json.loads(response.content) | ||||
| 
 | ||||
|         self.assertEqual(response.status_code, status.OK) | ||||
|         self.assertEqual(MockPaginatorView.total, content['total']) | ||||
|         self.assertEqual(limit, content['per_page']) | ||||
|         self.assertEqual(num_pages, content['pages']) | ||||
| 
 | ||||
|     def test_exceeding_limit(self): | ||||
|         """ Makes sure the client cannot exceed the default limit """ | ||||
|         from math import ceil | ||||
| 
 | ||||
|         limit = MockPaginatorView.limit + 10 | ||||
|         num_pages = int(ceil(MockPaginatorView.total / float(limit))) | ||||
| 
 | ||||
|         request = self.req.get('/paginator/?limit=%d' % limit) | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
| 
 | ||||
|         content = json.loads(response.content) | ||||
| 
 | ||||
|         self.assertEqual(response.status_code, status.OK) | ||||
|         self.assertEqual(MockPaginatorView.total, content['total']) | ||||
|         self.assertNotEqual(limit, content['per_page']) | ||||
|         self.assertNotEqual(num_pages, content['pages']) | ||||
|         self.assertEqual(MockPaginatorView.limit, content['per_page']) | ||||
| 
 | ||||
|     def test_only_works_for_get(self): | ||||
|         """ Pagination should only work for GET requests """ | ||||
|         request = self.req.post('/paginator', data={'content': 'spam'}) | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
| 
 | ||||
|         content = json.loads(response.content) | ||||
| 
 | ||||
|         self.assertEqual(response.status_code, status.CREATED) | ||||
|         self.assertEqual(None, content.get('per_page')) | ||||
|         self.assertEqual('OK', content['status']) | ||||
| 
 | ||||
|     def test_non_int_page(self): | ||||
|         """ Tests that it can handle invalid values """ | ||||
|         request = self.req.get('/paginator/?page=spam') | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
| 
 | ||||
|         self.assertEqual(response.status_code, status.NOT_FOUND) | ||||
| 
 | ||||
|     def test_page_range(self): | ||||
|         """ Tests that the page range is handle correctly """ | ||||
|         request = self.req.get('/paginator/?page=0') | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
|         content = json.loads(response.content) | ||||
|         self.assertEqual(response.status_code, status.NOT_FOUND) | ||||
| 
 | ||||
|         request = self.req.get('/paginator/') | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
|         content = json.loads(response.content) | ||||
|         self.assertEqual(response.status_code, status.OK) | ||||
|         self.assertEqual(range(0, MockPaginatorView.limit), content['results']) | ||||
| 
 | ||||
|         num_pages = content['pages'] | ||||
| 
 | ||||
|         request = self.req.get('/paginator/?page=%d' % num_pages) | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
|         content = json.loads(response.content) | ||||
|         self.assertEqual(response.status_code, status.OK) | ||||
|         self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results']) | ||||
| 
 | ||||
|         request = self.req.get('/paginator/?page=%d' % (num_pages + 1,)) | ||||
|         response = MockPaginatorView.as_view()(request) | ||||
|         content = json.loads(response.content) | ||||
|         self.assertEqual(response.status_code, status.NOT_FOUND) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user