allow overriding search_fields property name

This commit is contained in:
Adam Wester 2015-12-02 08:12:23 -05:00
parent d2f90fd6af
commit 30ee7a54b7
2 changed files with 25 additions and 2 deletions

View File

@ -132,6 +132,7 @@ class DjangoFilterBackend(BaseFilterBackend):
class SearchFilter(BaseFilterBackend): class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search. # The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM search_param = api_settings.SEARCH_PARAM
search_fields_property = "search_fields"
template = 'rest_framework/filters/search.html' template = 'rest_framework/filters/search.html'
def get_search_terms(self, request): def get_search_terms(self, request):
@ -155,7 +156,7 @@ class SearchFilter(BaseFilterBackend):
return "%s__icontains" % field_name return "%s__icontains" % field_name
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
search_fields = getattr(view, 'search_fields', None) search_fields = getattr(view, self.search_fields_property, None)
search_terms = self.get_search_terms(request) search_terms = self.get_search_terms(request)
if not search_fields or not search_terms: if not search_fields or not search_terms:
@ -180,7 +181,7 @@ class SearchFilter(BaseFilterBackend):
return distinct(queryset, base) return distinct(queryset, base)
def to_html(self, request, queryset, view): def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None): if not getattr(view, self.search_fields_property, None):
return '' return ''
term = self.get_search_terms(request) term = self.get_search_terms(request)

View File

@ -447,6 +447,28 @@ class SearchFilterTests(TestCase):
reload_module(filters) reload_module(filters)
def test_override_search_fields(self):
class CustomSearchFilter(filters.SearchFilter):
search_fields_property = 'nonstandard_search_fields'
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (CustomSearchFilter,)
nonstandard_search_fields = ('title', 'text')
view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'})
response = view(request)
self.assertEqual(
response.data,
[
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
)
class AttributeModel(models.Model): class AttributeModel(models.Model):
label = models.CharField(max_length=32) label = models.CharField(max_length=32)