mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-29 13:04:03 +03:00
Prevented unnecessary distinct() call in SearchFilter. (#3938)
* Prevented unnecessary distinct() call in SearchFilter. * Refactored SearchFilter lookup prefixes.
This commit is contained in:
parent
2a3b4c9822
commit
90bb0c58ce
|
@ -10,6 +10,7 @@ from functools import reduce
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.template import loader
|
from django.template import loader
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
@ -132,6 +133,12 @@ class SearchFilter(BaseFilterBackend):
|
||||||
# The URL query parameter used for the search.
|
# The URL query parameter used for the search.
|
||||||
search_param = api_settings.SEARCH_PARAM
|
search_param = api_settings.SEARCH_PARAM
|
||||||
template = 'rest_framework/filters/search.html'
|
template = 'rest_framework/filters/search.html'
|
||||||
|
lookup_prefixes = {
|
||||||
|
'^': 'istartswith',
|
||||||
|
'=': 'iexact',
|
||||||
|
'@': 'search',
|
||||||
|
'$': 'iregex',
|
||||||
|
}
|
||||||
|
|
||||||
def get_search_terms(self, request):
|
def get_search_terms(self, request):
|
||||||
"""
|
"""
|
||||||
|
@ -142,16 +149,31 @@ class SearchFilter(BaseFilterBackend):
|
||||||
return params.replace(',', ' ').split()
|
return params.replace(',', ' ').split()
|
||||||
|
|
||||||
def construct_search(self, field_name):
|
def construct_search(self, field_name):
|
||||||
if field_name.startswith('^'):
|
lookup = self.lookup_prefixes.get(field_name[0])
|
||||||
return "%s__istartswith" % field_name[1:]
|
if lookup:
|
||||||
elif field_name.startswith('='):
|
field_name = field_name[1:]
|
||||||
return "%s__iexact" % field_name[1:]
|
|
||||||
elif field_name.startswith('@'):
|
|
||||||
return "%s__search" % field_name[1:]
|
|
||||||
if field_name.startswith('$'):
|
|
||||||
return "%s__iregex" % field_name[1:]
|
|
||||||
else:
|
else:
|
||||||
return "%s__icontains" % field_name
|
lookup = 'icontains'
|
||||||
|
return LOOKUP_SEP.join([field_name, lookup])
|
||||||
|
|
||||||
|
def must_call_distinct(self, opts, lookups):
|
||||||
|
"""
|
||||||
|
Return True if 'distinct()' should be used to query the given lookups.
|
||||||
|
"""
|
||||||
|
for lookup in lookups:
|
||||||
|
if lookup[0] in self.lookup_prefixes:
|
||||||
|
lookup = lookup[1:]
|
||||||
|
parts = lookup.split(LOOKUP_SEP)
|
||||||
|
for part in parts:
|
||||||
|
field = opts.get_field(part)
|
||||||
|
if hasattr(field, 'get_path_info'):
|
||||||
|
# This field is a relation, update opts to follow the relation
|
||||||
|
path_info = field.get_path_info()
|
||||||
|
opts = path_info[-1].to_opts
|
||||||
|
if any(path.m2m for path in path_info):
|
||||||
|
# This field is a m2m relation so we know we need to call distinct
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
search_fields = getattr(view, 'search_fields', None)
|
search_fields = getattr(view, 'search_fields', None)
|
||||||
|
@ -173,10 +195,12 @@ class SearchFilter(BaseFilterBackend):
|
||||||
]
|
]
|
||||||
queryset = queryset.filter(reduce(operator.or_, queries))
|
queryset = queryset.filter(reduce(operator.or_, queries))
|
||||||
|
|
||||||
# Filtering against a many-to-many field requires us to
|
if self.must_call_distinct(queryset.model._meta, search_fields):
|
||||||
# call queryset.distinct() in order to avoid duplicate items
|
# Filtering against a many-to-many field requires us to
|
||||||
# in the resulting queryset.
|
# call queryset.distinct() in order to avoid duplicate items
|
||||||
return distinct(queryset, base)
|
# in the resulting queryset.
|
||||||
|
queryset = distinct(queryset, base)
|
||||||
|
return queryset
|
||||||
|
|
||||||
def to_html(self, request, queryset, view):
|
def to_html(self, request, queryset, view):
|
||||||
if not getattr(view, 'search_fields', None):
|
if not getattr(view, 'search_fields', None):
|
||||||
|
|
|
@ -500,6 +500,21 @@ class SearchFilterM2MTests(TestCase):
|
||||||
response = view(request)
|
response = view(request)
|
||||||
self.assertEqual(len(response.data), 1)
|
self.assertEqual(len(response.data), 1)
|
||||||
|
|
||||||
|
def test_must_call_distinct(self):
|
||||||
|
filter_ = filters.SearchFilter()
|
||||||
|
prefixes = [''] + list(filter_.lookup_prefixes)
|
||||||
|
for prefix in prefixes:
|
||||||
|
self.assertFalse(
|
||||||
|
filter_.must_call_distinct(
|
||||||
|
SearchFilterModelM2M._meta, ["%stitle" % prefix]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
filter_.must_call_distinct(
|
||||||
|
SearchFilterModelM2M._meta, ["%stitle" % prefix, "%sattributes__label" % prefix]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OrderingFilterModel(models.Model):
|
class OrderingFilterModel(models.Model):
|
||||||
title = models.CharField(max_length=20, verbose_name='verbose title')
|
title = models.CharField(max_length=20, verbose_name='verbose title')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user