from __future__ import unicode_literals import datetime from decimal import Decimal from django.core.paginator import Paginator from django.test import TestCase from django.utils import unittest from rest_framework import generics, serializers, status, pagination, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from .models import BasicModel, FilterableItem factory = APIRequestFactory() # Helper function to split arguments out of an url def split_arguments_from_url(url): if '?' not in url: return url path, args = url.split('?') args = dict(r.split('=') for r in args.split('&')) return path, args class BasicSerializer(serializers.ModelSerializer): class Meta: model = BasicModel class FilterableItemSerializer(serializers.ModelSerializer): class Meta: model = FilterableItem class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. """ queryset = BasicModel.objects.all() serializer_class = BasicSerializer paginate_by = 10 class DefaultPageSizeKwargView(generics.ListAPIView): """ View for testing default paginate_by_param usage """ queryset = BasicModel.objects.all() serializer_class = BasicSerializer class PaginateByParamView(generics.ListAPIView): """ View for testing custom paginate_by_param usage """ queryset = BasicModel.objects.all() serializer_class = BasicSerializer paginate_by_param = 'page_size' class MaxPaginateByView(generics.ListAPIView): """ View for testing custom max_paginate_by usage """ queryset = BasicModel.objects.all() serializer_class = BasicSerializer paginate_by = 3 max_paginate_by = 5 paginate_by_param = 'page_size' class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. """ def setUp(self): """ Create 26 BasicModel instances. """ for char in 'abcdefghijklmnopqrstuvwxyz': BasicModel(text=char * 3).save() self.objects = BasicModel.objects self.data = [ {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] self.view = RootView.as_view() def test_get_paginated_root_view(self): """ GET requests to paginated ListCreateAPIView should return paginated results. """ request = factory.get('/') # Note: Database queries are a `SELECT COUNT`, and `SELECT ` with self.assertNumQueries(2): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 26) self.assertEqual(response.data['results'], self.data[:10]) self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 26) self.assertEqual(response.data['results'], self.data[10:20]) self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 26) self.assertEqual(response.data['results'], self.data[20:]) self.assertEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) class IntegrationTestPaginationAndFiltering(TestCase): def setUp(self): """ Create 50 FilterableItem instances. """ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) for i in range(26): text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. decimal = base_data[1] + i date = base_data[2] - datetime.timedelta(days=i * 2) FilterableItem(text=text, decimal=decimal, date=date).save() self.objects = FilterableItem.objects self.data = [ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} for obj in self.objects.all() ] @unittest.skipUnless(django_filters, 'django-filter not installed') def test_get_django_filter_paginated_filtered_root_view(self): """ GET requests to paginated filtered ListCreateAPIView should return paginated results. The next and previous links should preserve the filtered parameters. """ class DecimalFilter(django_filters.FilterSet): decimal = django_filters.NumberFilter(lookup_type='lt') class Meta: model = FilterableItem fields = ['text', 'decimal', 'date'] class FilterFieldsRootView(generics.ListCreateAPIView): queryset = FilterableItem.objects.all() serializer_class = FilterableItemSerializer paginate_by = 10 filter_class = DecimalFilter filter_backends = (filters.DjangoFilterBackend,) view = FilterFieldsRootView.as_view() EXPECTED_NUM_QUERIES = 2 request = factory.get('/', {'decimal': '15.20'}) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[:10]) self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[10:15]) self.assertEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) request = factory.get(*split_arguments_from_url(response.data['previous'])) with self.assertNumQueries(EXPECTED_NUM_QUERIES): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[:10]) self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) def test_get_basic_paginated_filtered_root_view(self): """ Same as `test_get_django_filter_paginated_filtered_root_view`, except using a custom filter backend instead of the django-filter backend, """ class DecimalFilterBackend(filters.BaseFilterBackend): def filter_queryset(self, request, queryset, view): return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) class BasicFilterFieldsRootView(generics.ListCreateAPIView): queryset = FilterableItem.objects.all() serializer_class = FilterableItemSerializer paginate_by = 10 filter_backends = (DecimalFilterBackend,) view = BasicFilterFieldsRootView.as_view() request = factory.get('/', {'decimal': '15.20'}) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[:10]) self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) request = factory.get(*split_arguments_from_url(response.data['next'])) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[10:15]) self.assertEqual(response.data['next'], None) self.assertNotEqual(response.data['previous'], None) request = factory.get(*split_arguments_from_url(response.data['previous'])) with self.assertNumQueries(2): response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['count'], 15) self.assertEqual(response.data['results'], self.data[:10]) self.assertNotEqual(response.data['next'], None) self.assertEqual(response.data['previous'], None) class PassOnContextPaginationSerializer(pagination.PaginationSerializer): class Meta: object_serializer_class = serializers.Serializer class UnitTestPagination(TestCase): """ Unit tests for pagination of primitive objects. """ def setUp(self): self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] paginator = Paginator(self.objects, 10) self.first_page = paginator.page(1) self.last_page = paginator.page(3) def test_native_pagination(self): serializer = pagination.PaginationSerializer(self.first_page) self.assertEqual(serializer.data['count'], 26) self.assertEqual(serializer.data['next'], '?page=2') self.assertEqual(serializer.data['previous'], None) self.assertEqual(serializer.data['results'], self.objects[:10]) serializer = pagination.PaginationSerializer(self.last_page) self.assertEqual(serializer.data['count'], 26) self.assertEqual(serializer.data['next'], None) self.assertEqual(serializer.data['previous'], '?page=2') self.assertEqual(serializer.data['results'], self.objects[20:]) def test_context_available_in_result(self): """ Ensure context gets passed through to the object serializer. """ serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) serializer.data results = serializer.fields[serializer.results_field] self.assertEqual(serializer.context, results.context) class TestUnpaginated(TestCase): """ Tests for list views without pagination. """ def setUp(self): """ Create 13 BasicModel instances. """ for i in range(13): BasicModel(text=i).save() self.objects = BasicModel.objects self.data = [ {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] self.view = DefaultPageSizeKwargView.as_view() def test_unpaginated(self): """ Tests the default page size for this view. no page size --> no limit --> no meta data """ request = factory.get('/') response = self.view(request) self.assertEqual(response.data, self.data) class TestCustomPaginateByParam(TestCase): """ Tests for list views with default page size kwarg """ def setUp(self): """ Create 13 BasicModel instances. """ for i in range(13): BasicModel(text=i).save() self.objects = BasicModel.objects self.data = [ {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] self.view = PaginateByParamView.as_view() def test_default_page_size(self): """ Tests the default page size for this view. no page size --> no limit --> no meta data """ request = factory.get('/') response = self.view(request).render() self.assertEqual(response.data, self.data) def test_paginate_by_param(self): """ If paginate_by_param is set, the new kwarg should limit per view requests. """ request = factory.get('/', {'page_size': 5}) response = self.view(request).render() self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['results'], self.data[:5]) class TestMaxPaginateByParam(TestCase): """ Tests for list views with max_paginate_by kwarg """ def setUp(self): """ Create 13 BasicModel instances. """ for i in range(13): BasicModel(text=i).save() self.objects = BasicModel.objects self.data = [ {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] self.view = MaxPaginateByView.as_view() def test_max_paginate_by(self): """ If max_paginate_by is set, it should limit page size for the view. """ request = factory.get('/', data={'page_size': 10}) response = self.view(request).render() self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['results'], self.data[:5]) def test_max_paginate_by_without_page_size_param(self): """ If max_paginate_by is set, but client does not specifiy page_size, standard `paginate_by` behavior should be used. """ request = factory.get('/') response = self.view(request).render() self.assertEqual(response.data['results'], self.data[:3]) # Tests for context in pagination serializers class CustomField(serializers.Field): def to_native(self, value): if 'view' not in self.context: raise RuntimeError("context isn't getting passed into custom field") return "value" class BasicModelSerializer(serializers.Serializer): text = CustomField() def __init__(self, *args, **kwargs): super(BasicModelSerializer, self).__init__(*args, **kwargs) if 'view' not in self.context: raise RuntimeError("context isn't getting passed into serializer init") class TestContextPassedToCustomField(TestCase): def setUp(self): BasicModel.objects.create(text='ala ma kota') def test_with_pagination(self): class ListView(generics.ListCreateAPIView): queryset = BasicModel.objects.all() serializer_class = BasicModelSerializer paginate_by = 1 self.view = ListView.as_view() request = factory.get('/') response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) # Tests for custom pagination serializers class LinksSerializer(serializers.Serializer): next = pagination.NextPageField(source='*') prev = pagination.PreviousPageField(source='*') class CustomPaginationSerializer(pagination.BasePaginationSerializer): links = LinksSerializer(source='*') # Takes the page object as the source total_results = serializers.Field(source='paginator.count') results_field = 'objects' class TestCustomPaginationSerializer(TestCase): def setUp(self): objects = ['john', 'paul', 'george', 'ringo'] paginator = Paginator(objects, 2) self.page = paginator.page(1) def test_custom_pagination_serializer(self): request = APIRequestFactory().get('/foobar') serializer = CustomPaginationSerializer( instance=self.page, context={'request': request} ) expected = { 'links': { 'next': 'http://testserver/foobar?page=2', 'prev': None }, 'total_results': 4, 'objects': ['john', 'paul'] } self.assertEqual(serializer.data, expected) class NonIntegerPage(object): def __init__(self, paginator, object_list, prev_token, token, next_token): self.paginator = paginator self.object_list = object_list self.prev_token = prev_token self.token = token self.next_token = next_token def has_next(self): return not not self.next_token def next_page_number(self): return self.next_token def has_previous(self): return not not self.prev_token def previous_page_number(self): return self.prev_token class NonIntegerPaginator(object): def __init__(self, object_list, per_page): self.object_list = object_list self.per_page = per_page def count(self): # pretend like we don't know how many pages we have return None def page(self, token=None): if token: try: first = self.object_list.index(token) except ValueError: first = 0 else: first = 0 n = len(self.object_list) last = min(first + self.per_page, n) prev_token = self.object_list[last - (2 * self.per_page)] if first else None next_token = self.object_list[last] if last < n else None return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) class TestNonIntegerPagination(TestCase): def test_custom_pagination_serializer(self): objects = ['john', 'paul', 'george', 'ringo'] paginator = NonIntegerPaginator(objects, 2) request = APIRequestFactory().get('/foobar') serializer = CustomPaginationSerializer( instance=paginator.page(), context={'request': request} ) expected = { 'links': { 'next': 'http://testserver/foobar?page={0}'.format(objects[2]), 'prev': None }, 'total_results': None, 'objects': objects[:2] } self.assertEqual(serializer.data, expected) request = APIRequestFactory().get('/foobar') serializer = CustomPaginationSerializer( instance=paginator.page('george'), context={'request': request} ) expected = { 'links': { 'next': None, 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), }, 'total_results': None, 'objects': objects[2:] } self.assertEqual(serializer.data, expected)