Merge pull request #3180 from tomchristie/guarded-queryset

Guard against erroneous direct .queryset evaluation in CBVs.
This commit is contained in:
Tom Christie 2015-07-24 09:11:56 +01:00
commit 36d8d3681a
2 changed files with 28 additions and 0 deletions

View File

@ -7,6 +7,7 @@ import inspect
import warnings import warnings
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.db import models
from django.http import Http404 from django.http import Http404
from django.utils import six from django.utils import six
from django.utils.encoding import smart_text from django.utils.encoding import smart_text
@ -118,8 +119,19 @@ class APIView(View):
This allows us to discover information about the view when we do URL This allows us to discover information about the view when we do URL
reverse lookups. Used for breadcrumb generation. reverse lookups. Used for breadcrumb generation.
""" """
if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
def force_evaluation():
raise AssertionError(
'Do not evaluate the `.queryset` attribute directly, '
'as the result will be cached and reused between requests. '
'Use `.all()` or call `.get_queryset()` instead.'
)
cls.queryset._fetch_all = force_evaluation
cls.queryset._result_iter = force_evaluation # Django <= 1.5
view = super(APIView, cls).as_view(**initkwargs) view = super(APIView, cls).as_view(**initkwargs)
view.cls = cls view.cls = cls
# Note: session based authentication is explicitly CSRF validated, # Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt. # all other authentication is CSRF exempt.
return csrf_exempt(view) return csrf_exempt(view)

View File

@ -1,12 +1,14 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import django import django
import pytest
from django.db import models from django.db import models
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from rest_framework import generics, renderers, serializers, status from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from tests.models import ( from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel
@ -527,3 +529,17 @@ class TestFilterBackendAppliedToViews(TestCase):
response = view(request).render() response = view(request).render()
self.assertContains(response, 'field_b') self.assertContains(response, 'field_b')
self.assertNotContains(response, 'field_a') self.assertNotContains(response, 'field_a')
class TestGuardedQueryset(TestCase):
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all()
def get(self, request):
return Response(list(self.queryset))
view = QuerysetAccessError.as_view()
request = factory.get('/')
with pytest.raises(AssertionError):
view(request).render()