From 4f7e9ed3bbb0fb3567884ef4e3d340bc2f77bc68 Mon Sep 17 00:00:00 2001 From: Andrew <54764942+drewg3r@users.noreply.github.com> Date: Sun, 2 Jul 2023 09:57:20 +0200 Subject: [PATCH 1/6] Fix SearchFilter renders field with invalid value (#9023) Co-authored-by: Andrii Tarasenko --- rest_framework/filters.py | 4 +--- tests/test_filters.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 17e6975eb..c48504957 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -139,11 +139,9 @@ class SearchFilter(BaseFilterBackend): if not getattr(view, 'search_fields', None): return '' - term = self.get_search_terms(request) - term = term[0] if term else '' context = { 'param': self.search_param, - 'term': term + 'term': request.query_params.get(self.search_param, ''), } template = loader.get_template(self.template) return template.render(context) diff --git a/tests/test_filters.py b/tests/test_filters.py index 37ae4c7cf..2a22e30f9 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -225,6 +225,23 @@ class SearchFilterTests(TestCase): {'id': 2, 'title': 'zz', 'text': 'bcd'}, ] + def test_search_field_with_multiple_words(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + search_query = 'foo bar,baz' + view = SearchListView() + request = factory.get('/', {'search': search_query}) + request = view.initialize_request(request) + + rendered_search_field = filters.SearchFilter().to_html( + request=request, queryset=view.queryset, view=view + ) + assert search_query in rendered_search_field + class AttributeModel(models.Model): label = models.CharField(max_length=32) From 66d86d0177673b90723664e3fde2919df66f2b60 Mon Sep 17 00:00:00 2001 From: Burson Date: Thu, 13 Jul 2023 20:50:49 +0800 Subject: [PATCH 2/6] Fix choices in ChoiceField to support IntEnum (#8955) Python support Enum in version 3.4, but changed __str__ to int.__str__ until version 3.11 to better support the replacement of existing constants use-case. [https://docs.python.org/3/library/enum.html#enum.IntEnum](https://docs.python.org/3/library/enum.html#enum.IntEnum) rest_frame work support Python 3.6+, this commit will support the Enum in choices of Field. --- rest_framework/fields.py | 17 ++++++----------- tests/test_fields.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4ce9c79c3..0b56fa7fb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -8,6 +8,7 @@ import logging import re import uuid from collections.abc import Mapping +from enum import Enum from django.conf import settings from django.core.exceptions import ObjectDoesNotExist @@ -17,7 +18,6 @@ from django.core.validators import ( MinValueValidator, ProhibitNullCharactersValidator, RegexValidator, URLValidator, ip_address_validators ) -from django.db.models import IntegerChoices, TextChoices from django.forms import FilePathField as DjangoFilePathField from django.forms import ImageField as DjangoImageField from django.utils import timezone @@ -1401,11 +1401,8 @@ class ChoiceField(Field): def to_internal_value(self, data): if data == '' and self.allow_blank: return '' - - if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \ - str(data.value): + if isinstance(data, Enum) and str(data) != str(data.value): data = data.value - try: return self.choice_strings_to_values[str(data)] except KeyError: @@ -1414,11 +1411,8 @@ class ChoiceField(Field): def to_representation(self, value): if value in ('', None): return value - - if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \ - str(value.value): + if isinstance(value, Enum) and str(value) != str(value.value): value = value.value - return self.choice_strings_to_values.get(str(value), value) def iter_options(self): @@ -1442,8 +1436,7 @@ class ChoiceField(Field): # Allows us to deal with eg. integer choices while supporting either # integer or string input, but still get the correct datatype out. self.choice_strings_to_values = { - str(key.value) if isinstance(key, (IntegerChoices, TextChoices)) - and str(key) != str(key.value) else str(key): key for key in self.choices + str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices } choices = property(_get_choices, _set_choices) @@ -1829,6 +1822,7 @@ class HiddenField(Field): constraint on a pair of fields, as we need some way to include the date in the validated data. """ + def __init__(self, **kwargs): assert 'default' in kwargs, 'default is a required argument.' kwargs['write_only'] = True @@ -1858,6 +1852,7 @@ class SerializerMethodField(Field): def get_extra_info(self, obj): return ... # Calculate some data to return. """ + def __init__(self, method_name=None, **kwargs): self.method_name = method_name kwargs['source'] = '*' diff --git a/tests/test_fields.py b/tests/test_fields.py index 03584431e..7006d473c 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1875,6 +1875,31 @@ class TestChoiceField(FieldValues): field.run_validation(2) assert exc_info.value.detail == ['"2" is not a valid choice.'] + def test_enum_integer_choices(self): + from enum import IntEnum + + class ChoiceCase(IntEnum): + first = auto() + second = auto() + # Enum validate + choices = [ + (ChoiceCase.first, "1"), + (ChoiceCase.second, "2") + ] + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(1) == 1 + assert field.run_validation(ChoiceCase.first) == 1 + assert field.run_validation("1") == 1 + # Enum.value validate + choices = [ + (ChoiceCase.first.value, "1"), + (ChoiceCase.second.value, "2") + ] + field = serializers.ChoiceField(choices=choices) + assert field.run_validation(1) == 1 + assert field.run_validation(ChoiceCase.first) == 1 + assert field.run_validation("1") == 1 + def test_integer_choices(self): class ChoiceCase(IntegerChoices): first = auto() From 7658ffd71d4ad07ddada20b2b9538b889ec02403 Mon Sep 17 00:00:00 2001 From: Sergey Klyuykov Date: Sat, 15 Jul 2023 02:26:56 -0700 Subject: [PATCH 3/6] Fix: Pagination response schemas. (#9049) --- rest_framework/pagination.py | 3 +++ tests/test_pagination.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 7303890b0..2b20e76af 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -239,6 +239,7 @@ class PageNumberPagination(BasePagination): def get_paginated_response_schema(self, schema): return { 'type': 'object', + 'required': ['count', 'results'], 'properties': { 'count': { 'type': 'integer', @@ -411,6 +412,7 @@ class LimitOffsetPagination(BasePagination): def get_paginated_response_schema(self, schema): return { 'type': 'object', + 'required': ['count', 'results'], 'properties': { 'count': { 'type': 'integer', @@ -906,6 +908,7 @@ class CursorPagination(BasePagination): def get_paginated_response_schema(self, schema): return { 'type': 'object', + 'required': ['results'], 'properties': { 'next': { 'type': 'string', diff --git a/tests/test_pagination.py b/tests/test_pagination.py index d606986ab..090eb0d81 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -274,6 +274,7 @@ class TestPageNumberPagination: assert self.pagination.get_paginated_response_schema(unpaginated_schema) == { 'type': 'object', + 'required': ['count', 'results'], 'properties': { 'count': { 'type': 'integer', @@ -585,6 +586,7 @@ class TestLimitOffset: assert self.pagination.get_paginated_response_schema(unpaginated_schema) == { 'type': 'object', + 'required': ['count', 'results'], 'properties': { 'count': { 'type': 'integer', @@ -937,6 +939,7 @@ class CursorPaginationTestsMixin: assert self.pagination.get_paginated_response_schema(unpaginated_schema) == { 'type': 'object', + 'required': ['results'], 'properties': { 'next': { 'type': 'string', From 5c3b6e496c9892463f48f6b50cf9a0f1d2c29e78 Mon Sep 17 00:00:00 2001 From: Amin Aminian <47900904+aminiun@users.noreply.github.com> Date: Tue, 25 Jul 2023 09:51:25 +0200 Subject: [PATCH 4/6] class name added to unkown field error (#9019) --- rest_framework/serializers.py | 4 ++-- tests/test_model_serializer.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 56fa918dc..6ee75fbc1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1372,8 +1372,8 @@ class ModelSerializer(Serializer): Raise an error on any unknown fields. """ raise ImproperlyConfigured( - 'Field name `%s` is not valid for model `%s`.' % - (field_name, model_class.__name__) + 'Field name `%s` is not valid for model `%s` in `%s.%s`.' % + (field_name, model_class.__name__, self.__class__.__module__, self.__class__.__name__) ) def include_extra_kwargs(self, kwargs, extra_kwargs): diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index c5ac888f5..e2d4bbc30 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -315,7 +315,8 @@ class TestRegularFieldMappings(TestCase): model = RegularFieldsModel fields = ('auto_field', 'invalid') - expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.' + expected = 'Field name `invalid` is not valid for model `RegularFieldsModel` ' \ + 'in `tests.test_model_serializer.TestSerializer`.' with self.assertRaisesMessage(ImproperlyConfigured, expected): TestSerializer().fields From b99df0cf780adc3d65362a4425f9bb6d85410bcb Mon Sep 17 00:00:00 2001 From: Devid Date: Tue, 25 Jul 2023 14:01:23 +0100 Subject: [PATCH 5/6] Align SearchFilter behaviour to django.contrib.admin search (#9017) * Use subquery to remove duplicates in SearchFilter * Align SearchFilter behaviour to django.contrib.admin * Add compatibility with older django/python versions * Allow search to split also by comma after smart split * Use generator to build search conditions to reduce iterations * Improve search documentation * Update docs/api-guide/filtering.md --------- Co-authored-by: Asif Saif Uddin --- docs/api-guide/filtering.md | 18 ++++---- rest_framework/compat.py | 8 ---- rest_framework/filters.py | 83 +++++++++++++++++++++++++++---------- tests/test_filters.py | 77 ++++++++++++++++++++++++++++++++-- 4 files changed, 146 insertions(+), 40 deletions(-) diff --git a/docs/api-guide/filtering.md b/docs/api-guide/filtering.md index 47ea8592d..ff5f3c775 100644 --- a/docs/api-guide/filtering.md +++ b/docs/api-guide/filtering.md @@ -213,19 +213,23 @@ This will allow the client to filter the items in the list by making queries suc You can also perform a related lookup on a ForeignKey or ManyToManyField with the lookup API double-underscore notation: search_fields = ['username', 'email', 'profile__profession'] - + For [JSONField][JSONField] and [HStoreField][HStoreField] fields you can filter based on nested values within the data structure using the same double-underscore notation: search_fields = ['data__breed', 'data__owner__other_pets__0__name'] -By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched. +By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched. Searches may contain _quoted phrases_ with spaces, each phrase is considered as a single search term. -The search behavior may be restricted by prepending various characters to the `search_fields`. -* '^' Starts-with search. -* '=' Exact matches. -* '@' Full-text search. (Currently only supported Django's [PostgreSQL backend][postgres-search].) -* '$' Regex search. +The search behavior may be specified by prefixing field names in `search_fields` with one of the following characters (which is equivalent to adding `__` to the field): + +| Prefix | Lookup | | +| ------ | --------------| ------------------ | +| `^` | `istartswith` | Starts-with search.| +| `=` | `iexact` | Exact matches. | +| `$` | `iregex` | Regex search. | +| `@` | `search` | Full-text search (Currently only supported Django's [PostgreSQL backend][postgres-search]). | +| None | `icontains` | Contains search (Default). | For example: diff --git a/rest_framework/compat.py b/rest_framework/compat.py index ac5cbc572..7e80704e1 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -3,7 +3,6 @@ The `compat` module provides support for backwards compatibility with older versions of Django/Python, and compatibility wrappers around optional packages. """ import django -from django.conf import settings from django.views.generic import View @@ -14,13 +13,6 @@ def unicode_http_header(value): return value -def distinct(queryset, base): - if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle": - # distinct analogue for Oracle users - return base.filter(pk__in=set(queryset.values_list('pk', flat=True))) - return queryset.distinct() - - # django.contrib.postgres requires psycopg2 try: from django.contrib.postgres import fields as postgres_fields diff --git a/rest_framework/filters.py b/rest_framework/filters.py index c48504957..065e72f94 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -6,18 +6,35 @@ import operator import warnings from functools import reduce -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.db import models from django.db.models.constants import LOOKUP_SEP from django.template import loader from django.utils.encoding import force_str +from django.utils.text import smart_split, unescape_string_literal from django.utils.translation import gettext_lazy as _ from rest_framework import RemovedInDRF317Warning -from rest_framework.compat import coreapi, coreschema, distinct +from rest_framework.compat import coreapi, coreschema +from rest_framework.fields import CharField from rest_framework.settings import api_settings +def search_smart_split(search_terms): + """generator that first splits string by spaces, leaving quoted phrases togheter, + then it splits non-quoted phrases by commas. + """ + for term in smart_split(search_terms): + # trim commas to avoid bad matching for quoted phrases + term = term.strip(',') + if term.startswith(('"', "'")) and term[0] == term[-1]: + # quoted phrases are kept togheter without any other split + yield unescape_string_literal(term) + else: + # non-quoted tokens are split by comma, keeping only non-empty ones + yield from (sub_term.strip() for sub_term in term.split(',') if sub_term) + + class BaseFilterBackend: """ A base class from which all filter backend classes should inherit. @@ -64,18 +81,41 @@ class SearchFilter(BaseFilterBackend): def get_search_terms(self, request): """ Search terms are set by a ?search=... query parameter, - and may be comma and/or whitespace delimited. + and may be whitespace delimited. """ - params = request.query_params.get(self.search_param, '') - params = params.replace('\x00', '') # strip null characters - params = params.replace(',', ' ') - return params.split() + value = request.query_params.get(self.search_param, '') + field = CharField(trim_whitespace=False, allow_blank=True) + return field.run_validation(value) - def construct_search(self, field_name): + def construct_search(self, field_name, queryset): lookup = self.lookup_prefixes.get(field_name[0]) if lookup: field_name = field_name[1:] else: + # Use field_name if it includes a lookup. + opts = queryset.model._meta + lookup_fields = field_name.split(LOOKUP_SEP) + # Go through the fields, following all relations. + prev_field = None + for path_part in lookup_fields: + if path_part == "pk": + path_part = opts.pk.name + try: + field = opts.get_field(path_part) + except FieldDoesNotExist: + # Use valid query lookups. + if prev_field and prev_field.get_lookup(path_part): + return field_name + else: + prev_field = field + if hasattr(field, "path_infos"): + # Update opts to follow the relation. + opts = field.path_infos[-1].to_opts + # django < 4.1 + elif hasattr(field, 'get_path_info'): + # Update opts to follow the relation. + opts = field.get_path_info()[-1].to_opts + # Otherwise, use the field with icontains. lookup = 'icontains' return LOOKUP_SEP.join([field_name, lookup]) @@ -113,26 +153,27 @@ class SearchFilter(BaseFilterBackend): return queryset orm_lookups = [ - self.construct_search(str(search_field)) + self.construct_search(str(search_field), queryset) for search_field in search_fields ] base = queryset - conditions = [] - for search_term in search_terms: - queries = [ - models.Q(**{orm_lookup: search_term}) - for orm_lookup in orm_lookups - ] - conditions.append(reduce(operator.or_, queries)) + # generator which for each term builds the corresponding search + conditions = ( + reduce( + operator.or_, + (models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups) + ) for term in search_smart_split(search_terms) + ) queryset = queryset.filter(reduce(operator.and_, conditions)) + # Remove duplicates from results, if necessary if self.must_call_distinct(queryset, search_fields): - # Filtering against a many-to-many field requires us to - # call queryset.distinct() in order to avoid duplicate items - # in the resulting queryset. - # We try to avoid this if possible, for performance reasons. - queryset = distinct(queryset, base) + # inspired by django.contrib.admin + # this is more accurate than .distinct form M2M relationship + # also is cross-database + queryset = queryset.filter(pk=models.OuterRef('pk')) + queryset = base.filter(models.Exists(queryset)) return queryset def to_html(self, request, queryset, view): diff --git a/tests/test_filters.py b/tests/test_filters.py index 2a22e30f9..6db0c3deb 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -6,16 +6,36 @@ from django.core.exceptions import ImproperlyConfigured from django.db import models from django.db.models import CharField, Transform from django.db.models.functions import Concat, Upper -from django.test import TestCase +from django.test import SimpleTestCase, TestCase from django.test.utils import override_settings from rest_framework import filters, generics, serializers from rest_framework.compat import coreschema +from rest_framework.exceptions import ValidationError from rest_framework.test import APIRequestFactory factory = APIRequestFactory() +class SearchSplitTests(SimpleTestCase): + + def test_keep_quoted_togheter_regardless_of_commas(self): + assert ['hello, world'] == list(filters.search_smart_split('"hello, world"')) + + def test_strips_commas_around_quoted(self): + assert ['hello, world'] == list(filters.search_smart_split(',,"hello, world"')) + assert ['hello, world'] == list(filters.search_smart_split(',,"hello, world",,')) + assert ['hello, world'] == list(filters.search_smart_split('"hello, world",,')) + + def test_splits_by_comma(self): + assert ['hello', 'world'] == list(filters.search_smart_split(',,hello, world')) + assert ['hello', 'world'] == list(filters.search_smart_split(',,hello, world,,')) + assert ['hello', 'world'] == list(filters.search_smart_split('hello, world,,')) + + def test_splits_quotes_followed_by_comma_and_sentence(self): + assert ['"hello', 'world"', 'found'] == list(filters.search_smart_split('"hello, world",found')) + + class BaseFilterTests(TestCase): def setUp(self): self.original_coreapi = filters.coreapi @@ -50,7 +70,8 @@ class SearchFilterSerializer(serializers.ModelSerializer): class SearchFilterTests(TestCase): - def setUp(self): + @classmethod + def setUpTestData(cls): # Sequence of title/text is: # # z abc @@ -66,6 +87,9 @@ class SearchFilterTests(TestCase): ) SearchFilterModel(title=title, text=text).save() + SearchFilterModel(title='A title', text='The long text').save() + SearchFilterModel(title='The title', text='The "text').save() + def test_search(self): class SearchListView(generics.ListAPIView): queryset = SearchFilterModel.objects.all() @@ -186,9 +210,21 @@ class SearchFilterTests(TestCase): request = factory.get('/?search=\0as%00d\x00f') request = view.initialize_request(request) - terms = filters.SearchFilter().get_search_terms(request) + with self.assertRaises(ValidationError): + filters.SearchFilter().get_search_terms(request) - assert terms == ['asdf'] + def test_search_field_with_custom_lookup(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('text__iendswith',) + view = SearchListView.as_view() + request = factory.get('/', {'search': 'c'}) + response = view(request) + assert response.data == [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + ] def test_search_field_with_additional_transforms(self): from django.test.utils import register_lookup @@ -242,6 +278,32 @@ class SearchFilterTests(TestCase): ) assert search_query in rendered_search_field + def test_search_field_with_escapes(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text',) + view = SearchListView.as_view() + request = factory.get('/', {'search': '"\\\"text"'}) + response = view(request) + assert response.data == [ + {'id': 12, 'title': 'The title', 'text': 'The "text'}, + ] + + def test_search_field_with_quotes(self): + class SearchListView(generics.ListAPIView): + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text',) + view = SearchListView.as_view() + request = factory.get('/', {'search': '"long text"'}) + response = view(request) + assert response.data == [ + {'id': 11, 'title': 'A title', 'text': 'The long text'}, + ] + class AttributeModel(models.Model): label = models.CharField(max_length=32) @@ -284,6 +346,13 @@ class SearchFilterFkTests(TestCase): ["%sattribute__label" % prefix, "%stitle" % prefix] ) + def test_custom_lookup_to_related_model(self): + # In this test case the attribute of the fk model comes first in the + # list of search fields. + filter_ = filters.SearchFilter() + assert 'attribute__label__icontains' == filter_.construct_search('attribute__label', SearchFilterModelFk._meta) + assert 'attribute__label__iendswith' == filter_.construct_search('attribute__label__iendswith', SearchFilterModelFk._meta) + class SearchFilterModelM2M(models.Model): title = models.CharField(max_length=20) From 589b5dca9e7613f7742af8baed6ed870476dd23b Mon Sep 17 00:00:00 2001 From: Pierre Chiquet Date: Wed, 26 Jul 2023 06:27:49 +0200 Subject: [PATCH 6/6] Allow to override child.run_validation call in ListSerializer (#8035) * Separated run_child_validation method in ListSerializer * fix typo * Add test_update_allow_custom_child_validation --------- Co-authored-by: Pierre Chiquet --- rest_framework/serializers.py | 13 +++++++- tests/test_serializer_lists.py | 57 +++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 6ee75fbc1..77c181b6c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -653,6 +653,17 @@ class ListSerializer(BaseSerializer): return value + def run_child_validation(self, data): + """ + Run validation on child serializer. + You may need to override this method to support multiple updates. For example: + + self.child.instance = self.instance.get(pk=data['id']) + self.child.initial_data = data + return super().run_child_validation(data) + """ + return self.child.run_validation(data) + def to_internal_value(self, data): """ List of dicts of native values <- List of dicts of primitive datatypes. @@ -697,7 +708,7 @@ class ListSerializer(BaseSerializer): ): self.child.instance = self.instance[idx] try: - validated = self.child.run_validation(item) + validated = self.run_child_validation(item) except ValidationError as exc: errors.append(exc.detail) else: diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index 10463d29a..4070de7a5 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -153,6 +153,61 @@ class TestListSerializerContainingNestedSerializer: assert serializer.is_valid() assert serializer.validated_data == expected_output + def test_update_allow_custom_child_validation(self): + """ + Update a list of objects thanks custom run_child_validation implementation. + """ + + class TestUpdateSerializer(serializers.Serializer): + integer = serializers.IntegerField() + boolean = serializers.BooleanField() + + def update(self, instance, validated_data): + instance._data.update(validated_data) + return instance + + def validate(self, data): + # self.instance is set to current BasicObject instance + assert isinstance(self.instance, BasicObject) + # self.initial_data is current dictionary + assert isinstance(self.initial_data, dict) + assert self.initial_data["pk"] == self.instance.pk + return super().validate(data) + + class ListUpdateSerializer(serializers.ListSerializer): + child = TestUpdateSerializer() + + def run_child_validation(self, data): + # find related instance in self.instance list + child_instance = next(o for o in self.instance if o.pk == data["pk"]) + # set instance and initial_data for child serializer + self.child.instance = child_instance + self.child.initial_data = data + return super().run_child_validation(data) + + def update(self, instance, validated_data): + return [ + self.child.update(instance, attrs) + for instance, attrs in zip(self.instance, validated_data) + ] + + instance = [ + BasicObject(pk=1, integer=11, private_field="a"), + BasicObject(pk=2, integer=22, private_field="b"), + ] + input_data = [ + {"pk": 1, "integer": "123", "boolean": "true"}, + {"pk": 2, "integer": "456", "boolean": "false"}, + ] + expected_output = [ + BasicObject(pk=1, integer=123, boolean=True, private_field="a"), + BasicObject(pk=2, integer=456, boolean=False, private_field="b"), + ] + serializer = ListUpdateSerializer(instance, data=input_data) + assert serializer.is_valid() + updated_instances = serializer.save() + assert updated_instances == expected_output + class TestNestedListSerializer: """ @@ -481,7 +536,7 @@ class TestSerializerPartialUsage: assert serializer.validated_data == {} assert serializer.errors == {} - def test_udate_as_field_allow_empty_true(self): + def test_update_as_field_allow_empty_true(self): class ListSerializer(serializers.Serializer): update_field = serializers.IntegerField() store_field = serializers.IntegerField()