diff --git a/docs/api-guide/filtering.md b/docs/api-guide/filtering.md index 47ea8592d..6479efd2d 100644 --- a/docs/api-guide/filtering.md +++ b/docs/api-guide/filtering.md @@ -213,12 +213,12 @@ This will allow the client to filter the items in the list by making queries suc You can also perform a related lookup on a ForeignKey or ManyToManyField with the lookup API double-underscore notation: search_fields = ['username', 'email', 'profile__profession'] - + For [JSONField][JSONField] and [HStoreField][HStoreField] fields you can filter based on nested values within the data structure using the same double-underscore notation: search_fields = ['data__breed', 'data__owner__other_pets__0__name'] -By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched. +By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched. The search behavior may be restricted by prepending various characters to the `search_fields`. diff --git a/rest_framework/filters.py b/rest_framework/filters.py index ec7fdbb53..01437f0f3 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -6,15 +6,17 @@ import operator import warnings from functools import reduce -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.db import models from django.db.models.constants import LOOKUP_SEP from django.template import loader from django.utils.encoding import force_str +from django.utils.text import smart_split, unescape_string_literal from django.utils.translation import gettext_lazy as _ from rest_framework import RemovedInDRF317Warning from rest_framework.compat import coreapi, coreschema +from rest_framework.fields import CharField from rest_framework.settings import api_settings @@ -64,18 +66,37 @@ class SearchFilter(BaseFilterBackend): def get_search_terms(self, request): """ Search terms are set by a ?search=... query parameter, - and may be comma and/or whitespace delimited. + and may be whitespace delimited. """ - params = request.query_params.get(self.search_param, '') - params = params.replace('\x00', '') # strip null characters - params = params.replace(',', ' ') - return params.split() + value = request.query_params.get(self.search_param, '') + field = CharField(trim_whitespace=False, allow_blank=True) + return field.run_validation(value) - def construct_search(self, field_name): + def construct_search(self, field_name, queryset): lookup = self.lookup_prefixes.get(field_name[0]) if lookup: field_name = field_name[1:] else: + # Use field_name if it includes a lookup. + opts = queryset.model._meta + lookup_fields = field_name.split(LOOKUP_SEP) + # Go through the fields, following all relations. + prev_field = None + for path_part in lookup_fields: + if path_part == "pk": + path_part = opts.pk.name + try: + field = opts.get_field(path_part) + except FieldDoesNotExist: + # Use valid query lookups. + if prev_field and prev_field.get_lookup(path_part): + return field_name + else: + prev_field = field + if hasattr(field, "path_infos"): + # Update opts to follow the relation. + opts = field.path_infos[-1].to_opts + # Otherwise, use the field with icontains. lookup = 'icontains' return LOOKUP_SEP.join([field_name, lookup]) @@ -113,15 +134,17 @@ class SearchFilter(BaseFilterBackend): return queryset orm_lookups = [ - self.construct_search(str(search_field)) + self.construct_search(str(search_field), queryset) for search_field in search_fields ] base = queryset conditions = [] - for search_term in search_terms: + for term in smart_split(search_terms): + if term.startswith(('"', "'")) and term[0] == term[-1]: + term = unescape_string_literal(term) queries = [ - models.Q(**{orm_lookup: search_term}) + models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups ] conditions.append(reduce(operator.or_, queries)) diff --git a/tests/test_filters.py b/tests/test_filters.py index 2a22e30f9..f8eed4b97 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -11,6 +11,7 @@ from django.test.utils import override_settings from rest_framework import filters, generics, serializers from rest_framework.compat import coreschema +from rest_framework.exceptions import ValidationError from rest_framework.test import APIRequestFactory factory = APIRequestFactory() @@ -50,7 +51,8 @@ class SearchFilterSerializer(serializers.ModelSerializer): class SearchFilterTests(TestCase): - def setUp(self): + @classmethod + def setUpTestData(cls): # Sequence of title/text is: # # z abc @@ -66,6 +68,9 @@ class SearchFilterTests(TestCase): ) SearchFilterModel(title=title, text=text).save() + SearchFilterModel(title='A title', text='The long text').save() + SearchFilterModel(title='The title', text='The "text').save() + def test_search(self): class SearchListView(generics.ListAPIView): queryset = SearchFilterModel.objects.all() @@ -177,6 +182,7 @@ class SearchFilterTests(TestCase): request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'}) response = view(request) + print(response.data) assert response.data == [ {'id': 3, 'title': 'zzz', 'text': 'cde'} ] @@ -186,9 +192,21 @@ class SearchFilterTests(TestCase): request = factory.get('/?search=\0as%00d\x00f') request = view.initialize_request(request) - terms = filters.SearchFilter().get_search_terms(request) + with self.assertRaises(ValidationError): + filters.SearchFilter().get_search_terms(request) - assert terms == ['asdf'] + def test_search_field_with_custom_lookup(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('text__iendswith',) + view = SearchListView.as_view() + request = factory.get('/', {'search': 'c'}) + response = view(request) + assert response.data == [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + ] def test_search_field_with_additional_transforms(self): from django.test.utils import register_lookup @@ -242,6 +260,32 @@ class SearchFilterTests(TestCase): ) assert search_query in rendered_search_field + def test_search_field_with_escapes(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text',) + view = SearchListView.as_view() + request = factory.get('/', {'search': '"\\\"text"'}) + response = view(request) + assert response.data == [ + {'id': 12, 'title': 'The title', 'text': 'The "text'}, + ] + + def test_search_field_with_quotes(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text',) + view = SearchListView.as_view() + request = factory.get('/', {'search': '"long text"'}) + response = view(request) + assert response.data == [ + {'id': 11, 'title': 'A title', 'text': 'The long text'}, + ] + class AttributeModel(models.Model): label = models.CharField(max_length=32) @@ -284,6 +328,13 @@ class SearchFilterFkTests(TestCase): ["%sattribute__label" % prefix, "%stitle" % prefix] ) + def test_custom_lookup_to_related_model(self): + # In this test case the attribute of the fk model comes first in the + # list of search fields. + filter_ = filters.SearchFilter() + assert 'attribute__label__icontains' == filter_.construct_search('attribute__label', SearchFilterModelFk._meta) + assert 'attribute__label__iendswith' == filter_.construct_search('attribute__label__iendswith', SearchFilterModelFk._meta) + class SearchFilterModelM2M(models.Model): title = models.CharField(max_length=20) @@ -385,7 +436,7 @@ class SearchFilterToManyTests(TestCase): search_fields = ('=name', 'entry__headline', '=entry__pub_date__year') view = SearchListView.as_view() - request = factory.get('/', {'search': 'Lennon,1979'}) + request = factory.get('/', {'search': 'Lennon 1979'}) response = view(request) assert len(response.data) == 1