Re-prefetch related objects after updating (#8043)

* Re-prefetch related objects after updating

* Fix flake8 format

* Use _prefetch_related_lookups and refine test cases

* Add more test cases and refine prefetch checking
This commit is contained in:
Yuekui 2023-01-11 01:30:15 -08:00 committed by GitHub
parent bfce663a60
commit 2b34aa4291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 32 deletions

View File

@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet, We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways. which allows mixin classes to be composed in interesting ways.
""" """
from django.db.models.query import prefetch_related_objects
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -67,10 +69,13 @@ class UpdateModelMixin:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_update(serializer) self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None): queryset = self.filter_queryset(self.get_queryset())
if queryset._prefetch_related_lookups:
# If 'prefetch_related' has been applied to a queryset, we need to # If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance. # forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {} instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
return Response(serializer.data) return Response(serializer.data)

View File

@ -1,4 +1,5 @@
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from django.db.models.query import Prefetch
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, serializers from rest_framework import generics, serializers
@ -8,51 +9,84 @@ factory = APIRequestFactory()
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
permissions = serializers.SerializerMethodField()
def get_permissions(self, obj):
ret = []
for g in obj.groups.all():
ret.extend([p.pk for p in g.permissions.all()])
return ret
class Meta: class Meta:
model = User model = User
fields = ('id', 'username', 'email', 'groups') fields = ('id', 'username', 'email', 'groups', 'permissions')
class UserUpdate(generics.UpdateAPIView): class UserRetrieveUpdate(generics.RetrieveUpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups') queryset = User.objects.exclude(username='exclude').prefetch_related(
Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
'groups__permissions',
)
serializer_class = UserSerializer
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer serializer_class = UserSerializer
class TestPrefetchRelatedUpdates(TestCase): class TestPrefetchRelatedUpdates(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com') self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
self.user.groups.set(self.groups) self.user.groups.set(self.groups)
self.user.groups.add(Group.objects.create(name='exclude'))
self.expected = {
'id': self.user.pk,
'username': 'tom',
'groups': [group.pk for group in self.groups],
'email': 'tom@example.com',
'permissions': [],
}
self.view = UserRetrieveUpdate.as_view()
def test_prefetch_related_updates(self): def test_prefetch_related_updates(self):
view = UserUpdate.as_view() self.groups.append(Group.objects.create(name='c'))
pk = self.user.pk request = factory.put(
groups_pk = self.groups[0].pk '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') )
response = view(request, pk=pk) self.expected['username'] = 'new'
assert User.objects.get(pk=pk).groups.count() == 1 self.expected['groups'] = [group.pk for group in self.groups]
expected = { response = self.view(request, pk=self.user.pk)
'id': pk, assert User.objects.get(pk=self.user.pk).groups.count() == 12
'username': 'new', assert response.data == self.expected
'groups': [1], # Update and fetch should get same result
'email': 'tom@example.com' request = factory.get('/')
} response = self.view(request, pk=self.user.pk)
assert response.data == expected assert response.data == self.expected
def test_prefetch_related_excluding_instance_from_original_queryset(self): def test_prefetch_related_excluding_instance_from_original_queryset(self):
""" """
Regression test for https://github.com/encode/django-rest-framework/issues/4661 Regression test for https://github.com/encode/django-rest-framework/issues/4661
""" """
view = UserUpdate.as_view() request = factory.put(
pk = self.user.pk '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
groups_pk = self.groups[0].pk )
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') response = self.view(request, pk=self.user.pk)
response = view(request, pk=pk) assert User.objects.get(pk=self.user.pk).groups.count() == 2
assert User.objects.get(pk=pk).groups.count() == 1 self.expected['username'] = 'exclude'
expected = { self.expected['groups'] = [self.groups[0].pk]
'id': pk, assert response.data == self.expected
'username': 'exclude',
'groups': [1], def test_db_query_count(self):
'email': 'tom@example.com' request = factory.put(
} '/', {'username': 'new'}, format='json'
assert response.data == expected )
with self.assertNumQueries(7):
self.view(request, pk=self.user.pk)
request = factory.put(
'/', {'username': 'new2'}, format='json'
)
with self.assertNumQueries(16):
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)