SearchFilter and tests

This commit is contained in:
Tom Christie 2013-05-10 21:57:20 +01:00
parent 773a92eab3
commit 8ce36d2bf1
2 changed files with 87 additions and 3 deletions

View File

@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend):
search_param = 'search'
def construct_search(self, field_name):
if field_name.startswith('^'):
return "%s__istartswith" % field_name[1:]
@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend):
if not search_fields:
return None
search_terms = request.QUERY_PARAMS.get(self.search_param)
orm_lookups = [self.construct_search(str(search_field))
for search_field in self.search_fields]
for bit in self.query.split():
for search_field in search_fields]
for bit in search_terms.split():
or_queries = [models.Q(**{orm_lookup: bit})
for orm_lookup in orm_lookups]
queryset = queryset.filter(reduce(operator.or_, or_queries))
return queryset

View File

@ -1,17 +1,24 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
from django.db import models
from django.core.urlresolvers import reverse
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
from rest_framework.tests.models import FilterableItem, BasicModel
from rest_framework.tests.models import BasicModel
factory = RequestFactory()
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
if django_filters:
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data)
class SearchFilterModel(models.Model):
title = models.CharField(max_length=20)
text = models.CharField(max_length=100)
class SearchFilterTests(TestCase):
def setUp(self):
# Sequence of title/text is:
#
# z abc
# zz bcd
# zzz cde
# ...
for idx in range(10):
title = 'z' * (idx + 1)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModel(title=title, text=text).save()
def test_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
view = SearchListView.as_view()
request = factory.get('?search=b')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 1, 'title': u'z', 'text': u'abc'},
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
]
)
def test_exact_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
view = SearchListView.as_view()
request = factory.get('?search=zzz')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 3, 'title': u'zzz', 'text': u'cde'}
]
)
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
model = SearchFilterModel
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
view = SearchListView.as_view()
request = factory.get('?search=b')
response = view(request)
self.assertEqual(
response.data,
[
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
]
)