Re-prefetch related objects after updating

This commit is contained in:
Yuekui Li 2021-06-18 22:33:32 -07:00
parent 24a938abaa
commit af2c4a6297
3 changed files with 50 additions and 27 deletions

View File

@ -1,6 +1,7 @@
""" """
Generic views that provide commonly needed behaviour. Generic views that provide commonly needed behaviour.
""" """
from typing import Iterable
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.http import Http404 from django.http import Http404
@ -45,6 +46,8 @@ class GenericAPIView(views.APIView):
# The style to use for queryset pagination. # The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
prefetch_related = []
def get_queryset(self): def get_queryset(self):
""" """
Get the list of items for this view. Get the list of items for this view.
@ -68,10 +71,31 @@ class GenericAPIView(views.APIView):
queryset = self.queryset queryset = self.queryset
if isinstance(queryset, QuerySet): if isinstance(queryset, QuerySet):
# Prefetch related objects
if self.get_prefetch_related():
queryset = queryset.prefetch_related(*self.get_prefetch_related())
# Ensure queryset is re-evaluated on each request. # Ensure queryset is re-evaluated on each request.
queryset = queryset.all() queryset = queryset.all()
return queryset return queryset
def get_prefetch_related(self):
"""
Get the list of prefetch related objects for self.queryset or instance.
This must be an iterable.
Defaults to using `self.prefetch_related`.
You may want to override this if you need to provide prefetched objects
depending on the incoming request.
(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
"""
assert isinstance(self.prefetch_related, Iterable), (
"'%s' should either include an iterable `prefetch_related` attribute, "
"or override the `get_prefetch_related()` method."
% self.__class__.__name__
)
return self.prefetch_related
def get_object(self): def get_object(self):
""" """
Returns the object the view is displaying. Returns the object the view is displaying.

View File

@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet, We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways. which allows mixin classes to be composed in interesting ways.
""" """
from django.db.models.query import prefetch_related_objects
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -69,8 +71,10 @@ class UpdateModelMixin:
if getattr(instance, '_prefetched_objects_cache', None): if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to # If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance. # forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {} instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *self.get_prefetch_related())
return Response(serializer.data) return Response(serializer.data)

View File

@ -14,8 +14,9 @@ class UserSerializer(serializers.ModelSerializer):
class UserUpdate(generics.UpdateAPIView): class UserUpdate(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups') queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer serializer_class = UserSerializer
prefetch_related = ['groups']
class TestPrefetchRelatedUpdates(TestCase): class TestPrefetchRelatedUpdates(TestCase):
@ -23,36 +24,30 @@ class TestPrefetchRelatedUpdates(TestCase):
self.user = User.objects.create(username='tom', email='tom@example.com') self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.user.groups.set(self.groups) self.user.groups.set(self.groups)
self.expected = {
def test_prefetch_related_updates(self): 'id': self.user.pk,
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'new', 'username': 'new',
'groups': [1], 'groups': [1],
'email': 'tom@example.com' 'email': 'tom@example.com',
} }
assert response.data == expected self.view = UserUpdate.as_view()
def test_prefetch_related_updates(self):
request = factory.put(
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json'
)
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1
assert response.data == self.expected
def test_prefetch_related_excluding_instance_from_original_queryset(self): def test_prefetch_related_excluding_instance_from_original_queryset(self):
""" """
Regression test for https://github.com/encode/django-rest-framework/issues/4661 Regression test for https://github.com/encode/django-rest-framework/issues/4661
""" """
view = UserUpdate.as_view() request = factory.put(
pk = self.user.pk '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
groups_pk = self.groups[0].pk )
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') response = self.view(request, pk=self.user.pk)
response = view(request, pk=pk) assert User.objects.get(pk=self.user.pk).groups.count() == 1
assert User.objects.get(pk=pk).groups.count() == 1 self.expected['username'] = 'exclude'
expected = { assert response.data == self.expected
'id': pk,
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected