mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-27 08:29:59 +03:00
Re-prefetch related objects after updating
This commit is contained in:
parent
24a938abaa
commit
af2c4a6297
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Generic views that provide commonly needed behaviour.
|
Generic views that provide commonly needed behaviour.
|
||||||
"""
|
"""
|
||||||
|
from typing import Iterable
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
|
@ -45,6 +46,8 @@ class GenericAPIView(views.APIView):
|
||||||
# The style to use for queryset pagination.
|
# The style to use for queryset pagination.
|
||||||
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
||||||
|
|
||||||
|
prefetch_related = []
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
"""
|
"""
|
||||||
Get the list of items for this view.
|
Get the list of items for this view.
|
||||||
|
@ -68,9 +71,30 @@ class GenericAPIView(views.APIView):
|
||||||
|
|
||||||
queryset = self.queryset
|
queryset = self.queryset
|
||||||
if isinstance(queryset, QuerySet):
|
if isinstance(queryset, QuerySet):
|
||||||
|
# Prefetch related objects
|
||||||
|
if self.get_prefetch_related():
|
||||||
|
queryset = queryset.prefetch_related(*self.get_prefetch_related())
|
||||||
# Ensure queryset is re-evaluated on each request.
|
# Ensure queryset is re-evaluated on each request.
|
||||||
queryset = queryset.all()
|
queryset = queryset.all()
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
|
def get_prefetch_related(self):
|
||||||
|
"""
|
||||||
|
Get the list of prefetch related objects for self.queryset or instance.
|
||||||
|
This must be an iterable.
|
||||||
|
Defaults to using `self.prefetch_related`.
|
||||||
|
|
||||||
|
You may want to override this if you need to provide prefetched objects
|
||||||
|
depending on the incoming request.
|
||||||
|
|
||||||
|
(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
|
||||||
|
"""
|
||||||
|
assert isinstance(self.prefetch_related, Iterable), (
|
||||||
|
"'%s' should either include an iterable `prefetch_related` attribute, "
|
||||||
|
"or override the `get_prefetch_related()` method."
|
||||||
|
% self.__class__.__name__
|
||||||
|
)
|
||||||
|
return self.prefetch_related
|
||||||
|
|
||||||
def get_object(self):
|
def get_object(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
|
||||||
We don't bind behaviour to http method handlers yet,
|
We don't bind behaviour to http method handlers yet,
|
||||||
which allows mixin classes to be composed in interesting ways.
|
which allows mixin classes to be composed in interesting ways.
|
||||||
"""
|
"""
|
||||||
|
from django.db.models.query import prefetch_related_objects
|
||||||
|
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
@ -69,8 +71,10 @@ class UpdateModelMixin:
|
||||||
|
|
||||||
if getattr(instance, '_prefetched_objects_cache', None):
|
if getattr(instance, '_prefetched_objects_cache', None):
|
||||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||||
# forcibly invalidate the prefetch cache on the instance.
|
# forcibly invalidate the prefetch cache on the instance,
|
||||||
|
# and then re-prefetch related objects
|
||||||
instance._prefetched_objects_cache = {}
|
instance._prefetched_objects_cache = {}
|
||||||
|
prefetch_related_objects([instance], *self.get_prefetch_related())
|
||||||
|
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
|
|
@ -14,8 +14,9 @@ class UserSerializer(serializers.ModelSerializer):
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(generics.UpdateAPIView):
|
class UserUpdate(generics.UpdateAPIView):
|
||||||
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
|
queryset = User.objects.exclude(username='exclude')
|
||||||
serializer_class = UserSerializer
|
serializer_class = UserSerializer
|
||||||
|
prefetch_related = ['groups']
|
||||||
|
|
||||||
|
|
||||||
class TestPrefetchRelatedUpdates(TestCase):
|
class TestPrefetchRelatedUpdates(TestCase):
|
||||||
|
@ -23,36 +24,30 @@ class TestPrefetchRelatedUpdates(TestCase):
|
||||||
self.user = User.objects.create(username='tom', email='tom@example.com')
|
self.user = User.objects.create(username='tom', email='tom@example.com')
|
||||||
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
|
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
|
||||||
self.user.groups.set(self.groups)
|
self.user.groups.set(self.groups)
|
||||||
|
self.expected = {
|
||||||
def test_prefetch_related_updates(self):
|
'id': self.user.pk,
|
||||||
view = UserUpdate.as_view()
|
|
||||||
pk = self.user.pk
|
|
||||||
groups_pk = self.groups[0].pk
|
|
||||||
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
|
|
||||||
response = view(request, pk=pk)
|
|
||||||
assert User.objects.get(pk=pk).groups.count() == 1
|
|
||||||
expected = {
|
|
||||||
'id': pk,
|
|
||||||
'username': 'new',
|
'username': 'new',
|
||||||
'groups': [1],
|
'groups': [1],
|
||||||
'email': 'tom@example.com'
|
'email': 'tom@example.com',
|
||||||
}
|
}
|
||||||
assert response.data == expected
|
self.view = UserUpdate.as_view()
|
||||||
|
|
||||||
|
def test_prefetch_related_updates(self):
|
||||||
|
request = factory.put(
|
||||||
|
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json'
|
||||||
|
)
|
||||||
|
response = self.view(request, pk=self.user.pk)
|
||||||
|
assert User.objects.get(pk=self.user.pk).groups.count() == 1
|
||||||
|
assert response.data == self.expected
|
||||||
|
|
||||||
def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
||||||
"""
|
"""
|
||||||
Regression test for https://github.com/encode/django-rest-framework/issues/4661
|
Regression test for https://github.com/encode/django-rest-framework/issues/4661
|
||||||
"""
|
"""
|
||||||
view = UserUpdate.as_view()
|
request = factory.put(
|
||||||
pk = self.user.pk
|
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
|
||||||
groups_pk = self.groups[0].pk
|
)
|
||||||
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
|
response = self.view(request, pk=self.user.pk)
|
||||||
response = view(request, pk=pk)
|
assert User.objects.get(pk=self.user.pk).groups.count() == 1
|
||||||
assert User.objects.get(pk=pk).groups.count() == 1
|
self.expected['username'] = 'exclude'
|
||||||
expected = {
|
assert response.data == self.expected
|
||||||
'id': pk,
|
|
||||||
'username': 'exclude',
|
|
||||||
'groups': [1],
|
|
||||||
'email': 'tom@example.com'
|
|
||||||
}
|
|
||||||
assert response.data == expected
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user