Align SearchFilter behaviour to django.contrib.admin

This commit is contained in:
sevdog 2023-06-22 11:37:21 +02:00
parent dee83cebf4
commit ad96159234
No known key found for this signature in database
GPG Key ID: D939AF7A93A9C178
3 changed files with 90 additions and 16 deletions

View File

@ -218,7 +218,7 @@ For [JSONField][JSONField] and [HStoreField][HStoreField] fields you can filter
search_fields = ['data__breed', 'data__owner__other_pets__0__name']
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
The search behavior may be restricted by prepending various characters to the `search_fields`.

View File

@ -6,15 +6,17 @@ import operator
import warnings
from functools import reduce
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.template import loader
from django.utils.encoding import force_str
from django.utils.text import smart_split, unescape_string_literal
from django.utils.translation import gettext_lazy as _
from rest_framework import RemovedInDRF317Warning
from rest_framework.compat import coreapi, coreschema
from rest_framework.fields import CharField
from rest_framework.settings import api_settings
@ -64,18 +66,37 @@ class SearchFilter(BaseFilterBackend):
def get_search_terms(self, request):
"""
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
and may be whitespace delimited.
"""
params = request.query_params.get(self.search_param, '')
params = params.replace('\x00', '') # strip null characters
params = params.replace(',', ' ')
return params.split()
value = request.query_params.get(self.search_param, '')
field = CharField(trim_whitespace=False, allow_blank=True)
return field.run_validation(value)
def construct_search(self, field_name):
def construct_search(self, field_name, queryset):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
# Use field_name if it includes a lookup.
opts = queryset.model._meta
lookup_fields = field_name.split(LOOKUP_SEP)
# Go through the fields, following all relations.
prev_field = None
for path_part in lookup_fields:
if path_part == "pk":
path_part = opts.pk.name
try:
field = opts.get_field(path_part)
except FieldDoesNotExist:
# Use valid query lookups.
if prev_field and prev_field.get_lookup(path_part):
return field_name
else:
prev_field = field
if hasattr(field, "path_infos"):
# Update opts to follow the relation.
opts = field.path_infos[-1].to_opts
# Otherwise, use the field with icontains.
lookup = 'icontains'
return LOOKUP_SEP.join([field_name, lookup])
@ -113,15 +134,17 @@ class SearchFilter(BaseFilterBackend):
return queryset
orm_lookups = [
self.construct_search(str(search_field))
self.construct_search(str(search_field), queryset)
for search_field in search_fields
]
base = queryset
conditions = []
for search_term in search_terms:
for term in smart_split(search_terms):
if term.startswith(('"', "'")) and term[0] == term[-1]:
term = unescape_string_literal(term)
queries = [
models.Q(**{orm_lookup: search_term})
models.Q(**{orm_lookup: term})
for orm_lookup in orm_lookups
]
conditions.append(reduce(operator.or_, queries))

View File

@ -11,6 +11,7 @@ from django.test.utils import override_settings
from rest_framework import filters, generics, serializers
from rest_framework.compat import coreschema
from rest_framework.exceptions import ValidationError
from rest_framework.test import APIRequestFactory
factory = APIRequestFactory()
@ -50,7 +51,8 @@ class SearchFilterSerializer(serializers.ModelSerializer):
class SearchFilterTests(TestCase):
def setUp(self):
@classmethod
def setUpTestData(cls):
# Sequence of title/text is:
#
# z abc
@ -66,6 +68,9 @@ class SearchFilterTests(TestCase):
)
SearchFilterModel(title=title, text=text).save()
SearchFilterModel(title='A title', text='The long text').save()
SearchFilterModel(title='The title', text='The "text').save()
def test_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
@ -177,6 +182,7 @@ class SearchFilterTests(TestCase):
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
response = view(request)
print(response.data)
assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
@ -186,9 +192,21 @@ class SearchFilterTests(TestCase):
request = factory.get('/?search=\0as%00d\x00f')
request = view.initialize_request(request)
terms = filters.SearchFilter().get_search_terms(request)
with self.assertRaises(ValidationError):
filters.SearchFilter().get_search_terms(request)
assert terms == ['asdf']
def test_search_field_with_custom_lookup(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('text__iendswith',)
view = SearchListView.as_view()
request = factory.get('/', {'search': 'c'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'},
]
def test_search_field_with_additional_transforms(self):
from django.test.utils import register_lookup
@ -242,6 +260,32 @@ class SearchFilterTests(TestCase):
)
assert search_query in rendered_search_field
def test_search_field_with_escapes(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text',)
view = SearchListView.as_view()
request = factory.get('/', {'search': '"\\\"text"'})
response = view(request)
assert response.data == [
{'id': 12, 'title': 'The title', 'text': 'The "text'},
]
def test_search_field_with_quotes(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text',)
view = SearchListView.as_view()
request = factory.get('/', {'search': '"long text"'})
response = view(request)
assert response.data == [
{'id': 11, 'title': 'A title', 'text': 'The long text'},
]
class AttributeModel(models.Model):
label = models.CharField(max_length=32)
@ -284,6 +328,13 @@ class SearchFilterFkTests(TestCase):
["%sattribute__label" % prefix, "%stitle" % prefix]
)
def test_custom_lookup_to_related_model(self):
# In this test case the attribute of the fk model comes first in the
# list of search fields.
filter_ = filters.SearchFilter()
assert 'attribute__label__icontains' == filter_.construct_search('attribute__label', SearchFilterModelFk._meta)
assert 'attribute__label__iendswith' == filter_.construct_search('attribute__label__iendswith', SearchFilterModelFk._meta)
class SearchFilterModelM2M(models.Model):
title = models.CharField(max_length=20)
@ -385,7 +436,7 @@ class SearchFilterToManyTests(TestCase):
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'Lennon,1979'})
request = factory.get('/', {'search': 'Lennon 1979'})
response = view(request)
assert len(response.data) == 1