mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-30 21:44:04 +03:00
Group queries for PrimaryKeyRelatedField many serializers
This commit is contained in:
parent
047bec1288
commit
d6ca95f4c7
|
@ -249,18 +249,25 @@ class PrimaryKeyRelatedField(RelatedField):
|
|||
def use_pk_only_optimization(self):
|
||||
return True
|
||||
|
||||
def to_internal_value(self, data):
|
||||
def to_many_internal_value(self, data):
|
||||
if self.pk_field is not None:
|
||||
data = self.pk_field.to_internal_value(data)
|
||||
data = [self.pk_field.to_internal_value(item) for item in data]
|
||||
queryset = self.get_queryset()
|
||||
try:
|
||||
if isinstance(data, bool):
|
||||
raise TypeError
|
||||
return queryset.get(pk=data)
|
||||
except ObjectDoesNotExist:
|
||||
self.fail('does_not_exist', pk_value=data)
|
||||
for item in data:
|
||||
if isinstance(item, bool):
|
||||
raise TypeError
|
||||
result = queryset.filter(pk__in=data).all()
|
||||
pks = [item.pk for item in result]
|
||||
for item in data:
|
||||
if item not in pks:
|
||||
self.fail('does_not_exist', pk_value=item)
|
||||
return result
|
||||
except (TypeError, ValueError):
|
||||
self.fail('incorrect_type', data_type=type(data).__name__)
|
||||
self.fail('incorrect_type', data_type=type(data[0]).__name__)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
return self.to_many_internal_value([data])[0]
|
||||
|
||||
def to_representation(self, value):
|
||||
if self.pk_field is not None:
|
||||
|
@ -524,6 +531,9 @@ class ManyRelatedField(Field):
|
|||
if not self.allow_empty and len(data) == 0:
|
||||
self.fail('empty')
|
||||
|
||||
if hasattr(self.child_relation, "to_many_internal_value"):
|
||||
return self.child_relation.to_many_internal_value(data)
|
||||
|
||||
return [
|
||||
self.child_relation.to_internal_value(item)
|
||||
for item in data
|
||||
|
|
|
@ -189,6 +189,13 @@ class PKManyToManyTests(TestCase):
|
|||
]
|
||||
assert serializer.data == expected
|
||||
|
||||
def test_many_to_many_grouped_queries(self):
|
||||
data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
|
||||
serializer = ManyToManySourceSerializer(data=data)
|
||||
# Only one query should be executed even with several targets
|
||||
with self.assertNumQueries(1):
|
||||
assert serializer.is_valid()
|
||||
|
||||
def test_many_to_many_unsaved(self):
|
||||
source = ManyToManySource(name='source-unsaved')
|
||||
|
||||
|
|
|
@ -26,14 +26,26 @@ class MockQueryset:
|
|||
return self.items[val]
|
||||
|
||||
def get(self, **lookup):
|
||||
for item in self.items:
|
||||
if all([
|
||||
attrgetter(key.replace('__', '.'))(item) == value
|
||||
for key, value in lookup.items()
|
||||
]):
|
||||
return item
|
||||
result = self.filter(**lookup).all()
|
||||
if len(result) > 0:
|
||||
return result[0]
|
||||
raise ObjectDoesNotExist()
|
||||
|
||||
def all(self):
|
||||
return list(self.items)
|
||||
|
||||
def filter(self, **lookup):
|
||||
return MockQueryset(
|
||||
item
|
||||
for item in self.items
|
||||
if all([
|
||||
attrgetter(key.replace("__in", "").replace('__', '.'))(item) in value
|
||||
if key.endswith("__in")
|
||||
else attrgetter(key.replace('__', '.'))(item) == value
|
||||
for key, value in lookup.items()
|
||||
])
|
||||
)
|
||||
|
||||
|
||||
class BadType:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user