mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			73 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			73 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from django.contrib.auth.models import Group, User
 | 
						|
from django.test import TestCase
 | 
						|
 | 
						|
from rest_framework import generics, serializers
 | 
						|
from rest_framework.test import APIRequestFactory
 | 
						|
 | 
						|
factory = APIRequestFactory()
 | 
						|
 | 
						|
 | 
						|
class UserSerializer(serializers.ModelSerializer):
 | 
						|
    class Meta:
 | 
						|
        model = User
 | 
						|
        fields = ('id', 'username', 'email', 'groups')
 | 
						|
 | 
						|
 | 
						|
class UserUpdate(generics.UpdateAPIView):
 | 
						|
    queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
 | 
						|
    serializer_class = UserSerializer
 | 
						|
 | 
						|
 | 
						|
class TestPrefetchRelatedUpdates(TestCase):
 | 
						|
    def setUp(self):
 | 
						|
        self.user = User.objects.create(username='tom', email='tom@example.com')
 | 
						|
        self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
 | 
						|
        self.user.groups.set(self.groups)
 | 
						|
 | 
						|
    def test_prefetch_related_updates(self):
 | 
						|
        view = UserUpdate.as_view()
 | 
						|
        pk = self.user.pk
 | 
						|
        groups_pk = self.groups[0].pk
 | 
						|
        request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
 | 
						|
        response = view(request, pk=pk)
 | 
						|
        assert User.objects.get(pk=pk).groups.count() == 1
 | 
						|
        expected = {
 | 
						|
            'id': pk,
 | 
						|
            'username': 'new',
 | 
						|
            'groups': [1],
 | 
						|
            'email': 'tom@example.com'
 | 
						|
        }
 | 
						|
        assert response.data == expected
 | 
						|
 | 
						|
    def test_prefetch_related_excluding_instance_from_original_queryset(self):
 | 
						|
        """
 | 
						|
        Regression test for https://github.com/encode/django-rest-framework/issues/4661
 | 
						|
        """
 | 
						|
        view = UserUpdate.as_view()
 | 
						|
        pk = self.user.pk
 | 
						|
        groups_pk = self.groups[0].pk
 | 
						|
        request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
 | 
						|
        response = view(request, pk=pk)
 | 
						|
        assert User.objects.get(pk=pk).groups.count() == 1
 | 
						|
        expected = {
 | 
						|
            'id': pk,
 | 
						|
            'username': 'exclude',
 | 
						|
            'groups': [1],
 | 
						|
            'email': 'tom@example.com'
 | 
						|
        }
 | 
						|
        assert response.data == expected
 | 
						|
 | 
						|
    def test_can_update_without_queryset_on_class_view(self):
 | 
						|
        class UserUpdateWithoutQuerySet(generics.UpdateAPIView):
 | 
						|
            serializer_class = UserSerializer
 | 
						|
 | 
						|
            def get_object(self):
 | 
						|
                return User.objects.get(pk=self.kwargs['pk'])
 | 
						|
 | 
						|
        request = factory.patch('/', {'username': 'new'})
 | 
						|
        response = UserUpdateWithoutQuerySet.as_view()(request, pk=self.user.pk)
 | 
						|
        assert response.data['id'] == self.user.id
 | 
						|
        assert response.data['username'] == 'new'
 | 
						|
        self.user.refresh_from_db()
 | 
						|
        assert self.user.username == 'new'
 |