Revert "Re-prefetch related objects after updating (#8043)" (#9327)

This reverts commit 2b34aa4291.
This commit is contained in:
Asif Saif Uddin 2024-03-22 04:23:30 +06:00 committed by GitHub
parent 0e4ed81627
commit da78a147f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 71 deletions

View File

@ -4,8 +4,6 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet, We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways. which allows mixin classes to be composed in interesting ways.
""" """
from django.db.models.query import prefetch_related_objects
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -69,13 +67,10 @@ class UpdateModelMixin:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_update(serializer) self.perform_update(serializer)
queryset = self.filter_queryset(self.get_queryset()) if getattr(instance, '_prefetched_objects_cache', None):
if queryset._prefetch_related_lookups:
# If 'prefetch_related' has been applied to a queryset, we need to # If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance, # forcibly invalidate the prefetch cache on the instance.
# and then re-prefetch related objects
instance._prefetched_objects_cache = {} instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
return Response(serializer.data) return Response(serializer.data)

View File

@ -1,5 +1,4 @@
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from django.db.models.query import Prefetch
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, serializers from rest_framework import generics, serializers
@ -9,84 +8,51 @@ factory = APIRequestFactory()
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
permissions = serializers.SerializerMethodField()
def get_permissions(self, obj):
ret = []
for g in obj.groups.all():
ret.extend([p.pk for p in g.permissions.all()])
return ret
class Meta: class Meta:
model = User model = User
fields = ('id', 'username', 'email', 'groups', 'permissions') fields = ('id', 'username', 'email', 'groups')
class UserRetrieveUpdate(generics.RetrieveUpdateAPIView): class UserUpdate(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related( queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
'groups__permissions',
)
serializer_class = UserSerializer
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer serializer_class = UserSerializer
class TestPrefetchRelatedUpdates(TestCase): class TestPrefetchRelatedUpdates(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com') self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.user.groups.set(self.groups) self.user.groups.set(self.groups)
self.user.groups.add(Group.objects.create(name='exclude'))
self.expected = {
'id': self.user.pk,
'username': 'tom',
'groups': [group.pk for group in self.groups],
'email': 'tom@example.com',
'permissions': [],
}
self.view = UserRetrieveUpdate.as_view()
def test_prefetch_related_updates(self): def test_prefetch_related_updates(self):
self.groups.append(Group.objects.create(name='c')) view = UserUpdate.as_view()
request = factory.put( pk = self.user.pk
'/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json' groups_pk = self.groups[0].pk
) request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
self.expected['username'] = 'new' response = view(request, pk=pk)
self.expected['groups'] = [group.pk for group in self.groups] assert User.objects.get(pk=pk).groups.count() == 1
response = self.view(request, pk=self.user.pk) expected = {
assert User.objects.get(pk=self.user.pk).groups.count() == 12 'id': pk,
assert response.data == self.expected 'username': 'new',
# Update and fetch should get same result 'groups': [1],
request = factory.get('/') 'email': 'tom@example.com'
response = self.view(request, pk=self.user.pk) }
assert response.data == self.expected assert response.data == expected
def test_prefetch_related_excluding_instance_from_original_queryset(self): def test_prefetch_related_excluding_instance_from_original_queryset(self):
""" """
Regression test for https://github.com/encode/django-rest-framework/issues/4661 Regression test for https://github.com/encode/django-rest-framework/issues/4661
""" """
request = factory.put( view = UserUpdate.as_view()
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' pk = self.user.pk
) groups_pk = self.groups[0].pk
response = self.view(request, pk=self.user.pk) request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
assert User.objects.get(pk=self.user.pk).groups.count() == 2 response = view(request, pk=pk)
self.expected['username'] = 'exclude' assert User.objects.get(pk=pk).groups.count() == 1
self.expected['groups'] = [self.groups[0].pk] expected = {
assert response.data == self.expected 'id': pk,
'username': 'exclude',
def test_db_query_count(self): 'groups': [1],
request = factory.put( 'email': 'tom@example.com'
'/', {'username': 'new'}, format='json' }
) assert response.data == expected
with self.assertNumQueries(7):
self.view(request, pk=self.user.pk)
request = factory.put(
'/', {'username': 'new2'}, format='json'
)
with self.assertNumQueries(16):
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)