From af2c4a6297a08cc8f62c4bbc5f36cc7b51fefe32 Mon Sep 17 00:00:00 2001 From: Yuekui Li Date: Fri, 18 Jun 2021 22:33:32 -0700 Subject: [PATCH] Re-prefetch related objects after updating --- rest_framework/generics.py | 24 +++++++++++++++++ rest_framework/mixins.py | 6 ++++- tests/test_prefetch_related.py | 47 +++++++++++++++------------------- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55cfafda4..17290c5eb 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -1,6 +1,7 @@ """ Generic views that provide commonly needed behaviour. """ +from typing import Iterable from django.core.exceptions import ValidationError from django.db.models.query import QuerySet from django.http import Http404 @@ -45,6 +46,8 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + prefetch_related = [] + def get_queryset(self): """ Get the list of items for this view. @@ -68,9 +71,30 @@ class GenericAPIView(views.APIView): queryset = self.queryset if isinstance(queryset, QuerySet): + # Prefetch related objects + if self.get_prefetch_related(): + queryset = queryset.prefetch_related(*self.get_prefetch_related()) # Ensure queryset is re-evaluated on each request. queryset = queryset.all() return queryset + + def get_prefetch_related(self): + """ + Get the list of prefetch related objects for self.queryset or instance. + This must be an iterable. + Defaults to using `self.prefetch_related`. + + You may want to override this if you need to provide prefetched objects + depending on the incoming request. + + (Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`) + """ + assert isinstance(self.prefetch_related, Iterable), ( + "'%s' should either include an iterable `prefetch_related` attribute, " + "or override the `get_prefetch_related()` method." + % self.__class__.__name__ + ) + return self.prefetch_related def get_object(self): """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 7fa8947cb..98127757f 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -4,6 +4,8 @@ Basic building blocks for generic class based views. We don't bind behaviour to http method handlers yet, which allows mixin classes to be composed in interesting ways. """ +from django.db.models.query import prefetch_related_objects + from rest_framework import status from rest_framework.response import Response from rest_framework.settings import api_settings @@ -69,8 +71,10 @@ class UpdateModelMixin: if getattr(instance, '_prefetched_objects_cache', None): # If 'prefetch_related' has been applied to a queryset, we need to - # forcibly invalidate the prefetch cache on the instance. + # forcibly invalidate the prefetch cache on the instance, + # and then re-prefetch related objects instance._prefetched_objects_cache = {} + prefetch_related_objects([instance], *self.get_prefetch_related()) return Response(serializer.data) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index b07087c97..2f0064bf1 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -14,8 +14,9 @@ class UserSerializer(serializers.ModelSerializer): class UserUpdate(generics.UpdateAPIView): - queryset = User.objects.exclude(username='exclude').prefetch_related('groups') + queryset = User.objects.exclude(username='exclude') serializer_class = UserSerializer + prefetch_related = ['groups'] class TestPrefetchRelatedUpdates(TestCase): @@ -23,36 +24,30 @@ class TestPrefetchRelatedUpdates(TestCase): self.user = User.objects.create(username='tom', email='tom@example.com') self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.user.groups.set(self.groups) - - def test_prefetch_related_updates(self): - view = UserUpdate.as_view() - pk = self.user.pk - groups_pk = self.groups[0].pk - request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') - response = view(request, pk=pk) - assert User.objects.get(pk=pk).groups.count() == 1 - expected = { - 'id': pk, + self.expected = { + 'id': self.user.pk, 'username': 'new', 'groups': [1], - 'email': 'tom@example.com' + 'email': 'tom@example.com', } - assert response.data == expected + self.view = UserUpdate.as_view() + + def test_prefetch_related_updates(self): + request = factory.put( + '/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json' + ) + response = self.view(request, pk=self.user.pk) + assert User.objects.get(pk=self.user.pk).groups.count() == 1 + assert response.data == self.expected def test_prefetch_related_excluding_instance_from_original_queryset(self): """ Regression test for https://github.com/encode/django-rest-framework/issues/4661 """ - view = UserUpdate.as_view() - pk = self.user.pk - groups_pk = self.groups[0].pk - request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') - response = view(request, pk=pk) - assert User.objects.get(pk=pk).groups.count() == 1 - expected = { - 'id': pk, - 'username': 'exclude', - 'groups': [1], - 'email': 'tom@example.com' - } - assert response.data == expected + request = factory.put( + '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' + ) + response = self.view(request, pk=self.user.pk) + assert User.objects.get(pk=self.user.pk).groups.count() == 1 + self.expected['username'] = 'exclude' + assert response.data == self.expected