mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 18:08:03 +03:00 
			
		
		
		
	Align SearchFilter behaviour to django.contrib.admin search (#9017)
* Use subquery to remove duplicates in SearchFilter * Align SearchFilter behaviour to django.contrib.admin * Add compatibility with older django/python versions * Allow search to split also by comma after smart split * Use generator to build search conditions to reduce iterations * Improve search documentation * Update docs/api-guide/filtering.md --------- Co-authored-by: Asif Saif Uddin <auvipy@gmail.com>
This commit is contained in:
		
							parent
							
								
									5c3b6e496c
								
							
						
					
					
						commit
						b99df0cf78
					
				| 
						 | 
				
			
			@ -218,14 +218,18 @@ For [JSONField][JSONField] and [HStoreField][HStoreField] fields you can filter
 | 
			
		|||
 | 
			
		||||
    search_fields = ['data__breed', 'data__owner__other_pets__0__name']
 | 
			
		||||
 | 
			
		||||
By default, searches will use case-insensitive partial matches.  The search parameter may contain multiple search terms, which should be whitespace and/or comma separated.  If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
 | 
			
		||||
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched. Searches may contain _quoted phrases_ with spaces, each phrase is considered as a single search term.
 | 
			
		||||
 | 
			
		||||
The search behavior may be restricted by prepending various characters to the `search_fields`.
 | 
			
		||||
 | 
			
		||||
* '^' Starts-with search.
 | 
			
		||||
* '=' Exact matches.
 | 
			
		||||
* '@' Full-text search.  (Currently only supported Django's [PostgreSQL backend][postgres-search].)
 | 
			
		||||
* '$' Regex search.
 | 
			
		||||
The search behavior may be specified by prefixing field names in `search_fields` with one of the following characters (which is equivalent to adding `__<lookup>` to the field):
 | 
			
		||||
 | 
			
		||||
| Prefix | Lookup        |                    |
 | 
			
		||||
| ------ | --------------| ------------------ |
 | 
			
		||||
| `^`    | `istartswith` | Starts-with search.|
 | 
			
		||||
| `=`    | `iexact`      | Exact matches.     |
 | 
			
		||||
| `$`    | `iregex`      | Regex search.      |
 | 
			
		||||
| `@`    | `search`      | Full-text search (Currently only supported Django's [PostgreSQL backend][postgres-search]). |
 | 
			
		||||
| None   | `icontains`   | Contains search (Default).  |
 | 
			
		||||
 | 
			
		||||
For example:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,7 +3,6 @@ The `compat` module provides support for backwards compatibility with older
 | 
			
		|||
versions of Django/Python, and compatibility wrappers around optional packages.
 | 
			
		||||
"""
 | 
			
		||||
import django
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.views.generic import View
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -14,13 +13,6 @@ def unicode_http_header(value):
 | 
			
		|||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def distinct(queryset, base):
 | 
			
		||||
    if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle":
 | 
			
		||||
        # distinct analogue for Oracle users
 | 
			
		||||
        return base.filter(pk__in=set(queryset.values_list('pk', flat=True)))
 | 
			
		||||
    return queryset.distinct()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# django.contrib.postgres requires psycopg2
 | 
			
		||||
try:
 | 
			
		||||
    from django.contrib.postgres import fields as postgres_fields
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,18 +6,35 @@ import operator
 | 
			
		|||
import warnings
 | 
			
		||||
from functools import reduce
 | 
			
		||||
 | 
			
		||||
from django.core.exceptions import ImproperlyConfigured
 | 
			
		||||
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.db.models.constants import LOOKUP_SEP
 | 
			
		||||
from django.template import loader
 | 
			
		||||
from django.utils.encoding import force_str
 | 
			
		||||
from django.utils.text import smart_split, unescape_string_literal
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
 | 
			
		||||
from rest_framework import RemovedInDRF317Warning
 | 
			
		||||
from rest_framework.compat import coreapi, coreschema, distinct
 | 
			
		||||
from rest_framework.compat import coreapi, coreschema
 | 
			
		||||
from rest_framework.fields import CharField
 | 
			
		||||
from rest_framework.settings import api_settings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def search_smart_split(search_terms):
 | 
			
		||||
    """generator that first splits string by spaces, leaving quoted phrases togheter,
 | 
			
		||||
    then it splits non-quoted phrases by commas.
 | 
			
		||||
    """
 | 
			
		||||
    for term in smart_split(search_terms):
 | 
			
		||||
        # trim commas to avoid bad matching for quoted phrases
 | 
			
		||||
        term = term.strip(',')
 | 
			
		||||
        if term.startswith(('"', "'")) and term[0] == term[-1]:
 | 
			
		||||
            # quoted phrases are kept togheter without any other split
 | 
			
		||||
            yield unescape_string_literal(term)
 | 
			
		||||
        else:
 | 
			
		||||
            # non-quoted tokens are split by comma, keeping only non-empty ones
 | 
			
		||||
            yield from (sub_term.strip() for sub_term in term.split(',') if sub_term)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseFilterBackend:
 | 
			
		||||
    """
 | 
			
		||||
    A base class from which all filter backend classes should inherit.
 | 
			
		||||
| 
						 | 
				
			
			@ -64,18 +81,41 @@ class SearchFilter(BaseFilterBackend):
 | 
			
		|||
    def get_search_terms(self, request):
 | 
			
		||||
        """
 | 
			
		||||
        Search terms are set by a ?search=... query parameter,
 | 
			
		||||
        and may be comma and/or whitespace delimited.
 | 
			
		||||
        and may be whitespace delimited.
 | 
			
		||||
        """
 | 
			
		||||
        params = request.query_params.get(self.search_param, '')
 | 
			
		||||
        params = params.replace('\x00', '')  # strip null characters
 | 
			
		||||
        params = params.replace(',', ' ')
 | 
			
		||||
        return params.split()
 | 
			
		||||
        value = request.query_params.get(self.search_param, '')
 | 
			
		||||
        field = CharField(trim_whitespace=False, allow_blank=True)
 | 
			
		||||
        return field.run_validation(value)
 | 
			
		||||
 | 
			
		||||
    def construct_search(self, field_name):
 | 
			
		||||
    def construct_search(self, field_name, queryset):
 | 
			
		||||
        lookup = self.lookup_prefixes.get(field_name[0])
 | 
			
		||||
        if lookup:
 | 
			
		||||
            field_name = field_name[1:]
 | 
			
		||||
        else:
 | 
			
		||||
            # Use field_name if it includes a lookup.
 | 
			
		||||
            opts = queryset.model._meta
 | 
			
		||||
            lookup_fields = field_name.split(LOOKUP_SEP)
 | 
			
		||||
            # Go through the fields, following all relations.
 | 
			
		||||
            prev_field = None
 | 
			
		||||
            for path_part in lookup_fields:
 | 
			
		||||
                if path_part == "pk":
 | 
			
		||||
                    path_part = opts.pk.name
 | 
			
		||||
                try:
 | 
			
		||||
                    field = opts.get_field(path_part)
 | 
			
		||||
                except FieldDoesNotExist:
 | 
			
		||||
                    # Use valid query lookups.
 | 
			
		||||
                    if prev_field and prev_field.get_lookup(path_part):
 | 
			
		||||
                        return field_name
 | 
			
		||||
                else:
 | 
			
		||||
                    prev_field = field
 | 
			
		||||
                    if hasattr(field, "path_infos"):
 | 
			
		||||
                        # Update opts to follow the relation.
 | 
			
		||||
                        opts = field.path_infos[-1].to_opts
 | 
			
		||||
                    # django < 4.1
 | 
			
		||||
                    elif hasattr(field, 'get_path_info'):
 | 
			
		||||
                        # Update opts to follow the relation.
 | 
			
		||||
                        opts = field.get_path_info()[-1].to_opts
 | 
			
		||||
            # Otherwise, use the field with icontains.
 | 
			
		||||
            lookup = 'icontains'
 | 
			
		||||
        return LOOKUP_SEP.join([field_name, lookup])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -113,26 +153,27 @@ class SearchFilter(BaseFilterBackend):
 | 
			
		|||
            return queryset
 | 
			
		||||
 | 
			
		||||
        orm_lookups = [
 | 
			
		||||
            self.construct_search(str(search_field))
 | 
			
		||||
            self.construct_search(str(search_field), queryset)
 | 
			
		||||
            for search_field in search_fields
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        base = queryset
 | 
			
		||||
        conditions = []
 | 
			
		||||
        for search_term in search_terms:
 | 
			
		||||
            queries = [
 | 
			
		||||
                models.Q(**{orm_lookup: search_term})
 | 
			
		||||
                for orm_lookup in orm_lookups
 | 
			
		||||
            ]
 | 
			
		||||
            conditions.append(reduce(operator.or_, queries))
 | 
			
		||||
        # generator which for each term builds the corresponding search
 | 
			
		||||
        conditions = (
 | 
			
		||||
            reduce(
 | 
			
		||||
                operator.or_,
 | 
			
		||||
                (models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups)
 | 
			
		||||
            ) for term in search_smart_split(search_terms)
 | 
			
		||||
        )
 | 
			
		||||
        queryset = queryset.filter(reduce(operator.and_, conditions))
 | 
			
		||||
 | 
			
		||||
        # Remove duplicates from results, if necessary
 | 
			
		||||
        if self.must_call_distinct(queryset, search_fields):
 | 
			
		||||
            # Filtering against a many-to-many field requires us to
 | 
			
		||||
            # call queryset.distinct() in order to avoid duplicate items
 | 
			
		||||
            # in the resulting queryset.
 | 
			
		||||
            # We try to avoid this if possible, for performance reasons.
 | 
			
		||||
            queryset = distinct(queryset, base)
 | 
			
		||||
            # inspired by django.contrib.admin
 | 
			
		||||
            # this is more accurate than .distinct form M2M relationship
 | 
			
		||||
            # also is cross-database
 | 
			
		||||
            queryset = queryset.filter(pk=models.OuterRef('pk'))
 | 
			
		||||
            queryset = base.filter(models.Exists(queryset))
 | 
			
		||||
        return queryset
 | 
			
		||||
 | 
			
		||||
    def to_html(self, request, queryset, view):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,16 +6,36 @@ from django.core.exceptions import ImproperlyConfigured
 | 
			
		|||
from django.db import models
 | 
			
		||||
from django.db.models import CharField, Transform
 | 
			
		||||
from django.db.models.functions import Concat, Upper
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.test import SimpleTestCase, TestCase
 | 
			
		||||
from django.test.utils import override_settings
 | 
			
		||||
 | 
			
		||||
from rest_framework import filters, generics, serializers
 | 
			
		||||
from rest_framework.compat import coreschema
 | 
			
		||||
from rest_framework.exceptions import ValidationError
 | 
			
		||||
from rest_framework.test import APIRequestFactory
 | 
			
		||||
 | 
			
		||||
factory = APIRequestFactory()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SearchSplitTests(SimpleTestCase):
 | 
			
		||||
 | 
			
		||||
    def test_keep_quoted_togheter_regardless_of_commas(self):
 | 
			
		||||
        assert ['hello, world'] == list(filters.search_smart_split('"hello, world"'))
 | 
			
		||||
 | 
			
		||||
    def test_strips_commas_around_quoted(self):
 | 
			
		||||
        assert ['hello, world'] == list(filters.search_smart_split(',,"hello, world"'))
 | 
			
		||||
        assert ['hello, world'] == list(filters.search_smart_split(',,"hello, world",,'))
 | 
			
		||||
        assert ['hello, world'] == list(filters.search_smart_split('"hello, world",,'))
 | 
			
		||||
 | 
			
		||||
    def test_splits_by_comma(self):
 | 
			
		||||
        assert ['hello', 'world'] == list(filters.search_smart_split(',,hello, world'))
 | 
			
		||||
        assert ['hello', 'world'] == list(filters.search_smart_split(',,hello, world,,'))
 | 
			
		||||
        assert ['hello', 'world'] == list(filters.search_smart_split('hello, world,,'))
 | 
			
		||||
 | 
			
		||||
    def test_splits_quotes_followed_by_comma_and_sentence(self):
 | 
			
		||||
        assert ['"hello', 'world"', 'found'] == list(filters.search_smart_split('"hello, world",found'))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseFilterTests(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.original_coreapi = filters.coreapi
 | 
			
		||||
| 
						 | 
				
			
			@ -50,7 +70,8 @@ class SearchFilterSerializer(serializers.ModelSerializer):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class SearchFilterTests(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def setUpTestData(cls):
 | 
			
		||||
        # Sequence of title/text is:
 | 
			
		||||
        #
 | 
			
		||||
        # z   abc
 | 
			
		||||
| 
						 | 
				
			
			@ -66,6 +87,9 @@ class SearchFilterTests(TestCase):
 | 
			
		|||
            )
 | 
			
		||||
            SearchFilterModel(title=title, text=text).save()
 | 
			
		||||
 | 
			
		||||
        SearchFilterModel(title='A title', text='The long text').save()
 | 
			
		||||
        SearchFilterModel(title='The title', text='The "text').save()
 | 
			
		||||
 | 
			
		||||
    def test_search(self):
 | 
			
		||||
        class SearchListView(generics.ListAPIView):
 | 
			
		||||
            queryset = SearchFilterModel.objects.all()
 | 
			
		||||
| 
						 | 
				
			
			@ -186,9 +210,21 @@ class SearchFilterTests(TestCase):
 | 
			
		|||
        request = factory.get('/?search=\0as%00d\x00f')
 | 
			
		||||
        request = view.initialize_request(request)
 | 
			
		||||
 | 
			
		||||
        terms = filters.SearchFilter().get_search_terms(request)
 | 
			
		||||
        with self.assertRaises(ValidationError):
 | 
			
		||||
            filters.SearchFilter().get_search_terms(request)
 | 
			
		||||
 | 
			
		||||
        assert terms == ['asdf']
 | 
			
		||||
    def test_search_field_with_custom_lookup(self):
 | 
			
		||||
        class SearchListView(generics.ListAPIView):
 | 
			
		||||
            queryset = SearchFilterModel.objects.all()
 | 
			
		||||
            serializer_class = SearchFilterSerializer
 | 
			
		||||
            filter_backends = (filters.SearchFilter,)
 | 
			
		||||
            search_fields = ('text__iendswith',)
 | 
			
		||||
        view = SearchListView.as_view()
 | 
			
		||||
        request = factory.get('/', {'search': 'c'})
 | 
			
		||||
        response = view(request)
 | 
			
		||||
        assert response.data == [
 | 
			
		||||
            {'id': 1, 'title': 'z', 'text': 'abc'},
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def test_search_field_with_additional_transforms(self):
 | 
			
		||||
        from django.test.utils import register_lookup
 | 
			
		||||
| 
						 | 
				
			
			@ -242,6 +278,32 @@ class SearchFilterTests(TestCase):
 | 
			
		|||
        )
 | 
			
		||||
        assert search_query in rendered_search_field
 | 
			
		||||
 | 
			
		||||
    def test_search_field_with_escapes(self):
 | 
			
		||||
        class SearchListView(generics.ListAPIView):
 | 
			
		||||
            queryset = SearchFilterModel.objects.all()
 | 
			
		||||
            serializer_class = SearchFilterSerializer
 | 
			
		||||
            filter_backends = (filters.SearchFilter,)
 | 
			
		||||
            search_fields = ('title', 'text',)
 | 
			
		||||
        view = SearchListView.as_view()
 | 
			
		||||
        request = factory.get('/', {'search': '"\\\"text"'})
 | 
			
		||||
        response = view(request)
 | 
			
		||||
        assert response.data == [
 | 
			
		||||
            {'id': 12, 'title': 'The title', 'text': 'The "text'},
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def test_search_field_with_quotes(self):
 | 
			
		||||
        class SearchListView(generics.ListAPIView):
 | 
			
		||||
            queryset = SearchFilterModel.objects.all()
 | 
			
		||||
            serializer_class = SearchFilterSerializer
 | 
			
		||||
            filter_backends = (filters.SearchFilter,)
 | 
			
		||||
            search_fields = ('title', 'text',)
 | 
			
		||||
        view = SearchListView.as_view()
 | 
			
		||||
        request = factory.get('/', {'search': '"long text"'})
 | 
			
		||||
        response = view(request)
 | 
			
		||||
        assert response.data == [
 | 
			
		||||
            {'id': 11, 'title': 'A title', 'text': 'The long text'},
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AttributeModel(models.Model):
 | 
			
		||||
    label = models.CharField(max_length=32)
 | 
			
		||||
| 
						 | 
				
			
			@ -284,6 +346,13 @@ class SearchFilterFkTests(TestCase):
 | 
			
		|||
                ["%sattribute__label" % prefix, "%stitle" % prefix]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_custom_lookup_to_related_model(self):
 | 
			
		||||
        # In this test case the attribute of the fk model comes first in the
 | 
			
		||||
        # list of search fields.
 | 
			
		||||
        filter_ = filters.SearchFilter()
 | 
			
		||||
        assert 'attribute__label__icontains' == filter_.construct_search('attribute__label', SearchFilterModelFk._meta)
 | 
			
		||||
        assert 'attribute__label__iendswith' == filter_.construct_search('attribute__label__iendswith', SearchFilterModelFk._meta)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SearchFilterModelM2M(models.Model):
 | 
			
		||||
    title = models.CharField(max_length=20)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user