mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	Re-prefetch related objects after updating (#8043)
* Re-prefetch related objects after updating * Fix flake8 format * Use _prefetch_related_lookups and refine test cases * Add more test cases and refine prefetch checking
This commit is contained in:
		
							parent
							
								
									bfce663a60
								
							
						
					
					
						commit
						2b34aa4291
					
				| 
						 | 
				
			
			@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
 | 
			
		|||
We don't bind behaviour to http method handlers yet,
 | 
			
		||||
which allows mixin classes to be composed in interesting ways.
 | 
			
		||||
"""
 | 
			
		||||
from django.db.models.query import prefetch_related_objects
 | 
			
		||||
 | 
			
		||||
from rest_framework import status
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
from rest_framework.settings import api_settings
 | 
			
		||||
| 
						 | 
				
			
			@ -67,10 +69,13 @@ class UpdateModelMixin:
 | 
			
		|||
        serializer.is_valid(raise_exception=True)
 | 
			
		||||
        self.perform_update(serializer)
 | 
			
		||||
 | 
			
		||||
        if getattr(instance, '_prefetched_objects_cache', None):
 | 
			
		||||
        queryset = self.filter_queryset(self.get_queryset())
 | 
			
		||||
        if queryset._prefetch_related_lookups:
 | 
			
		||||
            # If 'prefetch_related' has been applied to a queryset, we need to
 | 
			
		||||
            # forcibly invalidate the prefetch cache on the instance.
 | 
			
		||||
            # forcibly invalidate the prefetch cache on the instance,
 | 
			
		||||
            # and then re-prefetch related objects
 | 
			
		||||
            instance._prefetched_objects_cache = {}
 | 
			
		||||
            prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
 | 
			
		||||
 | 
			
		||||
        return Response(serializer.data)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
from django.contrib.auth.models import Group, User
 | 
			
		||||
from django.db.models.query import Prefetch
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
 | 
			
		||||
from rest_framework import generics, serializers
 | 
			
		||||
| 
						 | 
				
			
			@ -8,51 +9,84 @@ factory = APIRequestFactory()
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class UserSerializer(serializers.ModelSerializer):
 | 
			
		||||
    permissions = serializers.SerializerMethodField()
 | 
			
		||||
 | 
			
		||||
    def get_permissions(self, obj):
 | 
			
		||||
        ret = []
 | 
			
		||||
        for g in obj.groups.all():
 | 
			
		||||
            ret.extend([p.pk for p in g.permissions.all()])
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        model = User
 | 
			
		||||
        fields = ('id', 'username', 'email', 'groups')
 | 
			
		||||
        fields = ('id', 'username', 'email', 'groups', 'permissions')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserUpdate(generics.UpdateAPIView):
 | 
			
		||||
    queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
 | 
			
		||||
class UserRetrieveUpdate(generics.RetrieveUpdateAPIView):
 | 
			
		||||
    queryset = User.objects.exclude(username='exclude').prefetch_related(
 | 
			
		||||
        Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
 | 
			
		||||
        'groups__permissions',
 | 
			
		||||
    )
 | 
			
		||||
    serializer_class = UserSerializer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
 | 
			
		||||
    queryset = User.objects.exclude(username='exclude')
 | 
			
		||||
    serializer_class = UserSerializer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPrefetchRelatedUpdates(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.user = User.objects.create(username='tom', email='tom@example.com')
 | 
			
		||||
        self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
 | 
			
		||||
        self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
 | 
			
		||||
        self.user.groups.set(self.groups)
 | 
			
		||||
        self.user.groups.add(Group.objects.create(name='exclude'))
 | 
			
		||||
        self.expected = {
 | 
			
		||||
            'id': self.user.pk,
 | 
			
		||||
            'username': 'tom',
 | 
			
		||||
            'groups': [group.pk for group in self.groups],
 | 
			
		||||
            'email': 'tom@example.com',
 | 
			
		||||
            'permissions': [],
 | 
			
		||||
        }
 | 
			
		||||
        self.view = UserRetrieveUpdate.as_view()
 | 
			
		||||
 | 
			
		||||
    def test_prefetch_related_updates(self):
 | 
			
		||||
        view = UserUpdate.as_view()
 | 
			
		||||
        pk = self.user.pk
 | 
			
		||||
        groups_pk = self.groups[0].pk
 | 
			
		||||
        request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
 | 
			
		||||
        response = view(request, pk=pk)
 | 
			
		||||
        assert User.objects.get(pk=pk).groups.count() == 1
 | 
			
		||||
        expected = {
 | 
			
		||||
            'id': pk,
 | 
			
		||||
            'username': 'new',
 | 
			
		||||
            'groups': [1],
 | 
			
		||||
            'email': 'tom@example.com'
 | 
			
		||||
        }
 | 
			
		||||
        assert response.data == expected
 | 
			
		||||
        self.groups.append(Group.objects.create(name='c'))
 | 
			
		||||
        request = factory.put(
 | 
			
		||||
            '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
 | 
			
		||||
        )
 | 
			
		||||
        self.expected['username'] = 'new'
 | 
			
		||||
        self.expected['groups'] = [group.pk for group in self.groups]
 | 
			
		||||
        response = self.view(request, pk=self.user.pk)
 | 
			
		||||
        assert User.objects.get(pk=self.user.pk).groups.count() == 12
 | 
			
		||||
        assert response.data == self.expected
 | 
			
		||||
        # Update and fetch should get same result
 | 
			
		||||
        request = factory.get('/')
 | 
			
		||||
        response = self.view(request, pk=self.user.pk)
 | 
			
		||||
        assert response.data == self.expected
 | 
			
		||||
 | 
			
		||||
    def test_prefetch_related_excluding_instance_from_original_queryset(self):
 | 
			
		||||
        """
 | 
			
		||||
        Regression test for https://github.com/encode/django-rest-framework/issues/4661
 | 
			
		||||
        """
 | 
			
		||||
        view = UserUpdate.as_view()
 | 
			
		||||
        pk = self.user.pk
 | 
			
		||||
        groups_pk = self.groups[0].pk
 | 
			
		||||
        request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
 | 
			
		||||
        response = view(request, pk=pk)
 | 
			
		||||
        assert User.objects.get(pk=pk).groups.count() == 1
 | 
			
		||||
        expected = {
 | 
			
		||||
            'id': pk,
 | 
			
		||||
            'username': 'exclude',
 | 
			
		||||
            'groups': [1],
 | 
			
		||||
            'email': 'tom@example.com'
 | 
			
		||||
        }
 | 
			
		||||
        assert response.data == expected
 | 
			
		||||
        request = factory.put(
 | 
			
		||||
            '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
 | 
			
		||||
        )
 | 
			
		||||
        response = self.view(request, pk=self.user.pk)
 | 
			
		||||
        assert User.objects.get(pk=self.user.pk).groups.count() == 2
 | 
			
		||||
        self.expected['username'] = 'exclude'
 | 
			
		||||
        self.expected['groups'] = [self.groups[0].pk]
 | 
			
		||||
        assert response.data == self.expected
 | 
			
		||||
 | 
			
		||||
    def test_db_query_count(self):
 | 
			
		||||
        request = factory.put(
 | 
			
		||||
            '/', {'username': 'new'}, format='json'
 | 
			
		||||
        )
 | 
			
		||||
        with self.assertNumQueries(7):
 | 
			
		||||
            self.view(request, pk=self.user.pk)
 | 
			
		||||
 | 
			
		||||
        request = factory.put(
 | 
			
		||||
            '/', {'username': 'new2'}, format='json'
 | 
			
		||||
        )
 | 
			
		||||
        with self.assertNumQueries(16):
 | 
			
		||||
            UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user