mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-09-17 09:42:29 +03:00
Add more test cases and refine prefetch checking
This commit is contained in:
parent
91a9582ec6
commit
f322c04159
|
@ -69,12 +69,12 @@ class UpdateModelMixin:
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
self.perform_update(serializer)
|
self.perform_update(serializer)
|
||||||
|
|
||||||
if getattr(instance, '_prefetched_objects_cache', None):
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
|
if queryset._prefetch_related_lookups:
|
||||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||||
# forcibly invalidate the prefetch cache on the instance,
|
# forcibly invalidate the prefetch cache on the instance,
|
||||||
# and then re-prefetch related objects
|
# and then re-prefetch related objects
|
||||||
instance._prefetched_objects_cache = {}
|
instance._prefetched_objects_cache = {}
|
||||||
queryset = self.filter_queryset(self.get_queryset())
|
|
||||||
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
|
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
|
||||||
|
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from django.contrib.auth.models import Group, User
|
from django.contrib.auth.models import Group, User
|
||||||
|
from django.db.models.query import Prefetch
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from rest_framework import generics, serializers
|
from rest_framework import generics, serializers
|
||||||
|
@ -21,8 +22,11 @@ class UserSerializer(serializers.ModelSerializer):
|
||||||
fields = ('id', 'username', 'email', 'groups', 'permissions')
|
fields = ('id', 'username', 'email', 'groups', 'permissions')
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(generics.UpdateAPIView):
|
class UserRetrieveUpdate(generics.RetrieveUpdateAPIView):
|
||||||
queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions')
|
queryset = User.objects.exclude(username='exclude').prefetch_related(
|
||||||
|
Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
|
||||||
|
'groups__permissions',
|
||||||
|
)
|
||||||
serializer_class = UserSerializer
|
serializer_class = UserSerializer
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,6 +40,7 @@ class TestPrefetchRelatedUpdates(TestCase):
|
||||||
self.user = User.objects.create(username='tom', email='tom@example.com')
|
self.user = User.objects.create(username='tom', email='tom@example.com')
|
||||||
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
|
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
|
||||||
self.user.groups.set(self.groups)
|
self.user.groups.set(self.groups)
|
||||||
|
self.user.groups.add(Group.objects.create(name='exclude'))
|
||||||
self.expected = {
|
self.expected = {
|
||||||
'id': self.user.pk,
|
'id': self.user.pk,
|
||||||
'username': 'tom',
|
'username': 'tom',
|
||||||
|
@ -43,7 +48,7 @@ class TestPrefetchRelatedUpdates(TestCase):
|
||||||
'email': 'tom@example.com',
|
'email': 'tom@example.com',
|
||||||
'permissions': [],
|
'permissions': [],
|
||||||
}
|
}
|
||||||
self.view = UserUpdate.as_view()
|
self.view = UserRetrieveUpdate.as_view()
|
||||||
|
|
||||||
def test_prefetch_related_updates(self):
|
def test_prefetch_related_updates(self):
|
||||||
self.groups.append(Group.objects.create(name='c'))
|
self.groups.append(Group.objects.create(name='c'))
|
||||||
|
@ -53,7 +58,11 @@ class TestPrefetchRelatedUpdates(TestCase):
|
||||||
self.expected['username'] = 'new'
|
self.expected['username'] = 'new'
|
||||||
self.expected['groups'] = [group.pk for group in self.groups]
|
self.expected['groups'] = [group.pk for group in self.groups]
|
||||||
response = self.view(request, pk=self.user.pk)
|
response = self.view(request, pk=self.user.pk)
|
||||||
assert User.objects.get(pk=self.user.pk).groups.count() == 11
|
assert User.objects.get(pk=self.user.pk).groups.count() == 12
|
||||||
|
assert response.data == self.expected
|
||||||
|
# Update and fetch should get same result
|
||||||
|
request = factory.get('/')
|
||||||
|
response = self.view(request, pk=self.user.pk)
|
||||||
assert response.data == self.expected
|
assert response.data == self.expected
|
||||||
|
|
||||||
def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
||||||
|
@ -64,7 +73,7 @@ class TestPrefetchRelatedUpdates(TestCase):
|
||||||
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
|
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
|
||||||
)
|
)
|
||||||
response = self.view(request, pk=self.user.pk)
|
response = self.view(request, pk=self.user.pk)
|
||||||
assert User.objects.get(pk=self.user.pk).groups.count() == 1
|
assert User.objects.get(pk=self.user.pk).groups.count() == 2
|
||||||
self.expected['username'] = 'exclude'
|
self.expected['username'] = 'exclude'
|
||||||
self.expected['groups'] = [self.groups[0].pk]
|
self.expected['groups'] = [self.groups[0].pk]
|
||||||
assert response.data == self.expected
|
assert response.data == self.expected
|
||||||
|
@ -79,5 +88,5 @@ class TestPrefetchRelatedUpdates(TestCase):
|
||||||
request = factory.put(
|
request = factory.put(
|
||||||
'/', {'username': 'new2'}, format='json'
|
'/', {'username': 'new2'}, format='json'
|
||||||
)
|
)
|
||||||
with self.assertNumQueries(15):
|
with self.assertNumQueries(16):
|
||||||
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)
|
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user