Use _prefetch_related_lookups and refine test cases

This commit is contained in:
Yuekui Li 2021-07-07 18:53:46 -07:00
parent 7f24ef2af6
commit 2da4374ee1
3 changed files with 39 additions and 33 deletions

View File

@ -1,8 +1,6 @@
"""
Generic views that provide commonly needed behaviour.
"""
from typing import Iterable
from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet
from django.http import Http404
@ -47,8 +45,6 @@ class GenericAPIView(views.APIView):
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
prefetch_related = []
def get_queryset(self):
"""
Get the list of items for this view.
@ -72,31 +68,10 @@ class GenericAPIView(views.APIView):
queryset = self.queryset
if isinstance(queryset, QuerySet):
# Prefetch related objects
if self.get_prefetch_related():
queryset = queryset.prefetch_related(*self.get_prefetch_related())
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()
return queryset
def get_prefetch_related(self):
"""
Get the list of prefetch related objects for self.queryset or instance.
This must be an iterable.
Defaults to using `self.prefetch_related`.
You may want to override this if you need to provide prefetched objects
depending on the incoming request.
(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
"""
assert isinstance(self.prefetch_related, Iterable), (
"'%s' should either include an iterable `prefetch_related` attribute, "
"or override the `get_prefetch_related()` method."
% self.__class__.__name__
)
return self.prefetch_related
def get_object(self):
"""
Returns the object the view is displaying.

View File

@ -74,7 +74,8 @@ class UpdateModelMixin:
# forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *self.get_prefetch_related())
queryset = self.filter_queryset(self.get_queryset())
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
return Response(serializer.data)

View File

@ -8,36 +8,52 @@ factory = APIRequestFactory()
class UserSerializer(serializers.ModelSerializer):
permissions = serializers.SerializerMethodField()
def get_permissions(self, obj):
ret = []
for g in obj.groups.all():
ret.extend([p.pk for p in g.permissions.all()])
return ret
class Meta:
model = User
fields = ('id', 'username', 'email', 'groups')
fields = ('id', 'username', 'email', 'groups', 'permissions')
class UserUpdate(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions')
serializer_class = UserSerializer
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer
prefetch_related = ['groups']
class TestPrefetchRelatedUpdates(TestCase):
def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
self.user.groups.set(self.groups)
self.expected = {
'id': self.user.pk,
'username': 'new',
'groups': [1],
'username': 'tom',
'groups': [group.pk for group in self.groups],
'email': 'tom@example.com',
'permissions': [],
}
self.view = UserUpdate.as_view()
def test_prefetch_related_updates(self):
self.groups.append(Group.objects.create(name='c'))
request = factory.put(
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json'
'/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
)
self.expected['username'] = 'new'
self.expected['groups'] = [group.pk for group in self.groups]
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1
assert User.objects.get(pk=self.user.pk).groups.count() == 11
assert response.data == self.expected
def test_prefetch_related_excluding_instance_from_original_queryset(self):
@ -50,4 +66,18 @@ class TestPrefetchRelatedUpdates(TestCase):
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 1
self.expected['username'] = 'exclude'
self.expected['groups'] = [self.groups[0].pk]
assert response.data == self.expected
def test_db_query_count(self):
request = factory.put(
'/', {'username': 'new'}, format='json'
)
with self.assertNumQueries(7):
self.view(request, pk=self.user.pk)
request = factory.put(
'/', {'username': 'new2'}, format='json'
)
with self.assertNumQueries(15):
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)