Add more test cases and refine prefetch checking

This commit is contained in:
Yuekui Li 2022-11-23 19:36:18 -08:00
parent 91a9582ec6
commit f322c04159
2 changed files with 17 additions and 8 deletions

View File

@ -69,12 +69,12 @@ class UpdateModelMixin:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_update(serializer) self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None): queryset = self.filter_queryset(self.get_queryset())
if queryset._prefetch_related_lookups:
# If 'prefetch_related' has been applied to a queryset, we need to # If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance, # forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects # and then re-prefetch related objects
instance._prefetched_objects_cache = {} instance._prefetched_objects_cache = {}
queryset = self.filter_queryset(self.get_queryset())
prefetch_related_objects([instance], *queryset._prefetch_related_lookups) prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
return Response(serializer.data) return Response(serializer.data)

View File

@ -1,4 +1,5 @@
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from django.db.models.query import Prefetch
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, serializers from rest_framework import generics, serializers
@ -21,8 +22,11 @@ class UserSerializer(serializers.ModelSerializer):
fields = ('id', 'username', 'email', 'groups', 'permissions') fields = ('id', 'username', 'email', 'groups', 'permissions')
class UserUpdate(generics.UpdateAPIView): class UserRetrieveUpdate(generics.RetrieveUpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions') queryset = User.objects.exclude(username='exclude').prefetch_related(
Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
'groups__permissions',
)
serializer_class = UserSerializer serializer_class = UserSerializer
@ -36,6 +40,7 @@ class TestPrefetchRelatedUpdates(TestCase):
self.user = User.objects.create(username='tom', email='tom@example.com') self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
self.user.groups.set(self.groups) self.user.groups.set(self.groups)
self.user.groups.add(Group.objects.create(name='exclude'))
self.expected = { self.expected = {
'id': self.user.pk, 'id': self.user.pk,
'username': 'tom', 'username': 'tom',
@ -43,7 +48,7 @@ class TestPrefetchRelatedUpdates(TestCase):
'email': 'tom@example.com', 'email': 'tom@example.com',
'permissions': [], 'permissions': [],
} }
self.view = UserUpdate.as_view() self.view = UserRetrieveUpdate.as_view()
def test_prefetch_related_updates(self): def test_prefetch_related_updates(self):
self.groups.append(Group.objects.create(name='c')) self.groups.append(Group.objects.create(name='c'))
@ -53,7 +58,11 @@ class TestPrefetchRelatedUpdates(TestCase):
self.expected['username'] = 'new' self.expected['username'] = 'new'
self.expected['groups'] = [group.pk for group in self.groups] self.expected['groups'] = [group.pk for group in self.groups]
response = self.view(request, pk=self.user.pk) response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 11 assert User.objects.get(pk=self.user.pk).groups.count() == 12
assert response.data == self.expected
# Update and fetch should get same result
request = factory.get('/')
response = self.view(request, pk=self.user.pk)
assert response.data == self.expected assert response.data == self.expected
def test_prefetch_related_excluding_instance_from_original_queryset(self): def test_prefetch_related_excluding_instance_from_original_queryset(self):
@ -64,7 +73,7 @@ class TestPrefetchRelatedUpdates(TestCase):
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
) )
response = self.view(request, pk=self.user.pk) response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1 assert User.objects.get(pk=self.user.pk).groups.count() == 2
self.expected['username'] = 'exclude' self.expected['username'] = 'exclude'
self.expected['groups'] = [self.groups[0].pk] self.expected['groups'] = [self.groups[0].pk]
assert response.data == self.expected assert response.data == self.expected
@ -79,5 +88,5 @@ class TestPrefetchRelatedUpdates(TestCase):
request = factory.put( request = factory.put(
'/', {'username': 'new2'}, format='json' '/', {'username': 'new2'}, format='json'
) )
with self.assertNumQueries(15): with self.assertNumQueries(16):
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk) UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)