Align SearchFilter behaviour to django.contrib.admin

This commit is contained in:
sevdog 2023-06-22 11:37:21 +02:00
parent dee83cebf4
commit ad96159234
No known key found for this signature in database
GPG Key ID: D939AF7A93A9C178
3 changed files with 90 additions and 16 deletions

View File

@ -218,7 +218,7 @@ For [JSONField][JSONField] and [HStoreField][HStoreField] fields you can filter
search_fields = ['data__breed', 'data__owner__other_pets__0__name'] search_fields = ['data__breed', 'data__owner__other_pets__0__name']
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched. By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
The search behavior may be restricted by prepending various characters to the `search_fields`. The search behavior may be restricted by prepending various characters to the `search_fields`.

View File

@ -6,15 +6,17 @@ import operator
import warnings import warnings
from functools import reduce from functools import reduce
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.db import models from django.db import models
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.template import loader from django.template import loader
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.text import smart_split, unescape_string_literal
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import RemovedInDRF317Warning from rest_framework import RemovedInDRF317Warning
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
from rest_framework.fields import CharField
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -64,18 +66,37 @@ class SearchFilter(BaseFilterBackend):
def get_search_terms(self, request): def get_search_terms(self, request):
""" """
Search terms are set by a ?search=... query parameter, Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited. and may be whitespace delimited.
""" """
params = request.query_params.get(self.search_param, '') value = request.query_params.get(self.search_param, '')
params = params.replace('\x00', '') # strip null characters field = CharField(trim_whitespace=False, allow_blank=True)
params = params.replace(',', ' ') return field.run_validation(value)
return params.split()
def construct_search(self, field_name): def construct_search(self, field_name, queryset):
lookup = self.lookup_prefixes.get(field_name[0]) lookup = self.lookup_prefixes.get(field_name[0])
if lookup: if lookup:
field_name = field_name[1:] field_name = field_name[1:]
else: else:
# Use field_name if it includes a lookup.
opts = queryset.model._meta
lookup_fields = field_name.split(LOOKUP_SEP)
# Go through the fields, following all relations.
prev_field = None
for path_part in lookup_fields:
if path_part == "pk":
path_part = opts.pk.name
try:
field = opts.get_field(path_part)
except FieldDoesNotExist:
# Use valid query lookups.
if prev_field and prev_field.get_lookup(path_part):
return field_name
else:
prev_field = field
if hasattr(field, "path_infos"):
# Update opts to follow the relation.
opts = field.path_infos[-1].to_opts
# Otherwise, use the field with icontains.
lookup = 'icontains' lookup = 'icontains'
return LOOKUP_SEP.join([field_name, lookup]) return LOOKUP_SEP.join([field_name, lookup])
@ -113,15 +134,17 @@ class SearchFilter(BaseFilterBackend):
return queryset return queryset
orm_lookups = [ orm_lookups = [
self.construct_search(str(search_field)) self.construct_search(str(search_field), queryset)
for search_field in search_fields for search_field in search_fields
] ]
base = queryset base = queryset
conditions = [] conditions = []
for search_term in search_terms: for term in smart_split(search_terms):
if term.startswith(('"', "'")) and term[0] == term[-1]:
term = unescape_string_literal(term)
queries = [ queries = [
models.Q(**{orm_lookup: search_term}) models.Q(**{orm_lookup: term})
for orm_lookup in orm_lookups for orm_lookup in orm_lookups
] ]
conditions.append(reduce(operator.or_, queries)) conditions.append(reduce(operator.or_, queries))

View File

@ -11,6 +11,7 @@ from django.test.utils import override_settings
from rest_framework import filters, generics, serializers from rest_framework import filters, generics, serializers
from rest_framework.compat import coreschema from rest_framework.compat import coreschema
from rest_framework.exceptions import ValidationError
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
@ -50,7 +51,8 @@ class SearchFilterSerializer(serializers.ModelSerializer):
class SearchFilterTests(TestCase): class SearchFilterTests(TestCase):
def setUp(self): @classmethod
def setUpTestData(cls):
# Sequence of title/text is: # Sequence of title/text is:
# #
# z abc # z abc
@ -66,6 +68,9 @@ class SearchFilterTests(TestCase):
) )
SearchFilterModel(title=title, text=text).save() SearchFilterModel(title=title, text=text).save()
SearchFilterModel(title='A title', text='The long text').save()
SearchFilterModel(title='The title', text='The "text').save()
def test_search(self): def test_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
@ -177,6 +182,7 @@ class SearchFilterTests(TestCase):
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'}) request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
response = view(request) response = view(request)
print(response.data)
assert response.data == [ assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'} {'id': 3, 'title': 'zzz', 'text': 'cde'}
] ]
@ -186,9 +192,21 @@ class SearchFilterTests(TestCase):
request = factory.get('/?search=\0as%00d\x00f') request = factory.get('/?search=\0as%00d\x00f')
request = view.initialize_request(request) request = view.initialize_request(request)
terms = filters.SearchFilter().get_search_terms(request) with self.assertRaises(ValidationError):
filters.SearchFilter().get_search_terms(request)
assert terms == ['asdf'] def test_search_field_with_custom_lookup(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('text__iendswith',)
view = SearchListView.as_view()
request = factory.get('/', {'search': 'c'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'},
]
def test_search_field_with_additional_transforms(self): def test_search_field_with_additional_transforms(self):
from django.test.utils import register_lookup from django.test.utils import register_lookup
@ -242,6 +260,32 @@ class SearchFilterTests(TestCase):
) )
assert search_query in rendered_search_field assert search_query in rendered_search_field
def test_search_field_with_escapes(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text',)
view = SearchListView.as_view()
request = factory.get('/', {'search': '"\\\"text"'})
response = view(request)
assert response.data == [
{'id': 12, 'title': 'The title', 'text': 'The "text'},
]
def test_search_field_with_quotes(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text',)
view = SearchListView.as_view()
request = factory.get('/', {'search': '"long text"'})
response = view(request)
assert response.data == [
{'id': 11, 'title': 'A title', 'text': 'The long text'},
]
class AttributeModel(models.Model): class AttributeModel(models.Model):
label = models.CharField(max_length=32) label = models.CharField(max_length=32)
@ -284,6 +328,13 @@ class SearchFilterFkTests(TestCase):
["%sattribute__label" % prefix, "%stitle" % prefix] ["%sattribute__label" % prefix, "%stitle" % prefix]
) )
def test_custom_lookup_to_related_model(self):
# In this test case the attribute of the fk model comes first in the
# list of search fields.
filter_ = filters.SearchFilter()
assert 'attribute__label__icontains' == filter_.construct_search('attribute__label', SearchFilterModelFk._meta)
assert 'attribute__label__iendswith' == filter_.construct_search('attribute__label__iendswith', SearchFilterModelFk._meta)
class SearchFilterModelM2M(models.Model): class SearchFilterModelM2M(models.Model):
title = models.CharField(max_length=20) title = models.CharField(max_length=20)
@ -385,7 +436,7 @@ class SearchFilterToManyTests(TestCase):
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year') search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'Lennon,1979'}) request = factory.get('/', {'search': 'Lennon 1979'})
response = view(request) response = view(request)
assert len(response.data) == 1 assert len(response.data) == 1