mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-11 12:17:24 +03:00
SearchFilter and tests
This commit is contained in:
parent
773a92eab3
commit
8ce36d2bf1
|
@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend):
|
||||||
|
|
||||||
|
|
||||||
class SearchFilter(BaseFilterBackend):
|
class SearchFilter(BaseFilterBackend):
|
||||||
|
search_param = 'search'
|
||||||
|
|
||||||
def construct_search(self, field_name):
|
def construct_search(self, field_name):
|
||||||
if field_name.startswith('^'):
|
if field_name.startswith('^'):
|
||||||
return "%s__istartswith" % field_name[1:]
|
return "%s__istartswith" % field_name[1:]
|
||||||
|
@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend):
|
||||||
if not search_fields:
|
if not search_fields:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
search_terms = request.QUERY_PARAMS.get(self.search_param)
|
||||||
orm_lookups = [self.construct_search(str(search_field))
|
orm_lookups = [self.construct_search(str(search_field))
|
||||||
for search_field in self.search_fields]
|
for search_field in search_fields]
|
||||||
for bit in self.query.split():
|
|
||||||
|
for bit in search_terms.split():
|
||||||
or_queries = [models.Q(**{orm_lookup: bit})
|
or_queries = [models.Q(**{orm_lookup: bit})
|
||||||
for orm_lookup in orm_lookups]
|
for orm_lookup in orm_lookups]
|
||||||
queryset = queryset.filter(reduce(operator.or_, or_queries))
|
queryset = queryset.filter(reduce(operator.or_, or_queries))
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
|
@ -1,17 +1,24 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import datetime
|
import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from django.db import models
|
||||||
from django.core.urlresolvers import reverse
|
from django.core.urlresolvers import reverse
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test.client import RequestFactory
|
from django.test.client import RequestFactory
|
||||||
from django.utils import unittest
|
from django.utils import unittest
|
||||||
from rest_framework import generics, serializers, status, filters
|
from rest_framework import generics, serializers, status, filters
|
||||||
from rest_framework.compat import django_filters, patterns, url
|
from rest_framework.compat import django_filters, patterns, url
|
||||||
from rest_framework.tests.models import FilterableItem, BasicModel
|
from rest_framework.tests.models import BasicModel
|
||||||
|
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
|
|
||||||
|
|
||||||
|
class FilterableItem(models.Model):
|
||||||
|
text = models.CharField(max_length=100)
|
||||||
|
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
||||||
|
date = models.DateField()
|
||||||
|
|
||||||
|
|
||||||
if django_filters:
|
if django_filters:
|
||||||
# Basic filter on a list view.
|
# Basic filter on a list view.
|
||||||
class FilterFieldsRootView(generics.ListCreateAPIView):
|
class FilterFieldsRootView(generics.ListCreateAPIView):
|
||||||
|
@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
|
||||||
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
|
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
self.assertEqual(response.data, valid_item_data)
|
self.assertEqual(response.data, valid_item_data)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilterModel(models.Model):
|
||||||
|
title = models.CharField(max_length=20)
|
||||||
|
text = models.CharField(max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilterTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Sequence of title/text is:
|
||||||
|
#
|
||||||
|
# z abc
|
||||||
|
# zz bcd
|
||||||
|
# zzz cde
|
||||||
|
# ...
|
||||||
|
for idx in range(10):
|
||||||
|
title = 'z' * (idx + 1)
|
||||||
|
text = (
|
||||||
|
chr(idx + ord('a')) +
|
||||||
|
chr(idx + ord('b')) +
|
||||||
|
chr(idx + ord('c'))
|
||||||
|
)
|
||||||
|
SearchFilterModel(title=title, text=text).save()
|
||||||
|
|
||||||
|
def test_search(self):
|
||||||
|
class SearchListView(generics.ListAPIView):
|
||||||
|
model = SearchFilterModel
|
||||||
|
filter_backends = (filters.SearchFilter,)
|
||||||
|
search_fields = ('title', 'text')
|
||||||
|
|
||||||
|
view = SearchListView.as_view()
|
||||||
|
request = factory.get('?search=b')
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(
|
||||||
|
response.data,
|
||||||
|
[
|
||||||
|
{u'id': 1, 'title': u'z', 'text': u'abc'},
|
||||||
|
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_exact_search(self):
|
||||||
|
class SearchListView(generics.ListAPIView):
|
||||||
|
model = SearchFilterModel
|
||||||
|
filter_backends = (filters.SearchFilter,)
|
||||||
|
search_fields = ('=title', 'text')
|
||||||
|
|
||||||
|
view = SearchListView.as_view()
|
||||||
|
request = factory.get('?search=zzz')
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(
|
||||||
|
response.data,
|
||||||
|
[
|
||||||
|
{u'id': 3, 'title': u'zzz', 'text': u'cde'}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_startswith_search(self):
|
||||||
|
class SearchListView(generics.ListAPIView):
|
||||||
|
model = SearchFilterModel
|
||||||
|
filter_backends = (filters.SearchFilter,)
|
||||||
|
search_fields = ('title', '^text')
|
||||||
|
|
||||||
|
view = SearchListView.as_view()
|
||||||
|
request = factory.get('?search=b')
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(
|
||||||
|
response.data,
|
||||||
|
[
|
||||||
|
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user