From d6ca95f4c7ae1d1bdbd0a72b529da5e0ffd1e9f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Escolano?= Date: Fri, 5 Jan 2024 12:52:31 +0100 Subject: [PATCH] Group queries for PrimaryKeyRelatedField many serializers --- rest_framework/relations.py | 26 ++++++++++++++++++-------- tests/test_relations_pk.py | 7 +++++++ tests/utils.py | 24 ++++++++++++++++++------ 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4409bce77..fe0acfc09 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -249,18 +249,25 @@ class PrimaryKeyRelatedField(RelatedField): def use_pk_only_optimization(self): return True - def to_internal_value(self, data): + def to_many_internal_value(self, data): if self.pk_field is not None: - data = self.pk_field.to_internal_value(data) + data = [self.pk_field.to_internal_value(item) for item in data] queryset = self.get_queryset() try: - if isinstance(data, bool): - raise TypeError - return queryset.get(pk=data) - except ObjectDoesNotExist: - self.fail('does_not_exist', pk_value=data) + for item in data: + if isinstance(item, bool): + raise TypeError + result = queryset.filter(pk__in=data).all() + pks = [item.pk for item in result] + for item in data: + if item not in pks: + self.fail('does_not_exist', pk_value=item) + return result except (TypeError, ValueError): - self.fail('incorrect_type', data_type=type(data).__name__) + self.fail('incorrect_type', data_type=type(data[0]).__name__) + + def to_internal_value(self, data): + return self.to_many_internal_value([data])[0] def to_representation(self, value): if self.pk_field is not None: @@ -524,6 +531,9 @@ class ManyRelatedField(Field): if not self.allow_empty and len(data) == 0: self.fail('empty') + if hasattr(self.child_relation, "to_many_internal_value"): + return self.child_relation.to_many_internal_value(data) + return [ self.child_relation.to_internal_value(item) for item in data diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index 7a4878a2b..260229a91 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -189,6 +189,13 @@ class PKManyToManyTests(TestCase): ] assert serializer.data == expected + def test_many_to_many_grouped_queries(self): + data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} + serializer = ManyToManySourceSerializer(data=data) + # Only one query should be executed even with several targets + with self.assertNumQueries(1): + assert serializer.is_valid() + def test_many_to_many_unsaved(self): source = ManyToManySource(name='source-unsaved') diff --git a/tests/utils.py b/tests/utils.py index 4ceb35309..799d98ec7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,14 +26,26 @@ class MockQueryset: return self.items[val] def get(self, **lookup): - for item in self.items: - if all([ - attrgetter(key.replace('__', '.'))(item) == value - for key, value in lookup.items() - ]): - return item + result = self.filter(**lookup).all() + if len(result) > 0: + return result[0] raise ObjectDoesNotExist() + def all(self): + return list(self.items) + + def filter(self, **lookup): + return MockQueryset( + item + for item in self.items + if all([ + attrgetter(key.replace("__in", "").replace('__', '.'))(item) in value + if key.endswith("__in") + else attrgetter(key.replace('__', '.'))(item) == value + for key, value in lookup.items() + ]) + ) + class BadType: """