mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-03 05:04:31 +03:00
SearchFilter and tests
This commit is contained in:
parent
773a92eab3
commit
8ce36d2bf1
|
@ -74,6 +74,8 @@ class DjangoFilterBackend(BaseFilterBackend):
|
|||
|
||||
|
||||
class SearchFilter(BaseFilterBackend):
|
||||
search_param = 'search'
|
||||
|
||||
def construct_search(self, field_name):
|
||||
if field_name.startswith('^'):
|
||||
return "%s__istartswith" % field_name[1:]
|
||||
|
@ -90,10 +92,13 @@ class SearchFilter(BaseFilterBackend):
|
|||
if not search_fields:
|
||||
return None
|
||||
|
||||
search_terms = request.QUERY_PARAMS.get(self.search_param)
|
||||
orm_lookups = [self.construct_search(str(search_field))
|
||||
for search_field in self.search_fields]
|
||||
for bit in self.query.split():
|
||||
for search_field in search_fields]
|
||||
|
||||
for bit in search_terms.split():
|
||||
or_queries = [models.Q(**{orm_lookup: bit})
|
||||
for orm_lookup in orm_lookups]
|
||||
queryset = queryset.filter(reduce(operator.or_, or_queries))
|
||||
|
||||
return queryset
|
||||
|
|
|
@ -1,17 +1,24 @@
|
|||
from __future__ import unicode_literals
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
from django.db import models
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.test import TestCase
|
||||
from django.test.client import RequestFactory
|
||||
from django.utils import unittest
|
||||
from rest_framework import generics, serializers, status, filters
|
||||
from rest_framework.compat import django_filters, patterns, url
|
||||
from rest_framework.tests.models import FilterableItem, BasicModel
|
||||
from rest_framework.tests.models import BasicModel
|
||||
|
||||
factory = RequestFactory()
|
||||
|
||||
|
||||
class FilterableItem(models.Model):
|
||||
text = models.CharField(max_length=100)
|
||||
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
||||
date = models.DateField()
|
||||
|
||||
|
||||
if django_filters:
|
||||
# Basic filter on a list view.
|
||||
class FilterFieldsRootView(generics.ListCreateAPIView):
|
||||
|
@ -256,3 +263,75 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
|
|||
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data, valid_item_data)
|
||||
|
||||
|
||||
class SearchFilterModel(models.Model):
|
||||
title = models.CharField(max_length=20)
|
||||
text = models.CharField(max_length=100)
|
||||
|
||||
|
||||
class SearchFilterTests(TestCase):
|
||||
def setUp(self):
|
||||
# Sequence of title/text is:
|
||||
#
|
||||
# z abc
|
||||
# zz bcd
|
||||
# zzz cde
|
||||
# ...
|
||||
for idx in range(10):
|
||||
title = 'z' * (idx + 1)
|
||||
text = (
|
||||
chr(idx + ord('a')) +
|
||||
chr(idx + ord('b')) +
|
||||
chr(idx + ord('c'))
|
||||
)
|
||||
SearchFilterModel(title=title, text=text).save()
|
||||
|
||||
def test_search(self):
|
||||
class SearchListView(generics.ListAPIView):
|
||||
model = SearchFilterModel
|
||||
filter_backends = (filters.SearchFilter,)
|
||||
search_fields = ('title', 'text')
|
||||
|
||||
view = SearchListView.as_view()
|
||||
request = factory.get('?search=b')
|
||||
response = view(request)
|
||||
self.assertEqual(
|
||||
response.data,
|
||||
[
|
||||
{u'id': 1, 'title': u'z', 'text': u'abc'},
|
||||
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
|
||||
]
|
||||
)
|
||||
|
||||
def test_exact_search(self):
|
||||
class SearchListView(generics.ListAPIView):
|
||||
model = SearchFilterModel
|
||||
filter_backends = (filters.SearchFilter,)
|
||||
search_fields = ('=title', 'text')
|
||||
|
||||
view = SearchListView.as_view()
|
||||
request = factory.get('?search=zzz')
|
||||
response = view(request)
|
||||
self.assertEqual(
|
||||
response.data,
|
||||
[
|
||||
{u'id': 3, 'title': u'zzz', 'text': u'cde'}
|
||||
]
|
||||
)
|
||||
|
||||
def test_startswith_search(self):
|
||||
class SearchListView(generics.ListAPIView):
|
||||
model = SearchFilterModel
|
||||
filter_backends = (filters.SearchFilter,)
|
||||
search_fields = ('title', '^text')
|
||||
|
||||
view = SearchListView.as_view()
|
||||
request = factory.get('?search=b')
|
||||
response = view(request)
|
||||
self.assertEqual(
|
||||
response.data,
|
||||
[
|
||||
{u'id': 2, 'title': u'zz', 'text': u'bcd'}
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user