diff --git a/rest_framework/filters.py b/rest_framework/filters.py index c15723ec3..0e985d6e0 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -97,6 +97,9 @@ class SearchFilter(BaseFilterBackend): if any(path.m2m for path in path_info): # This field is a m2m relation so we know we need to call distinct return True + else: + # This field has a custom __ query transform but is not a relational field. + break return False def filter_queryset(self, request, queryset, view): diff --git a/tests/test_filters.py b/tests/test_filters.py index 6d7969a92..30cedc7d7 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,9 +1,11 @@ import datetime from importlib import reload as reload_module +import django import pytest from django.core.exceptions import ImproperlyConfigured from django.db import models +from django.db.models import CharField, Transform from django.db.models.functions import Concat, Upper from django.test import TestCase from django.test.utils import override_settings @@ -189,6 +191,42 @@ class SearchFilterTests(TestCase): assert terms == ['asdf'] + @pytest.mark.skipif(django.VERSION[:2] < (2, 2), reason="requires django 2.2 or higher") + def test_search_field_with_additional_transforms(self): + from django.test.utils import register_lookup + + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('text__trim', ) + + view = SearchListView.as_view() + + # an example custom transform, that trims `a` from the string. + class TrimA(Transform): + function = 'TRIM' + lookup_name = 'trim' + + def as_sql(self, compiler, connection): + sql, params = compiler.compile(self.lhs) + return "trim(%s, 'a')" % sql, params + + with register_lookup(CharField, TrimA): + # Search including `a` + request = factory.get('/', {'search': 'abc'}) + + response = view(request) + assert response.data == [] + + # Search excluding `a` + request = factory.get('/', {'search': 'bc'}) + response = view(request) + assert response.data == [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'}, + ] + class AttributeModel(models.Model): label = models.CharField(max_length=32)