Prevented unnecessary distinct() call in SearchFilter.

This commit is contained in:
Simon Charette 2016-02-16 17:55:54 -05:00
parent 79dad012b0
commit 1ada749a42
2 changed files with 41 additions and 4 deletions

View File

@ -10,6 +10,7 @@ from functools import reduce
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import models from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.template import loader from django.template import loader
from django.utils import six from django.utils import six
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -153,6 +154,25 @@ class SearchFilter(BaseFilterBackend):
else: else:
return "%s__icontains" % field_name return "%s__icontains" % field_name
def must_call_distinct(self, opts, lookups):
"""
Return True if 'distinct()' should be used to query the given lookups.
"""
for lookup in lookups:
if lookup[0] in {'^', '=', '@', '$'}:
lookup = lookup[1:]
parts = lookup.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
if hasattr(field, 'get_path_info'):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
if any(path.m2m for path in path_info):
# This field is a m2m relation so we know we need to call distinct
return True
return False
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
search_fields = getattr(view, 'search_fields', None) search_fields = getattr(view, 'search_fields', None)
search_terms = self.get_search_terms(request) search_terms = self.get_search_terms(request)
@ -173,10 +193,12 @@ class SearchFilter(BaseFilterBackend):
] ]
queryset = queryset.filter(reduce(operator.or_, queries)) queryset = queryset.filter(reduce(operator.or_, queries))
# Filtering against a many-to-many field requires us to if self.must_call_distinct(queryset.model._meta, search_fields):
# call queryset.distinct() in order to avoid duplicate items # Filtering against a many-to-many field requires us to
# in the resulting queryset. # call queryset.distinct() in order to avoid duplicate items
return distinct(queryset, base) # in the resulting queryset.
queryset = distinct(queryset, base)
return queryset
def to_html(self, request, queryset, view): def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None): if not getattr(view, 'search_fields', None):

View File

@ -497,6 +497,21 @@ class SearchFilterM2MTests(TestCase):
response = view(request) response = view(request)
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
prefixes = ['', '^', '=', '@', '$']
for prefix in prefixes:
self.assertFalse(
filter_.must_call_distinct(
SearchFilterModelM2M._meta, ["%stitle" % prefix]
)
)
self.assertTrue(
filter_.must_call_distinct(
SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix]
)
)
class OrderingFilterModel(models.Model): class OrderingFilterModel(models.Model):
title = models.CharField(max_length=20) title = models.CharField(max_length=20)