mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-24 08:14:16 +03:00
Merge pull request #768 from kevinastone/master
Fixeds DjangoFilterBackend Incompatibility with SingleObjectMixin
This commit is contained in:
commit
2169c34a6f
|
@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
|
||||||
filter_class = self.get_filter_class(view)
|
filter_class = self.get_filter_class(view)
|
||||||
|
|
||||||
if filter_class:
|
if filter_class:
|
||||||
return filter_class(request.QUERY_PARAMS, queryset=queryset)
|
return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import datetime
|
import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from django.core.urlresolvers import reverse
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test.client import RequestFactory
|
from django.test.client import RequestFactory
|
||||||
from django.utils import unittest
|
from django.utils import unittest
|
||||||
from rest_framework import generics, status, filters
|
from rest_framework import generics, status, filters
|
||||||
from rest_framework.compat import django_filters
|
from rest_framework.compat import django_filters, patterns, url
|
||||||
from rest_framework.tests.models import FilterableItem, BasicModel
|
from rest_framework.tests.models import FilterableItem, BasicModel
|
||||||
|
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
|
@ -46,12 +47,21 @@ if django_filters:
|
||||||
filter_class = MisconfiguredFilter
|
filter_class = MisconfiguredFilter
|
||||||
filter_backend = filters.DjangoFilterBackend
|
filter_backend = filters.DjangoFilterBackend
|
||||||
|
|
||||||
|
class FilterClassDetailView(generics.RetrieveAPIView):
|
||||||
|
model = FilterableItem
|
||||||
|
filter_class = SeveralFieldsFilter
|
||||||
|
filter_backend = filters.DjangoFilterBackend
|
||||||
|
|
||||||
class IntegrationTestFiltering(TestCase):
|
urlpatterns = patterns('',
|
||||||
"""
|
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
|
||||||
Integration tests for filtered list views.
|
url(r'^$', FilterClassRootView.as_view(), name='root-view'),
|
||||||
"""
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CommonFilteringTestCase(TestCase):
|
||||||
|
def _serialize_object(self, obj):
|
||||||
|
return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""
|
"""
|
||||||
Create 10 FilterableItem instances.
|
Create 10 FilterableItem instances.
|
||||||
|
@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):
|
||||||
|
|
||||||
self.objects = FilterableItem.objects
|
self.objects = FilterableItem.objects
|
||||||
self.data = [
|
self.data = [
|
||||||
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
|
self._serialize_object(obj)
|
||||||
for obj in self.objects.all()
|
for obj in self.objects.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestFiltering(CommonFilteringTestCase):
|
||||||
|
"""
|
||||||
|
Integration tests for filtered list views.
|
||||||
|
"""
|
||||||
|
|
||||||
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
||||||
def test_get_filtered_fields_root_view(self):
|
def test_get_filtered_fields_root_view(self):
|
||||||
"""
|
"""
|
||||||
|
@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase):
|
||||||
request = factory.get('/?integer=%s' % search_integer)
|
request = factory.get('/?integer=%s' % search_integer)
|
||||||
response = view(request).render()
|
response = view(request).render()
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestDetailFiltering(CommonFilteringTestCase):
|
||||||
|
"""
|
||||||
|
Integration tests for filtered detail views.
|
||||||
|
"""
|
||||||
|
urls = 'rest_framework.tests.filterset'
|
||||||
|
|
||||||
|
def _get_url(self, item):
|
||||||
|
return reverse('detail-view', kwargs=dict(pk=item.pk))
|
||||||
|
|
||||||
|
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
||||||
|
def test_get_filtered_detail_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to filtered RetrieveAPIView that have a filter_class set
|
||||||
|
should return filtered results.
|
||||||
|
"""
|
||||||
|
item = self.objects.all()[0]
|
||||||
|
data = self._serialize_object(item)
|
||||||
|
|
||||||
|
# Basic test with no filter.
|
||||||
|
response = self.client.get(self._get_url(item))
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, data)
|
||||||
|
|
||||||
|
# Tests that the decimal filter set that should fail.
|
||||||
|
search_decimal = Decimal('4.25')
|
||||||
|
high_item = self.objects.filter(decimal__gt=search_decimal)[0]
|
||||||
|
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
# Tests that the decimal filter set that should succeed.
|
||||||
|
search_decimal = Decimal('4.25')
|
||||||
|
low_item = self.objects.filter(decimal__lt=search_decimal)[0]
|
||||||
|
low_item_data = self._serialize_object(low_item)
|
||||||
|
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, low_item_data)
|
||||||
|
|
||||||
|
# Tests that multiple filters works.
|
||||||
|
search_decimal = Decimal('5.25')
|
||||||
|
search_date = datetime.date(2012, 10, 2)
|
||||||
|
valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
|
||||||
|
valid_item_data = self._serialize_object(valid_item)
|
||||||
|
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, valid_item_data)
|
||||||
|
|
|
@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
|
||||||
view = FilterFieldsRootView.as_view()
|
view = FilterFieldsRootView.as_view()
|
||||||
|
|
||||||
EXPECTED_NUM_QUERIES = 2
|
EXPECTED_NUM_QUERIES = 2
|
||||||
if django.VERSION < (1, 4):
|
|
||||||
# On Django 1.3 we need to use django-filter 0.5.4
|
|
||||||
#
|
|
||||||
# The filter objects there don't expose a `.count()` method,
|
|
||||||
# which means we only make a single query *but* it's a single
|
|
||||||
# query across *all* of the queryset, instead of a COUNT and then
|
|
||||||
# a SELECT with a LIMIT.
|
|
||||||
#
|
|
||||||
# Although this is fewer queries, it's actually a regression.
|
|
||||||
EXPECTED_NUM_QUERIES = 1
|
|
||||||
|
|
||||||
request = factory.get('/?decimal=15.20')
|
request = factory.get('/?decimal=15.20')
|
||||||
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
|
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user