mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-18 12:12:19 +03:00
Align SearchFilter behaviour to django.contrib.admin
This commit is contained in:
parent
dee83cebf4
commit
ad96159234
|
@ -218,7 +218,7 @@ For [JSONField][JSONField] and [HStoreField][HStoreField] fields you can filter
|
||||||
|
|
||||||
search_fields = ['data__breed', 'data__owner__other_pets__0__name']
|
search_fields = ['data__breed', 'data__owner__other_pets__0__name']
|
||||||
|
|
||||||
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
|
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
|
||||||
|
|
||||||
The search behavior may be restricted by prepending various characters to the `search_fields`.
|
The search behavior may be restricted by prepending various characters to the `search_fields`.
|
||||||
|
|
||||||
|
|
|
@ -6,15 +6,17 @@ import operator
|
||||||
import warnings
|
import warnings
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.template import loader
|
from django.template import loader
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
from django.utils.text import smart_split, unescape_string_literal
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
from rest_framework import RemovedInDRF317Warning
|
from rest_framework import RemovedInDRF317Warning
|
||||||
from rest_framework.compat import coreapi, coreschema
|
from rest_framework.compat import coreapi, coreschema
|
||||||
|
from rest_framework.fields import CharField
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,18 +66,37 @@ class SearchFilter(BaseFilterBackend):
|
||||||
def get_search_terms(self, request):
|
def get_search_terms(self, request):
|
||||||
"""
|
"""
|
||||||
Search terms are set by a ?search=... query parameter,
|
Search terms are set by a ?search=... query parameter,
|
||||||
and may be comma and/or whitespace delimited.
|
and may be whitespace delimited.
|
||||||
"""
|
"""
|
||||||
params = request.query_params.get(self.search_param, '')
|
value = request.query_params.get(self.search_param, '')
|
||||||
params = params.replace('\x00', '') # strip null characters
|
field = CharField(trim_whitespace=False, allow_blank=True)
|
||||||
params = params.replace(',', ' ')
|
return field.run_validation(value)
|
||||||
return params.split()
|
|
||||||
|
|
||||||
def construct_search(self, field_name):
|
def construct_search(self, field_name, queryset):
|
||||||
lookup = self.lookup_prefixes.get(field_name[0])
|
lookup = self.lookup_prefixes.get(field_name[0])
|
||||||
if lookup:
|
if lookup:
|
||||||
field_name = field_name[1:]
|
field_name = field_name[1:]
|
||||||
else:
|
else:
|
||||||
|
# Use field_name if it includes a lookup.
|
||||||
|
opts = queryset.model._meta
|
||||||
|
lookup_fields = field_name.split(LOOKUP_SEP)
|
||||||
|
# Go through the fields, following all relations.
|
||||||
|
prev_field = None
|
||||||
|
for path_part in lookup_fields:
|
||||||
|
if path_part == "pk":
|
||||||
|
path_part = opts.pk.name
|
||||||
|
try:
|
||||||
|
field = opts.get_field(path_part)
|
||||||
|
except FieldDoesNotExist:
|
||||||
|
# Use valid query lookups.
|
||||||
|
if prev_field and prev_field.get_lookup(path_part):
|
||||||
|
return field_name
|
||||||
|
else:
|
||||||
|
prev_field = field
|
||||||
|
if hasattr(field, "path_infos"):
|
||||||
|
# Update opts to follow the relation.
|
||||||
|
opts = field.path_infos[-1].to_opts
|
||||||
|
# Otherwise, use the field with icontains.
|
||||||
lookup = 'icontains'
|
lookup = 'icontains'
|
||||||
return LOOKUP_SEP.join([field_name, lookup])
|
return LOOKUP_SEP.join([field_name, lookup])
|
||||||
|
|
||||||
|
@ -113,15 +134,17 @@ class SearchFilter(BaseFilterBackend):
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
orm_lookups = [
|
orm_lookups = [
|
||||||
self.construct_search(str(search_field))
|
self.construct_search(str(search_field), queryset)
|
||||||
for search_field in search_fields
|
for search_field in search_fields
|
||||||
]
|
]
|
||||||
|
|
||||||
base = queryset
|
base = queryset
|
||||||
conditions = []
|
conditions = []
|
||||||
for search_term in search_terms:
|
for term in smart_split(search_terms):
|
||||||
|
if term.startswith(('"', "'")) and term[0] == term[-1]:
|
||||||
|
term = unescape_string_literal(term)
|
||||||
queries = [
|
queries = [
|
||||||
models.Q(**{orm_lookup: search_term})
|
models.Q(**{orm_lookup: term})
|
||||||
for orm_lookup in orm_lookups
|
for orm_lookup in orm_lookups
|
||||||
]
|
]
|
||||||
conditions.append(reduce(operator.or_, queries))
|
conditions.append(reduce(operator.or_, queries))
|
||||||
|
|
|
@ -11,6 +11,7 @@ from django.test.utils import override_settings
|
||||||
|
|
||||||
from rest_framework import filters, generics, serializers
|
from rest_framework import filters, generics, serializers
|
||||||
from rest_framework.compat import coreschema
|
from rest_framework.compat import coreschema
|
||||||
|
from rest_framework.exceptions import ValidationError
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
|
|
||||||
factory = APIRequestFactory()
|
factory = APIRequestFactory()
|
||||||
|
@ -50,7 +51,8 @@ class SearchFilterSerializer(serializers.ModelSerializer):
|
||||||
|
|
||||||
|
|
||||||
class SearchFilterTests(TestCase):
|
class SearchFilterTests(TestCase):
|
||||||
def setUp(self):
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
# Sequence of title/text is:
|
# Sequence of title/text is:
|
||||||
#
|
#
|
||||||
# z abc
|
# z abc
|
||||||
|
@ -66,6 +68,9 @@ class SearchFilterTests(TestCase):
|
||||||
)
|
)
|
||||||
SearchFilterModel(title=title, text=text).save()
|
SearchFilterModel(title=title, text=text).save()
|
||||||
|
|
||||||
|
SearchFilterModel(title='A title', text='The long text').save()
|
||||||
|
SearchFilterModel(title='The title', text='The "text').save()
|
||||||
|
|
||||||
def test_search(self):
|
def test_search(self):
|
||||||
class SearchListView(generics.ListAPIView):
|
class SearchListView(generics.ListAPIView):
|
||||||
queryset = SearchFilterModel.objects.all()
|
queryset = SearchFilterModel.objects.all()
|
||||||
|
@ -177,6 +182,7 @@ class SearchFilterTests(TestCase):
|
||||||
|
|
||||||
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
|
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
|
||||||
response = view(request)
|
response = view(request)
|
||||||
|
print(response.data)
|
||||||
assert response.data == [
|
assert response.data == [
|
||||||
{'id': 3, 'title': 'zzz', 'text': 'cde'}
|
{'id': 3, 'title': 'zzz', 'text': 'cde'}
|
||||||
]
|
]
|
||||||
|
@ -186,9 +192,21 @@ class SearchFilterTests(TestCase):
|
||||||
request = factory.get('/?search=\0as%00d\x00f')
|
request = factory.get('/?search=\0as%00d\x00f')
|
||||||
request = view.initialize_request(request)
|
request = view.initialize_request(request)
|
||||||
|
|
||||||
terms = filters.SearchFilter().get_search_terms(request)
|
with self.assertRaises(ValidationError):
|
||||||
|
filters.SearchFilter().get_search_terms(request)
|
||||||
|
|
||||||
assert terms == ['asdf']
|
def test_search_field_with_custom_lookup(self):
|
||||||
|
class SearchListView(generics.ListAPIView):
|
||||||
|
queryset = SearchFilterModel.objects.all()
|
||||||
|
serializer_class = SearchFilterSerializer
|
||||||
|
filter_backends = (filters.SearchFilter,)
|
||||||
|
search_fields = ('text__iendswith',)
|
||||||
|
view = SearchListView.as_view()
|
||||||
|
request = factory.get('/', {'search': 'c'})
|
||||||
|
response = view(request)
|
||||||
|
assert response.data == [
|
||||||
|
{'id': 1, 'title': 'z', 'text': 'abc'},
|
||||||
|
]
|
||||||
|
|
||||||
def test_search_field_with_additional_transforms(self):
|
def test_search_field_with_additional_transforms(self):
|
||||||
from django.test.utils import register_lookup
|
from django.test.utils import register_lookup
|
||||||
|
@ -242,6 +260,32 @@ class SearchFilterTests(TestCase):
|
||||||
)
|
)
|
||||||
assert search_query in rendered_search_field
|
assert search_query in rendered_search_field
|
||||||
|
|
||||||
|
def test_search_field_with_escapes(self):
|
||||||
|
class SearchListView(generics.ListAPIView):
|
||||||
|
queryset = SearchFilterModel.objects.all()
|
||||||
|
serializer_class = SearchFilterSerializer
|
||||||
|
filter_backends = (filters.SearchFilter,)
|
||||||
|
search_fields = ('title', 'text',)
|
||||||
|
view = SearchListView.as_view()
|
||||||
|
request = factory.get('/', {'search': '"\\\"text"'})
|
||||||
|
response = view(request)
|
||||||
|
assert response.data == [
|
||||||
|
{'id': 12, 'title': 'The title', 'text': 'The "text'},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_search_field_with_quotes(self):
|
||||||
|
class SearchListView(generics.ListAPIView):
|
||||||
|
queryset = SearchFilterModel.objects.all()
|
||||||
|
serializer_class = SearchFilterSerializer
|
||||||
|
filter_backends = (filters.SearchFilter,)
|
||||||
|
search_fields = ('title', 'text',)
|
||||||
|
view = SearchListView.as_view()
|
||||||
|
request = factory.get('/', {'search': '"long text"'})
|
||||||
|
response = view(request)
|
||||||
|
assert response.data == [
|
||||||
|
{'id': 11, 'title': 'A title', 'text': 'The long text'},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AttributeModel(models.Model):
|
class AttributeModel(models.Model):
|
||||||
label = models.CharField(max_length=32)
|
label = models.CharField(max_length=32)
|
||||||
|
@ -284,6 +328,13 @@ class SearchFilterFkTests(TestCase):
|
||||||
["%sattribute__label" % prefix, "%stitle" % prefix]
|
["%sattribute__label" % prefix, "%stitle" % prefix]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_custom_lookup_to_related_model(self):
|
||||||
|
# In this test case the attribute of the fk model comes first in the
|
||||||
|
# list of search fields.
|
||||||
|
filter_ = filters.SearchFilter()
|
||||||
|
assert 'attribute__label__icontains' == filter_.construct_search('attribute__label', SearchFilterModelFk._meta)
|
||||||
|
assert 'attribute__label__iendswith' == filter_.construct_search('attribute__label__iendswith', SearchFilterModelFk._meta)
|
||||||
|
|
||||||
|
|
||||||
class SearchFilterModelM2M(models.Model):
|
class SearchFilterModelM2M(models.Model):
|
||||||
title = models.CharField(max_length=20)
|
title = models.CharField(max_length=20)
|
||||||
|
@ -385,7 +436,7 @@ class SearchFilterToManyTests(TestCase):
|
||||||
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
|
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
|
||||||
|
|
||||||
view = SearchListView.as_view()
|
view = SearchListView.as_view()
|
||||||
request = factory.get('/', {'search': 'Lennon,1979'})
|
request = factory.get('/', {'search': 'Lennon 1979'})
|
||||||
response = view(request)
|
response = view(request)
|
||||||
assert len(response.data) == 1
|
assert len(response.data) == 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user