diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 6ac6366c7..7fa8947cb 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -4,8 +4,6 @@ Basic building blocks for generic class based views. We don't bind behaviour to http method handlers yet, which allows mixin classes to be composed in interesting ways. """ -from django.db.models.query import prefetch_related_objects - from rest_framework import status from rest_framework.response import Response from rest_framework.settings import api_settings @@ -69,13 +67,10 @@ class UpdateModelMixin: serializer.is_valid(raise_exception=True) self.perform_update(serializer) - queryset = self.filter_queryset(self.get_queryset()) - if queryset._prefetch_related_lookups: + if getattr(instance, '_prefetched_objects_cache', None): # If 'prefetch_related' has been applied to a queryset, we need to - # forcibly invalidate the prefetch cache on the instance, - # and then re-prefetch related objects + # forcibly invalidate the prefetch cache on the instance. instance._prefetched_objects_cache = {} - prefetch_related_objects([instance], *queryset._prefetch_related_lookups) return Response(serializer.data) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 8e7bcf4ac..b07087c97 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -1,5 +1,4 @@ from django.contrib.auth.models import Group, User -from django.db.models.query import Prefetch from django.test import TestCase from rest_framework import generics, serializers @@ -9,84 +8,51 @@ factory = APIRequestFactory() class UserSerializer(serializers.ModelSerializer): - permissions = serializers.SerializerMethodField() - - def get_permissions(self, obj): - ret = [] - for g in obj.groups.all(): - ret.extend([p.pk for p in g.permissions.all()]) - return ret - class Meta: model = User - fields = ('id', 'username', 'email', 'groups', 'permissions') + fields = ('id', 'username', 'email', 'groups') -class UserRetrieveUpdate(generics.RetrieveUpdateAPIView): - queryset = User.objects.exclude(username='exclude').prefetch_related( - Prefetch('groups', queryset=Group.objects.exclude(name='exclude')), - 'groups__permissions', - ) - serializer_class = UserSerializer - - -class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView): - queryset = User.objects.exclude(username='exclude') +class UserUpdate(generics.UpdateAPIView): + queryset = User.objects.exclude(username='exclude').prefetch_related('groups') serializer_class = UserSerializer class TestPrefetchRelatedUpdates(TestCase): def setUp(self): self.user = User.objects.create(username='tom', email='tom@example.com') - self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] + self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.user.groups.set(self.groups) - self.user.groups.add(Group.objects.create(name='exclude')) - self.expected = { - 'id': self.user.pk, - 'username': 'tom', - 'groups': [group.pk for group in self.groups], - 'email': 'tom@example.com', - 'permissions': [], - } - self.view = UserRetrieveUpdate.as_view() def test_prefetch_related_updates(self): - self.groups.append(Group.objects.create(name='c')) - request = factory.put( - '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json' - ) - self.expected['username'] = 'new' - self.expected['groups'] = [group.pk for group in self.groups] - response = self.view(request, pk=self.user.pk) - assert User.objects.get(pk=self.user.pk).groups.count() == 12 - assert response.data == self.expected - # Update and fetch should get same result - request = factory.get('/') - response = self.view(request, pk=self.user.pk) - assert response.data == self.expected + view = UserUpdate.as_view() + pk = self.user.pk + groups_pk = self.groups[0].pk + request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') + response = view(request, pk=pk) + assert User.objects.get(pk=pk).groups.count() == 1 + expected = { + 'id': pk, + 'username': 'new', + 'groups': [1], + 'email': 'tom@example.com' + } + assert response.data == expected def test_prefetch_related_excluding_instance_from_original_queryset(self): """ Regression test for https://github.com/encode/django-rest-framework/issues/4661 """ - request = factory.put( - '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' - ) - response = self.view(request, pk=self.user.pk) - assert User.objects.get(pk=self.user.pk).groups.count() == 2 - self.expected['username'] = 'exclude' - self.expected['groups'] = [self.groups[0].pk] - assert response.data == self.expected - - def test_db_query_count(self): - request = factory.put( - '/', {'username': 'new'}, format='json' - ) - with self.assertNumQueries(7): - self.view(request, pk=self.user.pk) - - request = factory.put( - '/', {'username': 'new2'}, format='json' - ) - with self.assertNumQueries(16): - UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk) + view = UserUpdate.as_view() + pk = self.user.pk + groups_pk = self.groups[0].pk + request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') + response = view(request, pk=pk) + assert User.objects.get(pk=pk).groups.count() == 1 + expected = { + 'id': pk, + 'username': 'exclude', + 'groups': [1], + 'email': 'tom@example.com' + } + assert response.data == expected