From b72027fcdbd3c2e7c32ade4abd85fb53512c18d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Escolano?= Date: Fri, 5 Jan 2024 13:19:31 +0100 Subject: [PATCH] Group queries for SlugRelatedField many serializers --- rest_framework/relations.py | 21 ++++++++++++++++----- tests/test_relations_slug.py | 6 ++++++ tests/utils.py | 10 ++++++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index fe0acfc09..d4b59c46d 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -4,7 +4,7 @@ from operator import attrgetter from urllib import parse from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist -from django.db.models import Manager +from django.db.models import F, Manager from django.db.models.query import QuerySet from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve from django.utils.encoding import smart_str, uri_to_iri @@ -458,15 +458,26 @@ class SlugRelatedField(RelatedField): self.slug_field = slug_field super().__init__(**kwargs) - def to_internal_value(self, data): + def to_many_internal_value(self, data): queryset = self.get_queryset() try: - return queryset.get(**{self.slug_field: data}) - except ObjectDoesNotExist: - self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(data)) + result = ( + queryset + .filter(**{self.slug_field + "__in": data}) + .annotate(_slug_field_value=F(self.slug_field)) + .all() + ) + slugs = [item._slug_field_value for item in result] + for item in data: + if item not in slugs: + self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(item)) + return result except (TypeError, ValueError): self.fail('invalid') + def to_internal_value(self, data): + return self.to_many_internal_value([data])[0] + def to_representation(self, obj): slug = self.slug_field if "__" in slug: diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py index 0b9ca79d3..c0343cb99 100644 --- a/tests/test_relations_slug.py +++ b/tests/test_relations_slug.py @@ -174,6 +174,12 @@ class SlugForeignKeyTests(TestCase): ] assert serializer.data == expected + def test_reverse_foreign_key_create_grouped_queries(self): + data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} + serializer = ForeignKeyTargetSerializer(data=data) + with self.assertNumQueries(1): + assert serializer.is_valid() + def test_foreign_key_update_with_invalid_null(self): data = {'id': 1, 'name': 'source-1', 'target': None} instance = ForeignKeySource.objects.get(pk=1) diff --git a/tests/utils.py b/tests/utils.py index 799d98ec7..b7c2813a6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,7 +35,7 @@ class MockQueryset: return list(self.items) def filter(self, **lookup): - return MockQueryset( + return MockQueryset([ item for item in self.items if all([ @@ -44,7 +44,13 @@ class MockQueryset: else attrgetter(key.replace('__', '.'))(item) == value for key, value in lookup.items() ]) - ) + ]) + + def annotate(self, **kwargs): + for key, value in kwargs.items(): + for item in self.items: + setattr(item, key, attrgetter(value.name.replace('__', '.'))(item)) + return self class BadType: