SearchFilter and tests

This commit is contained in:
Tom Christie 2013-05-10 21:57:20 +01:00
parent 773a92eab3
commit 8ce36d2bf1
2 changed files with 87 additions and 3 deletions

View File

@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend): class SearchFilter(BaseFilterBackend):
search_param = 'search'
def construct_search(self, field_name): def construct_search(self, field_name):
if field_name.startswith('^'): if field_name.startswith('^'):
return "%s__istartswith" % field_name[1:] return "%s__istartswith" % field_name[1:]
@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend):
if not search_fields: if not search_fields:
return None return None
search_terms = request.QUERY_PARAMS.get(self.search_param)
orm_lookups = [self.construct_search(str(search_field)) orm_lookups = [self.construct_search(str(search_field))
for search_field in self.search_fields] for search_field in search_fields]
for bit in self.query.split():
for bit in search_terms.split():
or_queries = [models.Q(**{orm_lookup: bit}) or_queries = [models.Q(**{orm_lookup: bit})
for orm_lookup in orm_lookups] for orm_lookup in orm_lookups]
queryset = queryset.filter(reduce(operator.or_, or_queries)) queryset = queryset.filter(reduce(operator.or_, or_queries))
return queryset return queryset

View File

@ -1,17 +1,24 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.db import models
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils import unittest from django.utils import unittest
from rest_framework import generics, serializers, status, filters from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url from rest_framework.compat import django_filters, patterns, url
from rest_framework.tests.models import FilterableItem, BasicModel from rest_framework.tests.models import BasicModel
factory = RequestFactory() factory = RequestFactory()
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
if django_filters: if django_filters:
# Basic filter on a list view. # Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView): class FilterFieldsRootView(generics.ListCreateAPIView):
@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data) self.assertEqual(response.data, valid_item_data)
class SearchFilterModel(models.Model):
title = models.CharField(max_length=20)
text = models.CharField(max_length=100)
class SearchFilterTests(TestCase):
def setUp(self):
# Sequence of title/text is:
#
# z abc
# zz bcd
# zzz cde
# ...
for idx in range(10):
title = 'z' * (idx + 1)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModel(title=title, text=text).save()
def test_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
view = SearchListView.as_view()
request = factory.get('?search=b')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 1, 'title': u'z', 'text': u'abc'},
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
]
)
def test_exact_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
view = SearchListView.as_view()
request = factory.get('?search=zzz')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 3, 'title': u'zzz', 'text': u'cde'}
]
)
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
view = SearchListView.as_view()
request = factory.get('?search=b')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
]
)