mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-10 19:56:59 +03:00
Re-prefetch related objects after updating (#8043)
* Re-prefetch related objects after updating * Fix flake8 format * Use _prefetch_related_lookups and refine test cases * Add more test cases and refine prefetch checking
This commit is contained in:
parent
bfce663a60
commit
2b34aa4291
|
@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
|
|||
We don't bind behaviour to http method handlers yet,
|
||||
which allows mixin classes to be composed in interesting ways.
|
||||
"""
|
||||
from django.db.models.query import prefetch_related_objects
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
|
@ -67,10 +69,13 @@ class UpdateModelMixin:
|
|||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_update(serializer)
|
||||
|
||||
if getattr(instance, '_prefetched_objects_cache', None):
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
if queryset._prefetch_related_lookups:
|
||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||
# forcibly invalidate the prefetch cache on the instance.
|
||||
# forcibly invalidate the prefetch cache on the instance,
|
||||
# and then re-prefetch related objects
|
||||
instance._prefetched_objects_cache = {}
|
||||
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from django.contrib.auth.models import Group, User
|
||||
from django.db.models.query import Prefetch
|
||||
from django.test import TestCase
|
||||
|
||||
from rest_framework import generics, serializers
|
||||
|
@ -8,51 +9,84 @@ factory = APIRequestFactory()
|
|||
|
||||
|
||||
class UserSerializer(serializers.ModelSerializer):
|
||||
permissions = serializers.SerializerMethodField()
|
||||
|
||||
def get_permissions(self, obj):
|
||||
ret = []
|
||||
for g in obj.groups.all():
|
||||
ret.extend([p.pk for p in g.permissions.all()])
|
||||
return ret
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ('id', 'username', 'email', 'groups')
|
||||
fields = ('id', 'username', 'email', 'groups', 'permissions')
|
||||
|
||||
|
||||
class UserUpdate(generics.UpdateAPIView):
|
||||
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
|
||||
class UserRetrieveUpdate(generics.RetrieveUpdateAPIView):
|
||||
queryset = User.objects.exclude(username='exclude').prefetch_related(
|
||||
Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
|
||||
'groups__permissions',
|
||||
)
|
||||
serializer_class = UserSerializer
|
||||
|
||||
|
||||
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
|
||||
queryset = User.objects.exclude(username='exclude')
|
||||
serializer_class = UserSerializer
|
||||
|
||||
|
||||
class TestPrefetchRelatedUpdates(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create(username='tom', email='tom@example.com')
|
||||
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
|
||||
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
|
||||
self.user.groups.set(self.groups)
|
||||
self.user.groups.add(Group.objects.create(name='exclude'))
|
||||
self.expected = {
|
||||
'id': self.user.pk,
|
||||
'username': 'tom',
|
||||
'groups': [group.pk for group in self.groups],
|
||||
'email': 'tom@example.com',
|
||||
'permissions': [],
|
||||
}
|
||||
self.view = UserRetrieveUpdate.as_view()
|
||||
|
||||
def test_prefetch_related_updates(self):
|
||||
view = UserUpdate.as_view()
|
||||
pk = self.user.pk
|
||||
groups_pk = self.groups[0].pk
|
||||
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
|
||||
response = view(request, pk=pk)
|
||||
assert User.objects.get(pk=pk).groups.count() == 1
|
||||
expected = {
|
||||
'id': pk,
|
||||
'username': 'new',
|
||||
'groups': [1],
|
||||
'email': 'tom@example.com'
|
||||
}
|
||||
assert response.data == expected
|
||||
self.groups.append(Group.objects.create(name='c'))
|
||||
request = factory.put(
|
||||
'/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
|
||||
)
|
||||
self.expected['username'] = 'new'
|
||||
self.expected['groups'] = [group.pk for group in self.groups]
|
||||
response = self.view(request, pk=self.user.pk)
|
||||
assert User.objects.get(pk=self.user.pk).groups.count() == 12
|
||||
assert response.data == self.expected
|
||||
# Update and fetch should get same result
|
||||
request = factory.get('/')
|
||||
response = self.view(request, pk=self.user.pk)
|
||||
assert response.data == self.expected
|
||||
|
||||
def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
||||
"""
|
||||
Regression test for https://github.com/encode/django-rest-framework/issues/4661
|
||||
"""
|
||||
view = UserUpdate.as_view()
|
||||
pk = self.user.pk
|
||||
groups_pk = self.groups[0].pk
|
||||
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
|
||||
response = view(request, pk=pk)
|
||||
assert User.objects.get(pk=pk).groups.count() == 1
|
||||
expected = {
|
||||
'id': pk,
|
||||
'username': 'exclude',
|
||||
'groups': [1],
|
||||
'email': 'tom@example.com'
|
||||
}
|
||||
assert response.data == expected
|
||||
request = factory.put(
|
||||
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
|
||||
)
|
||||
response = self.view(request, pk=self.user.pk)
|
||||
assert User.objects.get(pk=self.user.pk).groups.count() == 2
|
||||
self.expected['username'] = 'exclude'
|
||||
self.expected['groups'] = [self.groups[0].pk]
|
||||
assert response.data == self.expected
|
||||
|
||||
def test_db_query_count(self):
|
||||
request = factory.put(
|
||||
'/', {'username': 'new'}, format='json'
|
||||
)
|
||||
with self.assertNumQueries(7):
|
||||
self.view(request, pk=self.user.pk)
|
||||
|
||||
request = factory.put(
|
||||
'/', {'username': 'new2'}, format='json'
|
||||
)
|
||||
with self.assertNumQueries(16):
|
||||
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)
|
||||
|
|
Loading…
Reference in New Issue
Block a user