Merge branch 'restframework2-filter' of git://github.com/onepercentclub/django-rest-framework into filtering

This commit is contained in:
Tom Christie 2012-11-07 20:13:27 +00:00
commit 9fd061a0b6
9 changed files with 292 additions and 9 deletions

View File

@ -11,6 +11,7 @@ env:
install: install:
- pip install $DJANGO - pip install $DJANGO
- pip install -r requirements.txt --use-mirrors
- export PYTHONPATH=. - export PYTHONPATH=.
script: script:

View File

@ -1 +1,2 @@
Django>=1.3 Django>=1.3
-e git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter

View File

@ -6,7 +6,7 @@ from rest_framework import views, mixins
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin from django.views.generic.detail import SingleObjectMixin
from django.views.generic.list import MultipleObjectMixin from django.views.generic.list import MultipleObjectMixin
import django_filters
### Base classes for the generic views ### ### Base classes for the generic views ###
@ -58,6 +58,37 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY paginate_by = api_settings.PAGINATE_BY
filter_class = None
filter_fields = None
def get_filter_class(self):
"""
Return the django-filters `FilterSet` used to filter the queryset.
"""
if self.filter_class:
return self.filter_class
if self.filter_fields:
class AutoFilterSet(django_filters.FilterSet):
class Meta:
model = self.model
fields = self.filter_fields
return AutoFilterSet
return None
def filter_queryset(self, queryset):
filter_class = self.get_filter_class()
if filter_class:
assert issubclass(filter_class.Meta.model, self.model), \
"%s is not a subclass of %s" % (filter_class.Meta.model, self.model)
return filter_class(self.request.GET, queryset=queryset)
return queryset
def get_filtered_queryset(self):
return self.filter_queryset(self.get_queryset())
def get_pagination_serializer_class(self): def get_pagination_serializer_class(self):
""" """

View File

@ -34,7 +34,7 @@ class ListModelMixin(object):
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
self.object_list = self.get_queryset() self.object_list = self.get_filtered_queryset()
# Default is to allow empty querysets. This can be altered by setting # Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets. # `.allow_empty = False`, to raise 404 errors on empty querysets.

View File

@ -3,7 +3,11 @@ from rest_framework import serializers
# TODO: Support URLconf kwarg-style paging # TODO: Support URLconf kwarg-style paging
class NextPageField(serializers.Field): class PageField(serializers.Field):
page_field = 'page'
class NextPageField(PageField):
""" """
Field that returns a link to the next page in paginated results. Field that returns a link to the next page in paginated results.
""" """
@ -12,13 +16,16 @@ class NextPageField(serializers.Field):
return None return None
page = value.next_page_number() page = value.next_page_number()
request = self.context.get('request') request = self.context.get('request')
relative_url = '?page=%d' % page relative_url = '?%s=%d' % (self.page_field, page)
if request: if request:
for field, value in request.QUERY_PARAMS.iteritems():
if field != self.page_field:
relative_url += '&%s=%s' % (field, value)
return request.build_absolute_uri(relative_url) return request.build_absolute_uri(relative_url)
return relative_url return relative_url
class PreviousPageField(serializers.Field): class PreviousPageField(PageField):
""" """
Field that returns a link to the previous page in paginated results. Field that returns a link to the previous page in paginated results.
""" """
@ -27,9 +34,12 @@ class PreviousPageField(serializers.Field):
return None return None
page = value.previous_page_number() page = value.previous_page_number()
request = self.context.get('request') request = self.context.get('request')
relative_url = '?page=%d' % page relative_url = '?%s=%d' % (self.page_field, page)
if request: if request:
return request.build_absolute_uri('?page=%d' % page) for field, value in request.QUERY_PARAMS.iteritems():
if field != self.page_field:
relative_url += '&%s=%s' % (field, value)
return request.build_absolute_uri(relative_url)
return relative_url return relative_url

View File

