Guard against erronous direct .queryset evaluation in CBVs.

This commit is contained in:
Tom Christie 2015-07-23 17:17:18 +01:00
parent 9d136abb24
commit e05021c8c6
2 changed files with 28 additions and 0 deletions

View File

@ -7,6 +7,7 @@ import inspect
import warnings
from django.core.exceptions import PermissionDenied
from django.db import models
from django.http import Http404
from django.utils import six
from django.utils.encoding import smart_text
@ -118,8 +119,19 @@ class APIView(View):
This allows us to discover information about the view when we do URL
reverse lookups. Used for breadcrumb generation.
"""
if isinstance(getattr(cls, 'queryset', None), models.QuerySet):
def force_evaluation():
raise AssertionError(
'Do not evaluate the `.queryset` attribute directly, '
'as the result will be cached and reused between requests. '
'Use `.all()` or call `.get_queryset()` instead.'
)
cls.queryset._fetch_all = force_evaluation
view = super(APIView, cls).as_view(**initkwargs)
view.cls = cls
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.
return csrf_exempt(view)

View File

@ -1,12 +1,14 @@
from __future__ import unicode_literals
import django
import pytest
from django.db import models
from django.shortcuts import get_object_or_404
from django.test import TestCase
from django.utils import six
from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel
@ -527,3 +529,17 @@ class TestFilterBackendAppliedToViews(TestCase):
response = view(request).render()
self.assertContains(response, 'field_b')
self.assertNotContains(response, 'field_a')
class TestGuardedQueryset(TestCase):
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all()
def get(self, request):
return Response(list(self.queryset))
view = QuerysetAccessError.as_view()
request = factory.get('/')
with pytest.raises(AssertionError):
view(request).render()