mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-23 15:54:16 +03:00
First attempt at adding filter support.
The filter support uses django-filter to work its magic.
This commit is contained in:
parent
83f39b3dce
commit
1e9ece0f93
|
@ -1 +1,2 @@
|
||||||
Django>=1.3
|
Django>=1.3
|
||||||
|
-e git+https://github.com/alex/django-filter.git#egg=django-filter
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
"""
|
"""
|
||||||
Generic views that provide commmonly needed behaviour.
|
Generic views that provide commonly needed behaviour.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from rest_framework import views, mixins
|
from rest_framework import views, mixins
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
from django.views.generic.detail import SingleObjectMixin
|
from django.views.generic.detail import SingleObjectMixin
|
||||||
from django.views.generic.list import MultipleObjectMixin
|
from django.views.generic.list import MultipleObjectMixin
|
||||||
|
import django_filters
|
||||||
|
|
||||||
### Base classes for the generic views ###
|
### Base classes for the generic views ###
|
||||||
|
|
||||||
|
@ -58,6 +58,37 @@ class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
|
||||||
|
|
||||||
pagination_serializer_class = api_settings.PAGINATION_SERIALIZER
|
pagination_serializer_class = api_settings.PAGINATION_SERIALIZER
|
||||||
paginate_by = api_settings.PAGINATE_BY
|
paginate_by = api_settings.PAGINATE_BY
|
||||||
|
filter_class = None
|
||||||
|
filter_fields = None
|
||||||
|
|
||||||
|
def get_filter_class(self):
|
||||||
|
"""
|
||||||
|
Return the django-filters `FilterSet` used to filter the queryset.
|
||||||
|
"""
|
||||||
|
if self.filter_class:
|
||||||
|
return self.filter_class
|
||||||
|
|
||||||
|
if self.filter_fields:
|
||||||
|
class AutoFilterSet(django_filters.FilterSet):
|
||||||
|
class Meta:
|
||||||
|
model = self.model
|
||||||
|
fields = self.filter_fields
|
||||||
|
return AutoFilterSet
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def filter_queryset(self, queryset):
|
||||||
|
filter_class = self.get_filter_class()
|
||||||
|
|
||||||
|
if filter_class:
|
||||||
|
assert issubclass(filter_class.Meta.model, self.model), \
|
||||||
|
"%s is not a subclass of %s" % (filter_class.Meta.model, self.model)
|
||||||
|
return filter_class(self.request.GET, queryset=queryset)
|
||||||
|
|
||||||
|
return queryset
|
||||||
|
|
||||||
|
def get_filtered_queryset(self):
|
||||||
|
return self.filter_queryset(self.get_queryset())
|
||||||
|
|
||||||
def get_pagination_serializer_class(self):
|
def get_pagination_serializer_class(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -33,7 +33,7 @@ class ListModelMixin(object):
|
||||||
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
|
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
|
||||||
|
|
||||||
def list(self, request, *args, **kwargs):
|
def list(self, request, *args, **kwargs):
|
||||||
self.object_list = self.get_queryset()
|
self.object_list = self.get_filtered_queryset()
|
||||||
|
|
||||||
# Default is to allow empty querysets. This can be altered by setting
|
# Default is to allow empty querysets. This can be altered by setting
|
||||||
# `.allow_empty = False`, to raise 404 errors on empty querysets.
|
# `.allow_empty = False`, to raise 404 errors on empty querysets.
|
||||||
|
|
160
rest_framework/tests/filterset.py
Normal file
160
rest_framework/tests/filterset.py
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
import datetime
|
||||||
|
from django.test import TestCase
|
||||||
|
from django.test.client import RequestFactory
|
||||||
|
from rest_framework import generics, status
|
||||||
|
from rest_framework.tests.models import FilterableItem, BasicModel
|
||||||
|
import django_filters
|
||||||
|
|
||||||
|
factory = RequestFactory()
|
||||||
|
|
||||||
|
# Basic filter on a list view.
|
||||||
|
class FilterFieldsRootView(generics.ListCreateAPIView):
|
||||||
|
model = FilterableItem
|
||||||
|
filter_fields = ['decimal', 'date']
|
||||||
|
|
||||||
|
|
||||||
|
# These class are used to test a filter class.
|
||||||
|
class SeveralFieldsFilter(django_filters.FilterSet):
|
||||||
|
text = django_filters.CharFilter(lookup_type='icontains')
|
||||||
|
decimal = django_filters.NumberFilter(lookup_type='lt')
|
||||||
|
date = django_filters.DateFilter(lookup_type='gt')
|
||||||
|
class Meta:
|
||||||
|
model = FilterableItem
|
||||||
|
fields = ['text', 'decimal', 'date']
|
||||||
|
|
||||||
|
|
||||||
|
class FilterClassRootView(generics.ListCreateAPIView):
|
||||||
|
model = FilterableItem
|
||||||
|
filter_class = SeveralFieldsFilter
|
||||||
|
|
||||||
|
|
||||||
|
# These classes are used to test a misconfigured filter class.
|
||||||
|
class MisconfiguredFilter(django_filters.FilterSet):
|
||||||
|
text = django_filters.CharFilter(lookup_type='icontains')
|
||||||
|
class Meta:
|
||||||
|
model = BasicModel
|
||||||
|
fields = ['text']
|
||||||
|
|
||||||
|
|
||||||
|
class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
|
||||||
|
model = FilterableItem
|
||||||
|
filter_class = MisconfiguredFilter
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestFiltering(TestCase):
|
||||||
|
"""
|
||||||
|
Integration tests for filtered list views.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Create 10 FilterableItem instances.
|
||||||
|
"""
|
||||||
|
base_data = ('a', 0.25, datetime.date(2012, 10, 8))
|
||||||
|
for i in range(10):
|
||||||
|
text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
|
||||||
|
decimal = base_data[1] + i
|
||||||
|
date = base_data[2] - datetime.timedelta(days=i * 2)
|
||||||
|
FilterableItem(text=text, decimal=decimal, date=date).save()
|
||||||
|
|
||||||
|
self.objects = FilterableItem.objects
|
||||||
|
self.data = [
|
||||||
|
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
|
||||||
|
for obj in self.objects.all()
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_get_filtered_fields_root_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to paginated ListCreateAPIView should return paginated results.
|
||||||
|
"""
|
||||||
|
view = FilterFieldsRootView.as_view()
|
||||||
|
|
||||||
|
# Basic test with no filter.
|
||||||
|
request = factory.get('/')
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEquals(response.data, self.data)
|
||||||
|
|
||||||
|
# Tests that the decimal filter works.
|
||||||
|
search_decimal = 2.25
|
||||||
|
request = factory.get('/?decimal=%s' % search_decimal)
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
expected_data = [ f for f in self.data if f['decimal'] == search_decimal ]
|
||||||
|
self.assertEquals(response.data, expected_data)
|
||||||
|
|
||||||
|
# Tests that the date filter works.
|
||||||
|
search_date = datetime.date(2012, 9, 22)
|
||||||
|
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
expected_data = [ f for f in self.data if f['date'] == search_date ]
|
||||||
|
self.assertEquals(response.data, expected_data)
|
||||||
|
|
||||||
|
def test_get_filtered_class_root_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to filtered ListCreateAPIView that have a filter_class set
|
||||||
|
should return filtered results.
|
||||||
|
"""
|
||||||
|
view = FilterClassRootView.as_view()
|
||||||
|
|
||||||
|
# Basic test with no filter.
|
||||||
|
request = factory.get('/')
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEquals(response.data, self.data)
|
||||||
|
|
||||||
|
# Tests that the decimal filter set with 'lt' in the filter class works.
|
||||||
|
search_decimal = 4.25
|
||||||
|
request = factory.get('/?decimal=%s' % search_decimal)
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
expected_data = [ f for f in self.data if f['decimal'] < search_decimal ]
|
||||||
|
self.assertEquals(response.data, expected_data)
|
||||||
|
|
||||||
|
# Tests that the date filter set with 'gt' in the filter class works.
|
||||||
|
search_date = datetime.date(2012, 10, 2)
|
||||||
|
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
expected_data = [ f for f in self.data if f['date'] > search_date ]
|
||||||
|
self.assertEquals(response.data, expected_data)
|
||||||
|
|
||||||
|
# Tests that the text filter set with 'icontains' in the filter class works.
|
||||||
|
search_text = 'ff'
|
||||||
|
request = factory.get('/?text=%s' % search_text)
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
expected_data = [ f for f in self.data if search_text in f['text'].lower() ]
|
||||||
|
self.assertEquals(response.data, expected_data)
|
||||||
|
|
||||||
|
# Tests that multiple filters works.
|
||||||
|
search_decimal = 5.25
|
||||||
|
search_date = datetime.date(2012, 10, 2)
|
||||||
|
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
expected_data = [ f for f in self.data if f['date'] > search_date and
|
||||||
|
f['decimal'] < search_decimal ]
|
||||||
|
self.assertEquals(response.data, expected_data)
|
||||||
|
|
||||||
|
def test_incorrectly_configured_filter(self):
|
||||||
|
"""
|
||||||
|
An error should be displayed when the filter class is misconfigured.
|
||||||
|
"""
|
||||||
|
view = IncorrectlyConfiguredRootView.as_view()
|
||||||
|
|
||||||
|
request = factory.get('/')
|
||||||
|
self.assertRaises(AssertionError, view, request)
|
||||||
|
|
||||||
|
# TODO Return 400 filter paramater requested that hasn't been configured.
|
||||||
|
def test_bad_request(self):
|
||||||
|
"""
|
||||||
|
GET requests with filters that aren't configured should return 400.
|
||||||
|
"""
|
||||||
|
view = FilterFieldsRootView.as_view()
|
||||||
|
|
||||||
|
search_integer = 10
|
||||||
|
request = factory.get('/?integer=%s' % search_integer)
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
@ -85,6 +85,13 @@ class Bookmark(RESTFrameworkModel):
|
||||||
tags = GenericRelation(TaggedItem)
|
tags = GenericRelation(TaggedItem)
|
||||||
|
|
||||||
|
|
||||||
|
# Model to test filtering.
|
||||||
|
class FilterableItem(RESTFrameworkModel):
|
||||||
|
text = models.CharField(max_length=100)
|
||||||
|
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
||||||
|
date = models.DateField()
|
||||||
|
|
||||||
|
|
||||||
# Model for regression test for #285
|
# Model for regression test for #285
|
||||||
|
|
||||||
class Comment(RESTFrameworkModel):
|
class Comment(RESTFrameworkModel):
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
import datetime
|
||||||
from django.core.paginator import Paginator
|
from django.core.paginator import Paginator
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test.client import RequestFactory
|
from django.test.client import RequestFactory
|
||||||
from rest_framework import generics, status, pagination
|
from rest_framework import generics, status, pagination
|
||||||
from rest_framework.tests.models import BasicModel
|
from rest_framework.tests.models import BasicModel, FilterableItem
|
||||||
|
import django_filters
|
||||||
|
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
|
|
||||||
|
@ -15,6 +17,19 @@ class RootView(generics.ListCreateAPIView):
|
||||||
paginate_by = 10
|
paginate_by = 10
|
||||||
|
|
||||||
|
|
||||||
|
class DecimalFilter(django_filters.FilterSet):
|
||||||
|
decimal = django_filters.NumberFilter(lookup_type='lt')
|
||||||
|
class Meta:
|
||||||
|
model = FilterableItem
|
||||||
|
fields = ['text', 'decimal', 'date']
|
||||||
|
|
||||||
|
|
||||||
|
class FilterFieldsRootView(generics.ListCreateAPIView):
|
||||||
|
model = FilterableItem
|
||||||
|
paginate_by = 10
|
||||||
|
filter_class = DecimalFilter
|
||||||
|
|
||||||
|
|
||||||
class IntegrationTestPagination(TestCase):
|
class IntegrationTestPagination(TestCase):
|
||||||
"""
|
"""
|
||||||
Integration tests for paginated list views.
|
Integration tests for paginated list views.
|
||||||
|
@ -22,7 +37,7 @@ class IntegrationTestPagination(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""
|
"""
|
||||||
Create 26 BasicModel intances.
|
Create 26 BasicModel instances.
|
||||||
"""
|
"""
|
||||||
for char in 'abcdefghijklmnopqrstuvwxyz':
|
for char in 'abcdefghijklmnopqrstuvwxyz':
|
||||||
BasicModel(text=char * 3).save()
|
BasicModel(text=char * 3).save()
|
||||||
|
@ -61,6 +76,56 @@ class IntegrationTestPagination(TestCase):
|
||||||
self.assertEquals(response.data['next'], None)
|
self.assertEquals(response.data['next'], None)
|
||||||
self.assertNotEquals(response.data['previous'], None)
|
self.assertNotEquals(response.data['previous'], None)
|
||||||
|
|
||||||
|
class IntegrationTestPaginationAndFiltering(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Create 50 FilterableItem instances.
|
||||||
|
"""
|
||||||
|
base_data = ('a', 0.25, datetime.date(2012, 10, 8))
|
||||||
|
for i in range(26):
|
||||||
|
text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
|
||||||
|
decimal = base_data[1] + i
|
||||||
|
date = base_data[2] - datetime.timedelta(days=i * 2)
|
||||||
|
FilterableItem(text=text, decimal=decimal, date=date).save()
|
||||||
|
|
||||||
|
self.objects = FilterableItem.objects
|
||||||
|
self.data = [
|
||||||
|
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
|
||||||
|
for obj in self.objects.all()
|
||||||
|
]
|
||||||
|
self.view = FilterFieldsRootView.as_view()
|
||||||
|
|
||||||
|
def test_get_paginated_filtered_root_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to paginated filtered ListCreateAPIView should return
|
||||||
|
paginated results. The next and previous links should preserve the
|
||||||
|
filtered parameters.
|
||||||
|
"""
|
||||||
|
request = factory.get('/?decimal=15.20')
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEquals(response.data['count'], 15)
|
||||||
|
self.assertEquals(response.data['results'], self.data[:10])
|
||||||
|
self.assertNotEquals(response.data['next'], None)
|
||||||
|
self.assertEquals(response.data['previous'], None)
|
||||||
|
|
||||||
|
request = factory.get(response.data['next'])
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEquals(response.data['count'], 15)
|
||||||
|
self.assertEquals(response.data['results'], self.data[10:15])
|
||||||
|
self.assertEquals(response.data['next'], None)
|
||||||
|
self.assertNotEquals(response.data['previous'], None)
|
||||||
|
|
||||||
|
request = factory.get(response.data['previous'])
|
||||||
|
response = self.view(request).render()
|
||||||
|
self.assertEquals(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEquals(response.data['count'], 15)
|
||||||
|
self.assertEquals(response.data['results'], self.data[:10])
|
||||||
|
self.assertNotEquals(response.data['next'], None)
|
||||||
|
self.assertEquals(response.data['previous'], None)
|
||||||
|
|
||||||
|
|
||||||
class UnitTestPagination(TestCase):
|
class UnitTestPagination(TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user