@ -0,0 +1,160 @@
import datetime
from decimal import Decimal
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status
from rest_framework.tests.models import FilterableItem, BasicModel
import django_filters
factory = RequestFactory()
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_fields = ['decimal', 'date']
# These class are used to test a filter class.
class SeveralFieldsFilter(django_filters.FilterSet):
text = django_filters.CharFilter(lookup_type='icontains')
decimal = django_filters.NumberFilter(lookup_type='lt')
date = django_filters.DateFilter(lookup_type='gt')
class Meta:
model = FilterableItem
fields = ['text', 'decimal', 'date']
class FilterClassRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_class = SeveralFieldsFilter
# These classes are used to test a misconfigured filter class.
class MisconfiguredFilter(django_filters.FilterSet):
text = django_filters.CharFilter(lookup_type='icontains')
class Meta:
model = BasicModel
fields = ['text']
class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_class = MisconfiguredFilter
class IntegrationTestFiltering(TestCase):
"""
Integration tests for filtered list views.
"""
def setUp(self):
"""
Create 10 FilterableItem instances.
"""
base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
for i in range(10):
text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
decimal = base_data[1] + i
date = base_data[2] - datetime.timedelta(days=i * 2)
FilterableItem(text=text, decimal=decimal, date=date).save()
self.objects = FilterableItem.objects
self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
for obj in self.objects.all()
]
def test_get_filtered_fields_root_view(self):
"""
GET requests to paginated ListCreateAPIView should return paginated results.
"""
view = FilterFieldsRootView.as_view()
# Basic test with no filter.
request = factory.get('/')
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
# Tests that the decimal filter works.
search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [ f for f in self.data if f['decimal'] == search_decimal ]
self.assertEquals(response.data, expected_data)
# Tests that the date filter works.
search_date = datetime.date(2012, 9, 22)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [ f for f in self.data if f['date'] == search_date ]
self.assertEquals(response.data, expected_data)
def test_get_filtered_class_root_view(self):
"""
GET requests to filtered ListCreateAPIView that have a filter_class set
should return filtered results.
"""
view = FilterClassRootView.as_view()
# Basic test with no filter.
request = factory.get('/')
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
# Tests that the decimal filter set with 'lt' in the filter class works.
search_decimal = Decimal('4.25')
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [ f for f in self.data if f['decimal'] < search_decimal ]
self.assertEquals(response.data, expected_data)
# Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [ f for f in self.data if f['date'] > search_date ]
self.assertEquals(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works.
search_text = 'ff'
request = factory.get('/?text=%s' % search_text)
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [ f for f in self.data if search_text in f['text'].lower() ]
self.assertEquals(response.data, expected_data)
# Tests that multiple filters works.
search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2)
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [ f for f in self.data if f['date'] > search_date and
f['decimal'] < search_decimal ]
self.assertEquals(response.data, expected_data)
def test_incorrectly_configured_filter(self):
"""
An error should be displayed when the filter class is misconfigured.
"""
view = IncorrectlyConfiguredRootView.as_view()
request = factory.get('/')
self.assertRaises(AssertionError, view, request)
def test_unknown_filter(self):
"""
GET requests with filters that aren't configured should return 200.
"""
view = FilterFieldsRootView.as_view()
search_integer = 10
request = factory.get('/?integer=%s' % search_integer)
response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)

View File

@ -95,6 +95,13 @@ class Bookmark(RESTFrameworkModel):
tags = GenericRelation(TaggedItem) tags = GenericRelation(TaggedItem)
# Model to test filtering.
class FilterableItem(RESTFrameworkModel):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
# Model for regression test for #285 # Model for regression test for #285
class Comment(RESTFrameworkModel): class Comment(RESTFrameworkModel):

View File

