diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 6031b06ad..6ac6366c7 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -69,12 +69,12 @@ class UpdateModelMixin: serializer.is_valid(raise_exception=True) self.perform_update(serializer) - if getattr(instance, '_prefetched_objects_cache', None): + queryset = self.filter_queryset(self.get_queryset()) + if queryset._prefetch_related_lookups: # If 'prefetch_related' has been applied to a queryset, we need to # forcibly invalidate the prefetch cache on the instance, # and then re-prefetch related objects instance._prefetched_objects_cache = {} - queryset = self.filter_queryset(self.get_queryset()) prefetch_related_objects([instance], *queryset._prefetch_related_lookups) return Response(serializer.data) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index d38066690..8e7bcf4ac 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -1,4 +1,5 @@ from django.contrib.auth.models import Group, User +from django.db.models.query import Prefetch from django.test import TestCase from rest_framework import generics, serializers @@ -21,8 +22,11 @@ class UserSerializer(serializers.ModelSerializer): fields = ('id', 'username', 'email', 'groups', 'permissions') -class UserUpdate(generics.UpdateAPIView): - queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions') +class UserRetrieveUpdate(generics.RetrieveUpdateAPIView): + queryset = User.objects.exclude(username='exclude').prefetch_related( + Prefetch('groups', queryset=Group.objects.exclude(name='exclude')), + 'groups__permissions', + ) serializer_class = UserSerializer @@ -36,6 +40,7 @@ class TestPrefetchRelatedUpdates(TestCase): self.user = User.objects.create(username='tom', email='tom@example.com') self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] self.user.groups.set(self.groups) + self.user.groups.add(Group.objects.create(name='exclude')) self.expected = { 'id': self.user.pk, 'username': 'tom', @@ -43,7 +48,7 @@ class TestPrefetchRelatedUpdates(TestCase): 'email': 'tom@example.com', 'permissions': [], } - self.view = UserUpdate.as_view() + self.view = UserRetrieveUpdate.as_view() def test_prefetch_related_updates(self): self.groups.append(Group.objects.create(name='c')) @@ -53,7 +58,11 @@ class TestPrefetchRelatedUpdates(TestCase): self.expected['username'] = 'new' self.expected['groups'] = [group.pk for group in self.groups] response = self.view(request, pk=self.user.pk) - assert User.objects.get(pk=self.user.pk).groups.count() == 11 + assert User.objects.get(pk=self.user.pk).groups.count() == 12 + assert response.data == self.expected + # Update and fetch should get same result + request = factory.get('/') + response = self.view(request, pk=self.user.pk) assert response.data == self.expected def test_prefetch_related_excluding_instance_from_original_queryset(self): @@ -64,7 +73,7 @@ class TestPrefetchRelatedUpdates(TestCase): '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' ) response = self.view(request, pk=self.user.pk) - assert User.objects.get(pk=self.user.pk).groups.count() == 1 + assert User.objects.get(pk=self.user.pk).groups.count() == 2 self.expected['username'] = 'exclude' self.expected['groups'] = [self.groups[0].pk] assert response.data == self.expected @@ -79,5 +88,5 @@ class TestPrefetchRelatedUpdates(TestCase): request = factory.put( '/', {'username': 'new2'}, format='json' ) - with self.assertNumQueries(15): + with self.assertNumQueries(16): UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)