mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-24 08:14:16 +03:00
Fix prefetch_related updates. (#4553)
This commit is contained in:
parent
aed4ed5e73
commit
d0b3b6470a
|
@ -68,6 +68,13 @@ class UpdateModelMixin(object):
|
|||
serializer = self.get_serializer(instance, data=request.data, partial=partial)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
self.perform_update(serializer)
|
||||
|
||||
if getattr(instance, '_prefetched_objects_cache', None):
|
||||
# If 'prefetch_related' has been applied to a queryset, we need to
|
||||
# refresh the instance from the database.
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance)
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
def perform_update(self, serializer):
|
||||
|
|
41
tests/test_prefetch_related.py
Normal file
41
tests/test_prefetch_related.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from django.contrib.auth.models import Group, User
|
||||
from django.test import TestCase
|
||||
|
||||
from rest_framework import generics, serializers
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
||||
class UserSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ('id', 'username', 'email', 'groups')
|
||||
|
||||
|
||||
class UserUpdate(generics.UpdateAPIView):
|
||||
queryset = User.objects.all().prefetch_related('groups')
|
||||
serializer_class = UserSerializer
|
||||
|
||||
|
||||
class TestPrefetchRelatedUpdates(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create(username='tom', email='tom@example.com')
|
||||
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
|
||||
self.user.groups = self.groups
|
||||
self.user.save()
|
||||
|
||||
def test_prefetch_related_updates(self):
|
||||
view = UserUpdate.as_view()
|
||||
pk = self.user.pk
|
||||
groups_pk = self.groups[0].pk
|
||||
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
|
||||
response = view(request, pk=pk)
|
||||
assert User.objects.get(pk=pk).groups.count() == 1
|
||||
expected = {
|
||||
'id': pk,
|
||||
'username': 'new',
|
||||
'groups': [1],
|
||||
'email': 'tom@example.com'
|
||||
}
|
||||
assert response.data == expected
|
Loading…
Reference in New Issue
Block a user