Prevented unnecessary distinct() call in SearchFilter.

This commit is contained in:
Simon Charette 2016-02-16 17:55:54 -05:00
parent 79dad012b0
commit 1ada749a42
2 changed files with 41 additions and 4 deletions

View File

@ -10,6 +10,7 @@ from functools import reduce
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.template import loader
from django.utils import six
from django.utils.translation import ugettext_lazy as _
@ -153,6 +154,25 @@ class SearchFilter(BaseFilterBackend):
else:
return "%s__icontains" % field_name
def must_call_distinct(self, opts, lookups):
"""
Return True if 'distinct()' should be used to query the given lookups.
"""
for lookup in lookups:
if lookup[0] in {'^', '=', '@', '$'}:
lookup = lookup[1:]
parts = lookup.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
if hasattr(field, 'get_path_info'):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
if any(path.m2m for path in path_info):
# This field is a m2m relation so we know we need to call distinct
return True
return False
def filter_queryset(self, request, queryset, view):
search_fields = getattr(view, 'search_fields', None)
search_terms = self.get_search_terms(request)
@ -173,10 +193,12 @@ class SearchFilter(BaseFilterBackend):
]
queryset = queryset.filter(reduce(operator.or_, queries))
# Filtering against a many-to-many field requires us to
# call queryset.distinct() in order to avoid duplicate items
# in the resulting queryset.
return distinct(queryset, base)
if self.must_call_distinct(queryset.model._meta, search_fields):
# Filtering against a many-to-many field requires us to
# call queryset.distinct() in order to avoid duplicate items
# in the resulting queryset.
queryset = distinct(queryset, base)
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):

View File

@ -497,6 +497,21 @@ class SearchFilterM2MTests(TestCase):
response = view(request)
self.assertEqual(len(response.data), 1)
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
prefixes = ['', '^', '=', '@', '$']
for prefix in prefixes:
self.assertFalse(
filter_.must_call_distinct(
SearchFilterModelM2M._meta, ["%stitle" % prefix]
)
)
self.assertTrue(
filter_.must_call_distinct(
SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix]
)
)
class OrderingFilterModel(models.Model):
title = models.CharField(max_length=20)