SearchFilter to support Custom query Transforms

Since Some fields support `__` as a custom Transform for query lookups we needed to update the m2m checking code to handle search_fields that contain __ that are not relationships.
This commit is contained in:
matthaus woolard 2020-01-08 10:56:24 +13:00
parent 373e521f36
commit 290bde3050
2 changed files with 41 additions and 0 deletions

View File

@ -97,6 +97,9 @@ class SearchFilter(BaseFilterBackend):
if any(path.m2m for path in path_info):
# This field is a m2m relation so we know we need to call distinct
return True
else:
# This field has a custom __ query transform but is not a relational field.
break
return False
def filter_queryset(self, request, queryset, view):

View File

@ -1,9 +1,11 @@
import datetime
from importlib import reload as reload_module
import django
import pytest
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.db.models import CharField, Transform
from django.db.models.functions import Concat, Upper
from django.test import TestCase
from django.test.utils import override_settings
@ -189,6 +191,42 @@ class SearchFilterTests(TestCase):
assert terms == ['asdf']
@pytest.mark.skipif(django.VERSION[:2] < (2, 2), reason="requires django 2.2 or higher")
def test_search_field_with_additional_transforms(self):
from django.test.utils import register_lookup
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('text__trim', )
view = SearchListView.as_view()
# an example custom transform, that trims `a` from the string.
class TrimA(Transform):
function = 'TRIM'
lookup_name = 'trim'
def as_sql(self, compiler, connection):
sql, params = compiler.compile(self.lhs)
return "trim(%s, 'a')" % sql, params
with register_lookup(CharField, TrimA):
# Search including `a`
request = factory.get('/', {'search': 'abc'})
response = view(request)
assert response.data == []
# Search excluding `a`
request = factory.get('/', {'search': 'bc'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'},
]
class AttributeModel(models.Model):
label = models.CharField(max_length=32)