mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-27 08:29:59 +03:00
Re-prefetch related objects after updating
This commit is contained in:
parent
24a938abaa
commit
af2c4a6297
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Generic views that provide commonly needed behaviour.
|
||||
"""
|
||||
from typing import Iterable
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db.models.query import QuerySet
|
||||
from django.http import Http404
|
||||
|
@ -45,6 +46,8 @@ class GenericAPIView(views.APIView):
|
|||
# The style to use for queryset pagination.
|
||||
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
||||
|
||||
prefetch_related = []
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
Get the list of items for this view.
|
||||
|
@ -68,10 +71,31 @@ class GenericAPIView(views.APIView):
|
|||
|
||||
queryset = self.queryset
|
||||
if isinstance(queryset, QuerySet):
|
||||
# Prefetch related objects
|
||||
if self.get_prefetch_related():
|
||||
queryset = queryset.prefetch_related(*self.get_prefetch_related())
|
||||
# Ensure queryset is re-evaluated on each request.
|
||||
queryset = queryset.all()
|
||||
return queryset
|
||||
|
||||
def get_prefetch_related(self):
|
||||
"""
|
||||
Get the list of prefetch related objects for self.queryset or instance.
|
||||
This must be an iterable.
|
||||
Defaults to using `self.prefetch_related`.
|
||||
|
||||
You may want to override this if you need to provide prefetched objects
|
||||
depending on the incoming request.
|
||||
|
||||
(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
|
||||
"""
|
||||
assert isinstance(self.prefetch_related, Iterable), (
|
||||
"'%s' should either include an iterable `prefetch_related` attribute, "
|
||||
"or override the `get_prefetch_related()` method."
|
||||
% self.__class__.__name__
|
||||
)
|
||||
return self.prefetch_related
|
||||
|
||||
def get_object(self):
|
||||
"""
|
||||
Returns the object the view is displaying.
|
||||
|
|
|
@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
|
|||
We don't bind behaviour to http method handlers yet,
|
||||
which allows mixin classes to be composed in interesting ways.
|
||||
"""
|
||||
from django.db.models.query import prefetch_related_objects
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
|
@ -69,8 +71,10 @@ class UpdateModelMixin:
|
|||
|
||||
if getattr(instance, '_prefetched_objects_cache', None):
|
||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||
# forcibly invalidate the prefetch cache on the instance.
|
||||
# forcibly invalidate the prefetch cache on the instance,
|
||||
# and then re-prefetch related objects
|
||||
instance._prefetched_objects_cache = {}
|
||||
prefetch_related_objects([instance], *self.get_prefetch_related())
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
|
|
|
@ -14,8 +14,9 @@ class UserSerializer(serializers.ModelSerializer):
|
|||
|
||||
|
||||
class UserUpdate(generics.UpdateAPIView):
|
||||
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
|
||||
queryset = User.objects.exclude(username='exclude')
|
||||
serializer_class = UserSerializer
|
||||
prefetch_related = ['groups']
|
||||
|
||||
|
||||
class TestPrefetchRelatedUpdates(TestCase):
|
||||
|
@ -23,36 +24,30 @@ class TestPrefetchRelatedUpdates(TestCase):
|
|||
self.user = User.objects.create(username='tom', email='tom@example.com')
|
||||
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
|
||||
self.user.groups.set(self.groups)
|
||||
|
||||
def test_prefetch_related_updates(self):
|
||||
view = UserUpdate.as_view()
|
||||
pk = self.user.pk
|
||||
groups_pk = self.groups[0].pk
|
||||
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
|
||||
response = view(request, pk=pk)
|
||||
assert User.objects.get(pk=pk).groups.count() == 1
|
||||
expected = {
|
||||
'id': pk,
|
||||
self.expected = {
|
||||
'id': self.user.pk,
|
||||
'username': 'new',
|
||||
'groups': [1],
|
||||
'email': 'tom@example.com'
|
||||
'email': 'tom@example.com',
|
||||
}
|
||||
assert response.data == expected
|
||||
self.view = UserUpdate.as_view()
|
||||
|
||||
def test_prefetch_related_updates(self):
|
||||
request = factory.put(
|
||||
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json'
|
||||
)
|
||||
response = self.view(request, pk=self.user.pk)
|
||||
assert User.objects.get(pk=self.user.pk).groups.count() == 1
|
||||
assert response.data == self.expected
|
||||
|
||||
def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
||||
"""
|
||||
Regression test for https://github.com/encode/django-rest-framework/issues/4661
|
||||
"""
|
||||
view = UserUpdate.as_view()
|
||||
pk = self.user.pk
|
||||
groups_pk = self.groups[0].pk
|
||||
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
|
||||
response = view(request, pk=pk)
|
||||
assert User.objects.get(pk=pk).groups.count() == 1
|
||||
expected = {
|
||||
'id': pk,
|
||||
'username': 'exclude',
|
||||
'groups': [1],
|
||||
'email': 'tom@example.com'
|
||||
}
|
||||
assert response.data == expected
|
||||
request = factory.put(
|
||||
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
|
||||
)
|
||||
response = self.view(request, pk=self.user.pk)
|
||||
assert User.objects.get(pk=self.user.pk).groups.count() == 1
|
||||
self.expected['username'] = 'exclude'
|
||||
assert response.data == self.expected
|
||||
|
|
Loading…
Reference in New Issue
Block a user