Re-prefetch related objects after updating (#8043)

* Re-prefetch related objects after updating

* Fix flake8 format

* Use _prefetch_related_lookups and refine test cases

* Add more test cases and refine prefetch checking
This commit is contained in:
Yuekui 2023-01-11 01:30:15 -08:00 committed by GitHub
parent bfce663a60
commit 2b34aa4291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 32 deletions

View File

@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
from django.db.models.query import prefetch_related_objects
from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
@ -67,10 +69,13 @@ class UpdateModelMixin:
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
queryset = self.filter_queryset(self.get_queryset())
if queryset._prefetch_related_lookups:
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
# forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
return Response(serializer.data)

View File

@ -1,4 +1,5 @@
from django.contrib.auth.models import Group, User
from django.db.models.query import Prefetch
from django.test import TestCase
from rest_framework import generics, serializers
@ -8,51 +9,84 @@ factory = APIRequestFactory()
class UserSerializer(serializers.ModelSerializer):
permissions = serializers.SerializerMethodField()
def get_permissions(self, obj):
ret = []
for g in obj.groups.all():
ret.extend([p.pk for p in g.permissions.all()])
return ret
class Meta:
model = User
fields = ('id', 'username', 'email', 'groups')
fields = ('id', 'username', 'email', 'groups', 'permissions')
class UserUpdate(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
class UserRetrieveUpdate(generics.RetrieveUpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related(
Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
'groups__permissions',
)
serializer_class = UserSerializer
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer
class TestPrefetchRelatedUpdates(TestCase):
def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
self.user.groups.set(self.groups)
self.user.groups.add(Group.objects.create(name='exclude'))
self.expected = {
'id': self.user.pk,
'username': 'tom',
'groups': [group.pk for group in self.groups],
'email': 'tom@example.com',
'permissions': [],
}
self.view = UserRetrieveUpdate.as_view()
def test_prefetch_related_updates(self):
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'new',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
self.groups.append(Group.objects.create(name='c'))
request = factory.put(
'/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
)
self.expected['username'] = 'new'
self.expected['groups'] = [group.pk for group in self.groups]
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 12
assert response.data == self.expected
# Update and fetch should get same result
request = factory.get('/')
response = self.view(request, pk=self.user.pk)
assert response.data == self.expected
def test_prefetch_related_excluding_instance_from_original_queryset(self):
"""
Regression test for https://github.com/encode/django-rest-framework/issues/4661
"""
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
request = factory.put(
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
)
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 2
self.expected['username'] = 'exclude'
self.expected['groups'] = [self.groups[0].pk]
assert response.data == self.expected
def test_db_query_count(self):
request = factory.put(
'/', {'username': 'new'}, format='json'
)
with self.assertNumQueries(7):
self.view(request, pk=self.user.pk)
request = factory.put(
'/', {'username': 'new2'}, format='json'
)
with self.assertNumQueries(16):
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)