diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 0473787bb..12ee279fa 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -42,6 +42,7 @@ class BaseFilterBackend(object): class SearchFilter(BaseFilterBackend): # The URL query parameter used for the search. search_param = api_settings.SEARCH_PARAM + search_filter_fields_source = 'search_fields' template = 'rest_framework/filters/search.html' lookup_prefixes = { '^': 'istartswith', @@ -89,7 +90,7 @@ class SearchFilter(BaseFilterBackend): return False def filter_queryset(self, request, queryset, view): - search_fields = getattr(view, 'search_fields', None) + search_fields = getattr(view, self.search_filter_fields_source, None) search_terms = self.get_search_terms(request) if not search_fields or not search_terms: @@ -119,7 +120,7 @@ class SearchFilter(BaseFilterBackend): return queryset def to_html(self, request, queryset, view): - if not getattr(view, 'search_fields', None): + if not getattr(view, self.search_filter_fields_source, None): return '' term = self.get_search_terms(request) diff --git a/tests/test_filters.py b/tests/test_filters.py index dc5b18068..a61e45500 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -156,6 +156,24 @@ class SearchFilterTests(TestCase): reload_module(filters) + def test_subclass_defines_own_field_source(self): + class CustomSearchFilter(filters.SearchFilter): + search_filter_fields_source = 'my_search_fields' + + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (CustomSearchFilter,) + my_search_fields = ('$title', '$text') + + view = SearchListView.as_view() + request = factory.get('/', {'search': 'cd'}) + response = view(request) + assert response.data == [ + {'id': 2, 'title': 'zz', 'text': 'bcd'}, + {'id': 3, 'title': 'zzz', 'text': 'cde'} + ] + class AttributeModel(models.Model): label = models.CharField(max_length=32)