From ae73a191f914bcdf63e4cefa3d433b7621d93510 Mon Sep 17 00:00:00 2001 From: David Medina Date: Thu, 10 Oct 2013 13:39:40 +0200 Subject: [PATCH] Add Link and Off&lim pagination tests --- rest_framework/tests/test_pagination.py | 210 +++++++++++++++++++++++- 1 file changed, 209 insertions(+), 1 deletion(-) diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index d6bc7895c..d2f2bdf8c 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -5,7 +5,8 @@ from django.db import models from django.core.paginator import Paginator from django.test import TestCase from django.utils import unittest -from rest_framework import generics, status, pagination, filters, serializers +from rest_framework import (generics, status, pagination, filters, serializers, + mixins) from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel @@ -13,6 +14,48 @@ from rest_framework.tests.models import BasicModel factory = APIRequestFactory() +def parse_header_link(value): + # Got it from python-requests code + # https://github.com/kennethreitz/requests/blob/master/requests/utils.py#L467 + def parse_header_links(value): + """Return a dict of parsed link headers proxies. + + i.e. Link: ; rel=front; type="image/jpeg",; rel=back;type="image/jpeg" + + """ + + links = [] + + replace_chars = " '\"" + + for val in value.split(","): + try: + url, params = val.split(";", 1) + except ValueError: + url, params = val, '' + + link = {} + + link["url"] = url.strip("<> '\"") + + for param in params.split(";"): + try: + key, value = param.split("=") + except ValueError: + break + + link[key.strip(replace_chars)] = value.strip(replace_chars) + + links.append(link) + + return links + links = parse_header_links(value) + link_map = {} + for link in links: + link_map[link['rel']] = link['url'] + return link_map + + class FilterableItem(models.Model): text = models.CharField(max_length=100) decimal = models.DecimalField(max_digits=4, decimal_places=2) @@ -52,6 +95,18 @@ class MaxPaginateByView(generics.ListAPIView): paginate_by_param = 'page_size' +class LinkPaginationView(mixins.LinkPaginationMixin, + generics.ListCreateAPIView): + model = BasicModel + paginate_by = 10 + + +class OffsetLimitPaginationView(mixins.OffsetLimitPaginationMixin, + generics.ListCreateAPIView): + model = BasicModel + paginate_by = 10 + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -103,6 +158,102 @@ class IntegrationTestPagination(TestCase): self.assertNotEqual(response.data['previous'], None) +class IntegrationTestLinkPagination(TestCase): + + def setUp(self): + """ + Create 26 BasicModel instances. + """ + for char in 'abcdefghijklmnopqrstuvwxyz': + BasicModel(text=char * 3).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = LinkPaginationView.as_view() + + def test_get_paginated_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + request = factory.get('/') + # Note: Database queries are a `SELECT COUNT`, and `SELECT ` + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0:10]) + link_header = parse_header_link(response['Link']) + self.assertTrue('next' in link_header) + self.assertTrue('last' in link_header) + self.assertFalse('previous' in link_header) + self.assertFalse('first' in link_header) + + request = factory.get(link_header['next']) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[10:20]) + link_header = parse_header_link(response['Link']) + self.assertTrue('next' in link_header) + self.assertTrue('last' in link_header) + self.assertTrue('previous' in link_header) + self.assertTrue('first' in link_header) + + request = factory.get(link_header['last']) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[20:]) + link_header = parse_header_link(response['Link']) + self.assertFalse('next' in link_header) + self.assertFalse('last' in link_header) + self.assertTrue('previous' in link_header) + self.assertTrue('first' in link_header) + + +class IntegrationTestOffsetLimitPagination(TestCase): + + def setUp(self): + """ + Create 26 BasicModel instances. + """ + for char in 'abcdefghijklmnopqrstuvwxyz': + BasicModel(text=char * 3).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = OffsetLimitPaginationView.as_view() + + def test_get_paginated_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + request = factory.get('/') + # Note: Database queries are a `SELECT COUNT`, and `SELECT ` + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['results'], self.data[0:10]) + self.assertEqual(response.data['count'], 26) + + request = factory.get('/', {'offset': 10}) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['results'], self.data[10:20]) + self.assertEqual(response.data['count'], 26) + + request = factory.get('/', {'offset': 20, 'limit': 3}) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['results'], self.data[20:23]) + self.assertEqual(response.data['count'], 26) + + class IntegrationTestPaginationAndFiltering(TestCase): def setUp(self): @@ -258,6 +409,63 @@ class UnitTestPagination(TestCase): self.assertEqual(serializer.context, results.context) +class UnitTestLinkPagination(TestCase): + + def setUp(self): + self.objects = range(10) + paginator = Paginator(self.objects, 2) + self.first_page = paginator.page(1) + self.third_page = paginator.page(3) + self.last_page = paginator.page(5) + + def test_native_pagination(self): + serializer = pagination.LinkPaginationSerializer(self.first_page) + self.assertEqual(serializer.data['next'], '?page=2') + self.assertEqual(serializer.data['previous'], None) + self.assertEqual(serializer.data['first'], None) + self.assertEqual(serializer.data['last'], '?page=5') + + serializer = pagination.LinkPaginationSerializer(self.third_page) + self.assertEqual(serializer.data['next'], '?page=4') + self.assertEqual(serializer.data['previous'], '?page=2') + self.assertEqual(serializer.data['first'], '?page=1') + self.assertEqual(serializer.data['last'], '?page=5') + + serializer = pagination.LinkPaginationSerializer(self.last_page) + self.assertEqual(serializer.data['next'], None) + self.assertEqual(serializer.data['previous'], '?page=4') + self.assertEqual(serializer.data['first'], '?page=1') + self.assertEqual(serializer.data['last'], None) + + +class UnitTestOffsetLimitPagination(TestCase): + + def setUp(self): + self.objects = range(10) + self.first_page = pagination.OffsetLimitPage( + self.objects, offset=0, limit=2) + self.third_page = pagination.OffsetLimitPage( + self.objects, offset=4, limit=2) + self.last_page = pagination.OffsetLimitPage( + self.objects, offset=8, limit=2) + + def test_native_pagination(self): + serializer = pagination.OffsetLimitPaginationSerializer( + self.first_page) + self.assertEqual(serializer.data['count'], 10) + self.assertEqual(serializer.data['results'], self.objects[:2]) + + serializer = pagination.OffsetLimitPaginationSerializer( + self.third_page) + self.assertEqual(serializer.data['count'], 10) + self.assertEqual(serializer.data['results'], self.objects[4:6]) + + serializer = pagination.OffsetLimitPaginationSerializer( + self.last_page) + self.assertEqual(serializer.data['count'], 10) + self.assertEqual(serializer.data['results'], self.objects[8:]) + + class TestUnpaginated(TestCase): """ Tests for list views without pagination.