This commit is contained in:
Patrick Forringer 2017-09-25 21:51:55 +00:00 committed by GitHub
commit ec3dbb84d4
2 changed files with 21 additions and 2 deletions

View File

@ -42,6 +42,7 @@ class BaseFilterBackend(object):
class SearchFilter(BaseFilterBackend): class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search. # The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM search_param = api_settings.SEARCH_PARAM
search_filter_fields_source = 'search_fields'
template = 'rest_framework/filters/search.html' template = 'rest_framework/filters/search.html'
lookup_prefixes = { lookup_prefixes = {
'^': 'istartswith', '^': 'istartswith',
@ -89,7 +90,7 @@ class SearchFilter(BaseFilterBackend):
return False return False
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
search_fields = getattr(view, 'search_fields', None) search_fields = getattr(view, self.search_filter_fields_source, None)
search_terms = self.get_search_terms(request) search_terms = self.get_search_terms(request)
if not search_fields or not search_terms: if not search_fields or not search_terms:
@ -119,7 +120,7 @@ class SearchFilter(BaseFilterBackend):
return queryset return queryset
def to_html(self, request, queryset, view): def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None): if not getattr(view, self.search_filter_fields_source, None):
return '' return ''
term = self.get_search_terms(request) term = self.get_search_terms(request)

View File

@ -156,6 +156,24 @@ class SearchFilterTests(TestCase):
reload_module(filters) reload_module(filters)
def test_subclass_defines_own_field_source(self):
class CustomSearchFilter(filters.SearchFilter):
search_filter_fields_source = 'my_search_fields'
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (CustomSearchFilter,)
my_search_fields = ('$title', '$text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'cd'})
response = view(request)
assert response.data == [
{'id': 2, 'title': 'zz', 'text': 'bcd'},
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
class AttributeModel(models.Model): class AttributeModel(models.Model):
label = models.CharField(max_length=32) label = models.CharField(max_length=32)