mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-27 08:29:59 +03:00
Use _prefetch_related_lookups and refine test cases
This commit is contained in:
parent
7f24ef2af6
commit
2da4374ee1
|
@ -1,8 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Generic views that provide commonly needed behaviour.
|
Generic views that provide commonly needed behaviour.
|
||||||
"""
|
"""
|
||||||
from typing import Iterable
|
|
||||||
|
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
|
@ -47,8 +45,6 @@ class GenericAPIView(views.APIView):
|
||||||
# The style to use for queryset pagination.
|
# The style to use for queryset pagination.
|
||||||
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
|
||||||
|
|
||||||
prefetch_related = []
|
|
||||||
|
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
"""
|
"""
|
||||||
Get the list of items for this view.
|
Get the list of items for this view.
|
||||||
|
@ -72,31 +68,10 @@ class GenericAPIView(views.APIView):
|
||||||
|
|
||||||
queryset = self.queryset
|
queryset = self.queryset
|
||||||
if isinstance(queryset, QuerySet):
|
if isinstance(queryset, QuerySet):
|
||||||
# Prefetch related objects
|
|
||||||
if self.get_prefetch_related():
|
|
||||||
queryset = queryset.prefetch_related(*self.get_prefetch_related())
|
|
||||||
# Ensure queryset is re-evaluated on each request.
|
# Ensure queryset is re-evaluated on each request.
|
||||||
queryset = queryset.all()
|
queryset = queryset.all()
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
def get_prefetch_related(self):
|
|
||||||
"""
|
|
||||||
Get the list of prefetch related objects for self.queryset or instance.
|
|
||||||
This must be an iterable.
|
|
||||||
Defaults to using `self.prefetch_related`.
|
|
||||||
|
|
||||||
You may want to override this if you need to provide prefetched objects
|
|
||||||
depending on the incoming request.
|
|
||||||
|
|
||||||
(Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`)
|
|
||||||
"""
|
|
||||||
assert isinstance(self.prefetch_related, Iterable), (
|
|
||||||
"'%s' should either include an iterable `prefetch_related` attribute, "
|
|
||||||
"or override the `get_prefetch_related()` method."
|
|
||||||
% self.__class__.__name__
|
|
||||||
)
|
|
||||||
return self.prefetch_related
|
|
||||||
|
|
||||||
def get_object(self):
|
def get_object(self):
|
||||||
"""
|
"""
|
||||||
Returns the object the view is displaying.
|
Returns the object the view is displaying.
|
||||||
|
|
|
@ -74,7 +74,8 @@ class UpdateModelMixin:
|
||||||
# forcibly invalidate the prefetch cache on the instance,
|
# forcibly invalidate the prefetch cache on the instance,
|
||||||
# and then re-prefetch related objects
|
# and then re-prefetch related objects
|
||||||
instance._prefetched_objects_cache = {}
|
instance._prefetched_objects_cache = {}
|
||||||
prefetch_related_objects([instance], *self.get_prefetch_related())
|
queryset = self.filter_queryset(self.get_queryset())
|
||||||
|
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
|
||||||
|
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
|
|
@ -8,36 +8,52 @@ factory = APIRequestFactory()
|
||||||
|
|
||||||
|
|
||||||
class UserSerializer(serializers.ModelSerializer):
|
class UserSerializer(serializers.ModelSerializer):
|
||||||
|
permissions = serializers.SerializerMethodField()
|
||||||
|
|
||||||
|
def get_permissions(self, obj):
|
||||||
|
ret = []
|
||||||
|
for g in obj.groups.all():
|
||||||
|
ret.extend([p.pk for p in g.permissions.all()])
|
||||||
|
return ret
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = User
|
model = User
|
||||||
fields = ('id', 'username', 'email', 'groups')
|
fields = ('id', 'username', 'email', 'groups', 'permissions')
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(generics.UpdateAPIView):
|
class UserUpdate(generics.UpdateAPIView):
|
||||||
|
queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions')
|
||||||
|
serializer_class = UserSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
|
||||||
queryset = User.objects.exclude(username='exclude')
|
queryset = User.objects.exclude(username='exclude')
|
||||||
serializer_class = UserSerializer
|
serializer_class = UserSerializer
|
||||||
prefetch_related = ['groups']
|
|
||||||
|
|
||||||
|
|
||||||
class TestPrefetchRelatedUpdates(TestCase):
|
class TestPrefetchRelatedUpdates(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.user = User.objects.create(username='tom', email='tom@example.com')
|
self.user = User.objects.create(username='tom', email='tom@example.com')
|
||||||
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
|
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
|
||||||
self.user.groups.set(self.groups)
|
self.user.groups.set(self.groups)
|
||||||
self.expected = {
|
self.expected = {
|
||||||
'id': self.user.pk,
|
'id': self.user.pk,
|
||||||
'username': 'new',
|
'username': 'tom',
|
||||||
'groups': [1],
|
'groups': [group.pk for group in self.groups],
|
||||||
'email': 'tom@example.com',
|
'email': 'tom@example.com',
|
||||||
|
'permissions': [],
|
||||||
}
|
}
|
||||||
self.view = UserUpdate.as_view()
|
self.view = UserUpdate.as_view()
|
||||||
|
|
||||||
def test_prefetch_related_updates(self):
|
def test_prefetch_related_updates(self):
|
||||||
|
self.groups.append(Group.objects.create(name='c'))
|
||||||
request = factory.put(
|
request = factory.put(
|
||||||
'/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json'
|
'/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
|
||||||
)
|
)
|
||||||
|
self.expected['username'] = 'new'
|
||||||
|
self.expected['groups'] = [group.pk for group in self.groups]
|
||||||
response = self.view(request, pk=self.user.pk)
|
response = self.view(request, pk=self.user.pk)
|
||||||
assert User.objects.get(pk=self.user.pk).groups.count() == 1
|
assert User.objects.get(pk=self.user.pk).groups.count() == 11
|
||||||
assert response.data == self.expected
|
assert response.data == self.expected
|
||||||
|
|
||||||
def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
||||||
|
@ -50,4 +66,18 @@ class TestPrefetchRelatedUpdates(TestCase):
|
||||||
response = self.view(request, pk=self.user.pk)
|
response = self.view(request, pk=self.user.pk)
|
||||||
assert User.objects.get(pk=self.user.pk).groups.count() == 1
|
assert User.objects.get(pk=self.user.pk).groups.count() == 1
|
||||||
self.expected['username'] = 'exclude'
|
self.expected['username'] = 'exclude'
|
||||||
|
self.expected['groups'] = [self.groups[0].pk]
|
||||||
assert response.data == self.expected
|
assert response.data == self.expected
|
||||||
|
|
||||||
|
def test_db_query_count(self):
|
||||||
|
request = factory.put(
|
||||||
|
'/', {'username': 'new'}, format='json'
|
||||||
|
)
|
||||||
|
with self.assertNumQueries(7):
|
||||||
|
self.view(request, pk=self.user.pk)
|
||||||
|
|
||||||
|
request = factory.put(
|
||||||
|
'/', {'username': 'new2'}, format='json'
|
||||||
|
)
|
||||||
|
with self.assertNumQueries(15):
|
||||||
|
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user