Merge pull request #768 from kevinastone/master

Fixeds DjangoFilterBackend Incompatibility with SingleObjectMixin
This commit is contained in:
Tom Christie 2013-03-29 06:36:46 -07:00
commit 2169c34a6f
3 changed files with 70 additions and 17 deletions

View File

@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class = self.get_filter_class(view) filter_class = self.get_filter_class(view)
if filter_class: if filter_class:
return filter_class(request.QUERY_PARAMS, queryset=queryset) return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
return queryset return queryset

View File

@ -1,11 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils import unittest from django.utils import unittest
from rest_framework import generics, status, filters from rest_framework import generics, status, filters
from rest_framework.compat import django_filters from rest_framework.compat import django_filters, patterns, url
from rest_framework.tests.models import FilterableItem, BasicModel from rest_framework.tests.models import FilterableItem, BasicModel
factory = RequestFactory() factory = RequestFactory()
@ -46,12 +47,21 @@ if django_filters:
filter_class = MisconfiguredFilter filter_class = MisconfiguredFilter
filter_backend = filters.DjangoFilterBackend filter_backend = filters.DjangoFilterBackend
class FilterClassDetailView(generics.RetrieveAPIView):
model = FilterableItem
filter_class = SeveralFieldsFilter
filter_backend = filters.DjangoFilterBackend
class IntegrationTestFiltering(TestCase): urlpatterns = patterns('',
""" url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
Integration tests for filtered list views. url(r'^$', FilterClassRootView.as_view(), name='root-view'),
""" )
class CommonFilteringTestCase(TestCase):
def _serialize_object(self, obj):
return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
def setUp(self): def setUp(self):
""" """
Create 10 FilterableItem instances. Create 10 FilterableItem instances.
@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):
self.objects = FilterableItem.objects self.objects = FilterableItem.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} self._serialize_object(obj)
for obj in self.objects.all() for obj in self.objects.all()
] ]
class IntegrationTestFiltering(CommonFilteringTestCase):
"""
Integration tests for filtered list views.
"""
@unittest.skipUnless(django_filters, 'django-filters not installed') @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_fields_root_view(self): def test_get_filtered_fields_root_view(self):
""" """
@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?integer=%s' % search_integer) request = factory.get('/?integer=%s' % search_integer)
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
class IntegrationTestDetailFiltering(CommonFilteringTestCase):
"""
Integration tests for filtered detail views.
"""
urls = 'rest_framework.tests.filterset'
def _get_url(self, item):
return reverse('detail-view', kwargs=dict(pk=item.pk))
@unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_detail_view(self):
"""
GET requests to filtered RetrieveAPIView that have a filter_class set
should return filtered results.
"""
item = self.objects.all()[0]
data = self._serialize_object(item)
# Basic test with no filter.
response = self.client.get(self._get_url(item))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, data)
# Tests that the decimal filter set that should fail.
search_decimal = Decimal('4.25')
high_item = self.objects.filter(decimal__gt=search_decimal)[0]
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Tests that the decimal filter set that should succeed.
search_decimal = Decimal('4.25')
low_item = self.objects.filter(decimal__lt=search_decimal)[0]
low_item_data = self._serialize_object(low_item)
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, low_item_data)
# Tests that multiple filters works.
search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2)
valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
valid_item_data = self._serialize_object(valid_item)
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data)

View File

@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view = FilterFieldsRootView.as_view() view = FilterFieldsRootView.as_view()
EXPECTED_NUM_QUERIES = 2 EXPECTED_NUM_QUERIES = 2
if django.VERSION < (1, 4):
# On Django 1.3 we need to use django-filter 0.5.4
#
# The filter objects there don't expose a `.count()` method,
# which means we only make a single query *but* it's a single
# query across *all* of the queryset, instead of a COUNT and then
# a SELECT with a LIMIT.
#
# Although this is fewer queries, it's actually a regression.
EXPECTED_NUM_QUERIES = 1
request = factory.get('/?decimal=15.20') request = factory.get('/?decimal=15.20')
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):