Group queries for PrimaryKeyRelatedField many serializers

This commit is contained in:
Clément Escolano 2024-01-05 12:52:31 +01:00
parent 047bec1288
commit d6ca95f4c7
No known key found for this signature in database
GPG Key ID: DE78E2F63131054A
3 changed files with 43 additions and 14 deletions

View File

@ -249,18 +249,25 @@ class PrimaryKeyRelatedField(RelatedField):
def use_pk_only_optimization(self):
return True
def to_internal_value(self, data):
def to_many_internal_value(self, data):
if self.pk_field is not None:
data = self.pk_field.to_internal_value(data)
data = [self.pk_field.to_internal_value(item) for item in data]
queryset = self.get_queryset()
try:
if isinstance(data, bool):
raise TypeError
return queryset.get(pk=data)
except ObjectDoesNotExist:
self.fail('does_not_exist', pk_value=data)
for item in data:
if isinstance(item, bool):
raise TypeError
result = queryset.filter(pk__in=data).all()
pks = [item.pk for item in result]
for item in data:
if item not in pks:
self.fail('does_not_exist', pk_value=item)
return result
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)
self.fail('incorrect_type', data_type=type(data[0]).__name__)
def to_internal_value(self, data):
return self.to_many_internal_value([data])[0]
def to_representation(self, value):
if self.pk_field is not None:
@ -524,6 +531,9 @@ class ManyRelatedField(Field):
if not self.allow_empty and len(data) == 0:
self.fail('empty')
if hasattr(self.child_relation, "to_many_internal_value"):
return self.child_relation.to_many_internal_value(data)
return [
self.child_relation.to_internal_value(item)
for item in data

View File

@ -189,6 +189,13 @@ class PKManyToManyTests(TestCase):
]
assert serializer.data == expected
def test_many_to_many_grouped_queries(self):
data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data)
# Only one query should be executed even with several targets
with self.assertNumQueries(1):
assert serializer.is_valid()
def test_many_to_many_unsaved(self):
source = ManyToManySource(name='source-unsaved')

View File

@ -26,14 +26,26 @@ class MockQueryset:
return self.items[val]
def get(self, **lookup):
for item in self.items:
if all([
attrgetter(key.replace('__', '.'))(item) == value
for key, value in lookup.items()
]):
return item
result = self.filter(**lookup).all()
if len(result) > 0:
return result[0]
raise ObjectDoesNotExist()
def all(self):
return list(self.items)
def filter(self, **lookup):
return MockQueryset(
item
for item in self.items
if all([
attrgetter(key.replace("__in", "").replace('__', '.'))(item) in value
if key.endswith("__in")
else attrgetter(key.replace('__', '.'))(item) == value
for key, value in lookup.items()
])
)
class BadType:
"""