diff --git a/tests/test_filters.py b/tests/test_filters.py index 6d7969a92..f02c3078d 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -3,13 +3,13 @@ from importlib import reload as reload_module import pytest from django.core.exceptions import ImproperlyConfigured -from django.db import models +from django.db import connection, models # NOQA from django.db.models.functions import Concat, Upper from django.test import TestCase from django.test.utils import override_settings from rest_framework import filters, generics, serializers -from rest_framework.compat import coreschema +from rest_framework.compat import coreschema, postgres_fields from rest_framework.test import APIRequestFactory factory = APIRequestFactory() @@ -369,6 +369,75 @@ class SearchFilterAnnotatedFieldTests(TestCase): assert response.data[0]['title_text'] == 'ABCDEF' +class SearchFilterModelPostgres(models.Model): + json_field = postgres_fields.JSONField() + + +class SearchFilterPostgresSerializer(serializers.ModelSerializer): + class Meta: + model = SearchFilterModelPostgres + fields = '__all__' + + +@pytest.mark.skipif('connection.vendor != "postgresql"') +class SearchFilterPostgresTests(TestCase): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModelPostgres.objects.all() + serializer_class = SearchFilterPostgresSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('=json_field__title', 'json_field__text') + + def setUp(self): + for idx in range(10): + title = 'z' * (idx + 1) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + SearchFilterModelPostgres( + json_field={'title': title, 'text': text} + ).save() + + def test_default_json_search(self): + view = self.SearchListView.as_view() + + request = factory.get('/', {'search': 'b'}) + response = view(request) + items = { + (row['json_field']['title'], row['json_field']['text']) + for row in response.data + } + assert items == {('z', 'abc'), ('zz', 'bcd')} + + request = factory.get('/', {'search': 'EF'}) + response = view(request) + items = { + (row['json_field']['title'], row['json_field']['text']) + for row in response.data + } + assert items == {('zzzz', 'def'), ('zzzzz', 'efg')} + + def test_exact_json_search(self): + view = self.SearchListView.as_view() + + request = factory.get('/', {'search': 'zzzzz'}) + response = view(request) + items = { + (row['json_field']['title'], row['json_field']['text']) + for row in response.data + } + assert items == {('zzzzz', 'efg')} + + request = factory.get('/', {'search': 'Z'}) + response = view(request) + items = { + (row['json_field']['title'], row['json_field']['text']) + for row in response.data + } + assert items == {('z', 'abc')} + + class OrderingFilterModel(models.Model): title = models.CharField(max_length=20, verbose_name='verbose title') text = models.CharField(max_length=100)