Re-prefetch related objects after updating

This commit is contained in:
Yuekui Li 2021-06-18 22:33:32 -07:00
parent 24a938abaa
commit af2c4a6297
3 changed files with 50 additions and 27 deletions

View File

@ -1,6 +1,7 @@
"""
Generic views that provide commonly needed behaviour.
"""
from typing import Iterable
from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet
from django.http import Http404
@ -45,6 +46,8 @@ class GenericAPIView(views.APIView):
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
prefetch_related = []
def get_queryset(self):
"""
Get the list of items for this view.
@ -68,9 +71,30 @@ class GenericAPIView(views.APIView):
queryset = self.queryset
if isinstance(queryset, QuerySet):
# Prefetch related objects
if self.get_prefetch_related():
queryset = queryset.prefetch_related(*self.get_prefetch_related())
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()
return queryset
def get_prefetch_related(self):
"""
Get the list of prefetch related objects for self.queryset or instance.
This must be an iterable.
Defaults to using `self.prefetch_related`.
You may want to override this if you need to provide prefetched objects
depending on the incoming request.
(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
"""
assert isinstance(self.prefetch_related, Iterable), (
"'%s' should either include an iterable `prefetch_related` attribute, "
"or override the `get_prefetch_related()` method."
% self.__class__.__name__
)
return self.prefetch_related
def get_object(self):
"""

View File

@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
from django.db.models.query import prefetch_related_objects
from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
@ -69,8 +71,10 @@ class UpdateModelMixin:
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
# forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *self.get_prefetch_related())
return Response(serializer.data)

View File

@ -14,8 +14,9 @@ class UserSerializer(serializers.ModelSerializer):
class UserUpdate(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer
prefetch_related = ['groups']
class TestPrefetchRelatedUpdates(TestCase):
@ -23,36 +24,30 @@ class TestPrefetchRelatedUpdates(TestCase):
self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.user.groups.set(self.groups)
def test_prefetch_related_updates(self):
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
self.expected = {
'id': self.user.pk,
'username': 'new',
'groups': [1],
'email': 'tom@example.com'
'email': 'tom@example.com',
}
assert response.data == expected
self.view = UserUpdate.as_view()
def test_prefetch_related_updates(self):
request = factory.put(
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json'
)
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1
assert response.data == self.expected
def test_prefetch_related_excluding_instance_from_original_queryset(self):
"""
Regression test for https://github.com/encode/django-rest-framework/issues/4661
"""
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
request = factory.put(
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
)
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1
self.expected['username'] = 'exclude'
assert response.data == self.expected