Use _prefetch_related_lookups and refine test cases

This commit is contained in:
Yuekui Li 2021-07-07 18:53:46 -07:00
parent 7f24ef2af6
commit 2da4374ee1
3 changed files with 39 additions and 33 deletions

View File

@ -1,8 +1,6 @@
""" """
Generic views that provide commonly needed behaviour. Generic views that provide commonly needed behaviour.
""" """
from typing import Iterable
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.http import Http404 from django.http import Http404
@ -47,8 +45,6 @@ class GenericAPIView(views.APIView):
# The style to use for queryset pagination. # The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
prefetch_related = []
def get_queryset(self): def get_queryset(self):
""" """
Get the list of items for this view. Get the list of items for this view.
@ -72,31 +68,10 @@ class GenericAPIView(views.APIView):
queryset = self.queryset queryset = self.queryset
if isinstance(queryset, QuerySet): if isinstance(queryset, QuerySet):
# Prefetch related objects
if self.get_prefetch_related():
queryset = queryset.prefetch_related(*self.get_prefetch_related())
# Ensure queryset is re-evaluated on each request. # Ensure queryset is re-evaluated on each request.
queryset = queryset.all() queryset = queryset.all()
return queryset return queryset
def get_prefetch_related(self):
"""
Get the list of prefetch related objects for self.queryset or instance.
This must be an iterable.
Defaults to using `self.prefetch_related`.
You may want to override this if you need to provide prefetched objects
depending on the incoming request.
(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
"""
assert isinstance(self.prefetch_related, Iterable), (
"'%s' should either include an iterable `prefetch_related` attribute, "
"or override the `get_prefetch_related()` method."
% self.__class__.__name__
)
return self.prefetch_related
def get_object(self): def get_object(self):
""" """
Returns the object the view is displaying. Returns the object the view is displaying.

View File

@ -74,7 +74,8 @@ class UpdateModelMixin:
# forcibly invalidate the prefetch cache on the instance, # forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects # and then re-prefetch related objects
instance._prefetched_objects_cache = {} instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *self.get_prefetch_related()) queryset = self.filter_queryset(self.get_queryset())
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
return Response(serializer.data) return Response(serializer.data)

View File

@ -8,36 +8,52 @@ factory = APIRequestFactory()
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
permissions = serializers.SerializerMethodField()
def get_permissions(self, obj):
ret = []
for g in obj.groups.all():
ret.extend([p.pk for p in g.permissions.all()])
return ret
class Meta: class Meta:
model = User model = User
fields = ('id', 'username', 'email', 'groups') fields = ('id', 'username', 'email', 'groups', 'permissions')
class UserUpdate(generics.UpdateAPIView): class UserUpdate(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions')
serializer_class = UserSerializer
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude') queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer serializer_class = UserSerializer
prefetch_related = ['groups']
class TestPrefetchRelatedUpdates(TestCase): class TestPrefetchRelatedUpdates(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com') self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
self.user.groups.set(self.groups) self.user.groups.set(self.groups)
self.expected = { self.expected = {
'id': self.user.pk, 'id': self.user.pk,
'username': 'new', 'username': 'tom',
'groups': [1], 'groups': [group.pk for group in self.groups],
'email': 'tom@example.com', 'email': 'tom@example.com',
'permissions': [],
} }
self.view = UserUpdate.as_view() self.view = UserUpdate.as_view()
def test_prefetch_related_updates(self): def test_prefetch_related_updates(self):
self.groups.append(Group.objects.create(name='c'))
request = factory.put( request = factory.put(
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json' '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
) )
self.expected['username'] = 'new'
self.expected['groups'] = [group.pk for group in self.groups]
response = self.view(request, pk=self.user.pk) response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1 assert User.objects.get(pk=self.user.pk).groups.count() == 11
assert response.data == self.expected assert response.data == self.expected
def test_prefetch_related_excluding_instance_from_original_queryset(self): def test_prefetch_related_excluding_instance_from_original_queryset(self):
@ -50,4 +66,18 @@ class TestPrefetchRelatedUpdates(TestCase):
response = self.view(request, pk=self.user.pk) response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1 assert User.objects.get(pk=self.user.pk).groups.count() == 1
self.expected['username'] = 'exclude' self.expected['username'] = 'exclude'
self.expected['groups'] = [self.groups[0].pk]
assert response.data == self.expected assert response.data == self.expected
def test_db_query_count(self):
request = factory.put(
'/', {'username': 'new'}, format='json'
)
with self.assertNumQueries(7):
self.view(request, pk=self.user.pk)
request = factory.put(
'/', {'username': 'new2'}, format='json'
)
with self.assertNumQueries(15):
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)