diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 3836e8170..4ac942957 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -10,6 +10,7 @@ from functools import reduce from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db import models +from django.db.models.constants import LOOKUP_SEP from django.template import loader from django.utils import six from django.utils.translation import ugettext_lazy as _ @@ -132,6 +133,12 @@ class SearchFilter(BaseFilterBackend): # The URL query parameter used for the search. search_param = api_settings.SEARCH_PARAM template = 'rest_framework/filters/search.html' + lookup_prefixes = { + '^': 'istartswith', + '=': 'iexact', + '@': 'search', + '$': 'iregex', + } def get_search_terms(self, request): """ @@ -142,16 +149,31 @@ class SearchFilter(BaseFilterBackend): return params.replace(',', ' ').split() def construct_search(self, field_name): - if field_name.startswith('^'): - return "%s__istartswith" % field_name[1:] - elif field_name.startswith('='): - return "%s__iexact" % field_name[1:] - elif field_name.startswith('@'): - return "%s__search" % field_name[1:] - if field_name.startswith('$'): - return "%s__iregex" % field_name[1:] + lookup = self.lookup_prefixes.get(field_name[0]) + if lookup: + field_name = field_name[1:] else: - return "%s__icontains" % field_name + lookup = 'icontains' + return LOOKUP_SEP.join([field_name, lookup]) + + def must_call_distinct(self, opts, lookups): + """ + Return True if 'distinct()' should be used to query the given lookups. + """ + for lookup in lookups: + if lookup[0] in self.lookup_prefixes: + lookup = lookup[1:] + parts = lookup.split(LOOKUP_SEP) + for part in parts: + field = opts.get_field(part) + if hasattr(field, 'get_path_info'): + # This field is a relation, update opts to follow the relation + path_info = field.get_path_info() + opts = path_info[-1].to_opts + if any(path.m2m for path in path_info): + # This field is a m2m relation so we know we need to call distinct + return True + return False def filter_queryset(self, request, queryset, view): search_fields = getattr(view, 'search_fields', None) @@ -173,10 +195,12 @@ class SearchFilter(BaseFilterBackend): ] queryset = queryset.filter(reduce(operator.or_, queries)) - # Filtering against a many-to-many field requires us to - # call queryset.distinct() in order to avoid duplicate items - # in the resulting queryset. - return distinct(queryset, base) + if self.must_call_distinct(queryset.model._meta, search_fields): + # Filtering against a many-to-many field requires us to + # call queryset.distinct() in order to avoid duplicate items + # in the resulting queryset. + queryset = distinct(queryset, base) + return queryset def to_html(self, request, queryset, view): if not getattr(view, 'search_fields', None): diff --git a/tests/test_filters.py b/tests/test_filters.py index 646d8a625..4c3e2af0b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -500,6 +500,21 @@ class SearchFilterM2MTests(TestCase): response = view(request) self.assertEqual(len(response.data), 1) + def test_must_call_distinct(self): + filter_ = filters.SearchFilter() + prefixes = [''] + list(filter_.lookup_prefixes) + for prefix in prefixes: + self.assertFalse( + filter_.must_call_distinct( + SearchFilterModelM2M._meta, ["%stitle" % prefix] + ) + ) + self.assertTrue( + filter_.must_call_distinct( + SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix] + ) + ) + class OrderingFilterModel(models.Model): title = models.CharField(max_length=20, verbose_name='verbose title')