diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 42e77d910..6f556766c 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -10,6 +10,7 @@ from functools import reduce from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db import models +from django.db.models.constants import LOOKUP_SEP from django.template import loader from django.utils import six from django.utils.translation import ugettext_lazy as _ @@ -153,6 +154,25 @@ class SearchFilter(BaseFilterBackend): else: return "%s__icontains" % field_name + def must_call_distinct(self, opts, lookups): + """ + Return True if 'distinct()' should be used to query the given lookups. + """ + for lookup in lookups: + if lookup[0] in {'^', '=', '@', '$'}: + lookup = lookup[1:] + parts = lookup.split(LOOKUP_SEP) + for part in parts: + field = opts.get_field(part) + if hasattr(field, 'get_path_info'): + # This field is a relation, update opts to follow the relation + path_info = field.get_path_info() + opts = path_info[-1].to_opts + if any(path.m2m for path in path_info): + # This field is a m2m relation so we know we need to call distinct + return True + return False + def filter_queryset(self, request, queryset, view): search_fields = getattr(view, 'search_fields', None) search_terms = self.get_search_terms(request) @@ -173,10 +193,12 @@ class SearchFilter(BaseFilterBackend): ] queryset = queryset.filter(reduce(operator.or_, queries)) - # Filtering against a many-to-many field requires us to - # call queryset.distinct() in order to avoid duplicate items - # in the resulting queryset. - return distinct(queryset, base) + if self.must_call_distinct(queryset.model._meta, search_fields): + # Filtering against a many-to-many field requires us to + # call queryset.distinct() in order to avoid duplicate items + # in the resulting queryset. + queryset = distinct(queryset, base) + return queryset def to_html(self, request, queryset, view): if not getattr(view, 'search_fields', None): diff --git a/tests/test_filters.py b/tests/test_filters.py index 729a7b75b..b21f6d789 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -497,6 +497,21 @@ class SearchFilterM2MTests(TestCase): response = view(request) self.assertEqual(len(response.data), 1) + def test_must_call_distinct(self): + filter_ = filters.SearchFilter() + prefixes = ['', '^', '=', '@', '$'] + for prefix in prefixes: + self.assertFalse( + filter_.must_call_distinct( + SearchFilterModelM2M._meta, ["%stitle" % prefix] + ) + ) + self.assertTrue( + filter_.must_call_distinct( + SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix] + ) + ) + class OrderingFilterModel(models.Model): title = models.CharField(max_length=20)