mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 18:08:03 +03:00 
			
		
		
		
	Added SearchFilter.get_search_fields() hook. (#6279)
This commit is contained in:
		
							parent
							
								
									1ece516d2d
								
							
						
					
					
						commit
						d110454d4c
					
				| 
						 | 
				
			
			@ -218,6 +218,13 @@ For example:
 | 
			
		|||
 | 
			
		||||
By default, the search parameter is named `'search`', but this may be overridden with the `SEARCH_PARAM` setting.
 | 
			
		||||
 | 
			
		||||
To dynamically change search fields based on request content, it's possible to subclass the `SearchFilter` and override the `get_search_fields()` function. For example, the following subclass will only search on `title` if the query parameter `title_only` is in the request:
 | 
			
		||||
 | 
			
		||||
    class CustomSearchFilter(self, view, request):
 | 
			
		||||
        if request.query_params.get('title_only'):
 | 
			
		||||
            return ('title',)
 | 
			
		||||
        return super(CustomSearchFilter, self).get_search_fields(view, request)
 | 
			
		||||
 | 
			
		||||
For more details, see the [Django documentation][search-django-admin].
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -53,6 +53,14 @@ class SearchFilter(BaseFilterBackend):
 | 
			
		|||
    search_title = _('Search')
 | 
			
		||||
    search_description = _('A search term.')
 | 
			
		||||
 | 
			
		||||
    def get_search_fields(self, view, request):
 | 
			
		||||
        """
 | 
			
		||||
        Search fields are obtained from the view, but the request is always
 | 
			
		||||
        passed to this method. Sub-classes can override this method to
 | 
			
		||||
        dynamically change the search fields based on request content.
 | 
			
		||||
        """
 | 
			
		||||
        return getattr(view, 'search_fields', None)
 | 
			
		||||
 | 
			
		||||
    def get_search_terms(self, request):
 | 
			
		||||
        """
 | 
			
		||||
        Search terms are set by a ?search=... query parameter,
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +98,7 @@ class SearchFilter(BaseFilterBackend):
 | 
			
		|||
        return False
 | 
			
		||||
 | 
			
		||||
    def filter_queryset(self, request, queryset, view):
 | 
			
		||||
        search_fields = getattr(view, 'search_fields', None)
 | 
			
		||||
        search_fields = self.get_search_fields(view, request)
 | 
			
		||||
        search_terms = self.get_search_terms(request)
 | 
			
		||||
 | 
			
		||||
        if not search_fields or not search_terms:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -156,6 +156,31 @@ class SearchFilterTests(TestCase):
 | 
			
		|||
 | 
			
		||||
        reload_module(filters)
 | 
			
		||||
 | 
			
		||||
    def test_search_with_filter_subclass(self):
 | 
			
		||||
        class CustomSearchFilter(filters.SearchFilter):
 | 
			
		||||
            # Filter that dynamically changes search fields
 | 
			
		||||
            def get_search_fields(self, view, request):
 | 
			
		||||
                if request.query_params.get('title_only'):
 | 
			
		||||
                    return ('$title',)
 | 
			
		||||
                return super(CustomSearchFilter, self).get_search_fields(view, request)
 | 
			
		||||
 | 
			
		||||
        class SearchListView(generics.ListAPIView):
 | 
			
		||||
            queryset = SearchFilterModel.objects.all()
 | 
			
		||||
            serializer_class = SearchFilterSerializer
 | 
			
		||||
            filter_backends = (CustomSearchFilter,)
 | 
			
		||||
            search_fields = ('$title', '$text')
 | 
			
		||||
 | 
			
		||||
        view = SearchListView.as_view()
 | 
			
		||||
        request = factory.get('/', {'search': '^\w{3}$'})
 | 
			
		||||
        response = view(request)
 | 
			
		||||
        assert len(response.data) == 10
 | 
			
		||||
 | 
			
		||||
        request = factory.get('/', {'search': '^\w{3}$', 'title_only': 'true'})
 | 
			
		||||
        response = view(request)
 | 
			
		||||
        assert response.data == [
 | 
			
		||||
            {'id': 3, 'title': 'zzz', 'text': 'cde'}
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AttributeModel(models.Model):
 | 
			
		||||
    label = models.CharField(max_length=32)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user