From 2da4374ee10a1544cad73d779e78f40b513cada1 Mon Sep 17 00:00:00 2001 From: Yuekui Li Date: Wed, 7 Jul 2021 18:53:46 -0700 Subject: [PATCH] Use _prefetch_related_lookups and refine test cases --- rest_framework/generics.py | 25 ------------------- rest_framework/mixins.py | 3 ++- tests/test_prefetch_related.py | 44 ++++++++++++++++++++++++++++------ 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index e42ca529c..55cfafda4 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -1,8 +1,6 @@ """ Generic views that provide commonly needed behaviour. """ -from typing import Iterable - from django.core.exceptions import ValidationError from django.db.models.query import QuerySet from django.http import Http404 @@ -47,8 +45,6 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS - prefetch_related = [] - def get_queryset(self): """ Get the list of items for this view. @@ -72,31 +68,10 @@ class GenericAPIView(views.APIView): queryset = self.queryset if isinstance(queryset, QuerySet): - # Prefetch related objects - if self.get_prefetch_related(): - queryset = queryset.prefetch_related(*self.get_prefetch_related()) # Ensure queryset is re-evaluated on each request. queryset = queryset.all() return queryset - def get_prefetch_related(self): - """ - Get the list of prefetch related objects for self.queryset or instance. - This must be an iterable. - Defaults to using `self.prefetch_related`. - - You may want to override this if you need to provide prefetched objects - depending on the incoming request. - - (Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`) - """ - assert isinstance(self.prefetch_related, Iterable), ( - "'%s' should either include an iterable `prefetch_related` attribute, " - "or override the `get_prefetch_related()` method." - % self.__class__.__name__ - ) - return self.prefetch_related - def get_object(self): """ Returns the object the view is displaying. diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 98127757f..6031b06ad 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -74,7 +74,8 @@ class UpdateModelMixin: # forcibly invalidate the prefetch cache on the instance, # and then re-prefetch related objects instance._prefetched_objects_cache = {} - prefetch_related_objects([instance], *self.get_prefetch_related()) + queryset = self.filter_queryset(self.get_queryset()) + prefetch_related_objects([instance], *queryset._prefetch_related_lookups) return Response(serializer.data) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 2f0064bf1..d38066690 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -8,36 +8,52 @@ factory = APIRequestFactory() class UserSerializer(serializers.ModelSerializer): + permissions = serializers.SerializerMethodField() + + def get_permissions(self, obj): + ret = [] + for g in obj.groups.all(): + ret.extend([p.pk for p in g.permissions.all()]) + return ret + class Meta: model = User - fields = ('id', 'username', 'email', 'groups') + fields = ('id', 'username', 'email', 'groups', 'permissions') class UserUpdate(generics.UpdateAPIView): + queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions') + serializer_class = UserSerializer + + +class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView): queryset = User.objects.exclude(username='exclude') serializer_class = UserSerializer - prefetch_related = ['groups'] class TestPrefetchRelatedUpdates(TestCase): def setUp(self): self.user = User.objects.create(username='tom', email='tom@example.com') - self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] + self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] self.user.groups.set(self.groups) self.expected = { 'id': self.user.pk, - 'username': 'new', - 'groups': [1], + 'username': 'tom', + 'groups': [group.pk for group in self.groups], 'email': 'tom@example.com', + 'permissions': [], } self.view = UserUpdate.as_view() def test_prefetch_related_updates(self): + self.groups.append(Group.objects.create(name='c')) request = factory.put( - '/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json' + '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json' ) + self.expected['username'] = 'new' + self.expected['groups'] = [group.pk for group in self.groups] response = self.view(request, pk=self.user.pk) - assert User.objects.get(pk=self.user.pk).groups.count() == 1 + assert User.objects.get(pk=self.user.pk).groups.count() == 11 assert response.data == self.expected def test_prefetch_related_excluding_instance_from_original_queryset(self): @@ -50,4 +66,18 @@ class TestPrefetchRelatedUpdates(TestCase): response = self.view(request, pk=self.user.pk) assert User.objects.get(pk=self.user.pk).groups.count() == 1 self.expected['username'] = 'exclude' + self.expected['groups'] = [self.groups[0].pk] assert response.data == self.expected + + def test_db_query_count(self): + request = factory.put( + '/', {'username': 'new'}, format='json' + ) + with self.assertNumQueries(7): + self.view(request, pk=self.user.pk) + + request = factory.put( + '/', {'username': 'new2'}, format='json' + ) + with self.assertNumQueries(15): + UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)