diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 375242b6b..6a625ec5f 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -20,6 +20,21 @@ from rest_framework.fields import CharField from rest_framework.settings import api_settings +def search_smart_split(search_terms): + """generator that first splits string by spaces, leaving quoted phrases togheter, + then it splits non-quoted phrases by commas. + """ + for term in smart_split(search_terms): + # trim commas to avoid bad matching for quoted phrases + term = term.strip(',') + if term.startswith(('"', "'")) and term[0] == term[-1]: + # quoted phrases are kept togheter without any other split + yield unescape_string_literal(term) + else: + # non-quoted tokens are split by comma, keeping only non-empty ones + yield from (sub_term.strip() for sub_term in term.split(',') if sub_term) + + class BaseFilterBackend: """ A base class from which all filter backend classes should inherit. @@ -144,9 +159,7 @@ class SearchFilter(BaseFilterBackend): base = queryset conditions = [] - for term in smart_split(search_terms): - if term.startswith(('"', "'")) and term[0] == term[-1]: - term = unescape_string_literal(term) + for term in search_smart_split(search_terms): queries = [ models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups diff --git a/tests/test_filters.py b/tests/test_filters.py index 2a87165ca..6db0c3deb 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -6,7 +6,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db import models from django.db.models import CharField, Transform from django.db.models.functions import Concat, Upper -from django.test import TestCase +from django.test import SimpleTestCase, TestCase from django.test.utils import override_settings from rest_framework import filters, generics, serializers @@ -17,6 +17,25 @@ from rest_framework.test import APIRequestFactory factory = APIRequestFactory() +class SearchSplitTests(SimpleTestCase): + + def test_keep_quoted_togheter_regardless_of_commas(self): + assert ['hello, world'] == list(filters.search_smart_split('"hello, world"')) + + def test_strips_commas_around_quoted(self): + assert ['hello, world'] == list(filters.search_smart_split(',,"hello, world"')) + assert ['hello, world'] == list(filters.search_smart_split(',,"hello, world",,')) + assert ['hello, world'] == list(filters.search_smart_split('"hello, world",,')) + + def test_splits_by_comma(self): + assert ['hello', 'world'] == list(filters.search_smart_split(',,hello, world')) + assert ['hello', 'world'] == list(filters.search_smart_split(',,hello, world,,')) + assert ['hello', 'world'] == list(filters.search_smart_split('hello, world,,')) + + def test_splits_quotes_followed_by_comma_and_sentence(self): + assert ['"hello', 'world"', 'found'] == list(filters.search_smart_split('"hello, world",found')) + + class BaseFilterTests(TestCase): def setUp(self): self.original_coreapi = filters.coreapi @@ -435,7 +454,7 @@ class SearchFilterToManyTests(TestCase): search_fields = ('=name', 'entry__headline', '=entry__pub_date__year') view = SearchListView.as_view() - request = factory.get('/', {'search': 'Lennon 1979'}) + request = factory.get('/', {'search': 'Lennon,1979'}) response = view(request) assert len(response.data) == 1