@ -1,8 +1,11 @@
import datetime
from decimal import Decimal
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from rest_framework import generics, status, pagination from rest_framework import generics, status, pagination
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import BasicModel, FilterableItem
import django_filters
factory = RequestFactory() factory = RequestFactory()
@ -15,6 +18,19 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10 paginate_by = 10
class DecimalFilter(django_filters.FilterSet):
decimal = django_filters.NumberFilter(lookup_type='lt')
class Meta:
model = FilterableItem
fields = ['text', 'decimal', 'date']
class FilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
paginate_by = 10
filter_class = DecimalFilter
class IntegrationTestPagination(TestCase): class IntegrationTestPagination(TestCase):
""" """
Integration tests for paginated list views. Integration tests for paginated list views.
@ -22,7 +38,7 @@ class IntegrationTestPagination(TestCase):
def setUp(self): def setUp(self):
""" """
Create 26 BasicModel intances. Create 26 BasicModel instances.
""" """
for char in 'abcdefghijklmnopqrstuvwxyz': for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save() BasicModel(text=char * 3).save()
@ -62,6 +78,57 @@ class IntegrationTestPagination(TestCase):
self.assertNotEquals(response.data['previous'], None) self.assertNotEquals(response.data['previous'], None)
class IntegrationTestPaginationAndFiltering(TestCase):
def setUp(self):
"""
Create 50 FilterableItem instances.
"""
base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
for i in range(26):
text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
decimal = base_data[1] + i
date = base_data[2] - datetime.timedelta(days=i * 2)
FilterableItem(text=text, decimal=decimal, date=date).save()
self.objects = FilterableItem.objects
self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
for obj in self.objects.all()
]
self.view = FilterFieldsRootView.as_view()
def test_get_paginated_filtered_root_view(self):
"""
GET requests to paginated filtered ListCreateAPIView should return
paginated results. The next and previous links should preserve the
filtered parameters.
"""
request = factory.get('/?decimal=15.20')
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 15)
self.assertEquals(response.data['results'], self.data[:10])
self.assertNotEquals(response.data['next'], None)
self.assertEquals(response.data['previous'], None)
request = factory.get(response.data['next'])
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 15)
self.assertEquals(response.data['results'], self.data[10:15])
self.assertEquals(response.data['next'], None)
self.assertNotEquals(response.data['previous'], None)
request = factory.get(response.data['previous'])
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 15)
self.assertEquals(response.data['results'], self.data[:10])
self.assertNotEquals(response.data['next'], None)
self.assertEquals(response.data['previous'], None)
class UnitTestPagination(TestCase): class UnitTestPagination(TestCase):
""" """
Unit tests for pagination of primative objects. Unit tests for pagination of primative objects.

View File

@ -8,23 +8,29 @@ commands = {envpython} rest_framework/runtests/runtests.py
[testenv:py2.7-django1.5] [testenv:py2.7-django1.5]
basepython = python2.7 basepython = python2.7
deps = https://github.com/django/django/zipball/master deps = https://github.com/django/django/zipball/master
git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.7-django1.4] [testenv:py2.7-django1.4]
basepython = python2.7 basepython = python2.7
deps = django==1.4.1 deps = django==1.4.1
git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.7-django1.3] [testenv:py2.7-django1.3]
basepython = python2.7 basepython = python2.7
deps = django==1.3.3 deps = django==1.3.3
git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.6-django1.5] [testenv:py2.6-django1.5]
basepython = python2.6 basepython = python2.6
deps = https://github.com/django/django/zipball/master deps = https://github.com/django/django/zipball/master
git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.6-django1.4] [testenv:py2.6-django1.4]
basepython = python2.6 basepython = python2.6
deps = django==1.4.1 deps = django==1.4.1
git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter
[testenv:py2.6-django1.3] [testenv:py2.6-django1.3]
basepython = python2.6 basepython = python2.6
deps = django==1.3.3 deps = django==1.3.3
git+https://github.com/alex/django-filter.git@0e4b3d703b31574922ab86fc78a86164aad0c1d0#egg=django-filter