diff --git a/rest_framework/views.py b/rest_framework/views.py index a709c2f6b..9c9c8e19a 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -7,6 +7,7 @@ import inspect import warnings from django.core.exceptions import PermissionDenied +from django.db import models from django.http import Http404 from django.utils import six from django.utils.encoding import smart_text @@ -118,8 +119,19 @@ class APIView(View): This allows us to discover information about the view when we do URL reverse lookups. Used for breadcrumb generation. """ + if isinstance(getattr(cls, 'queryset', None), models.QuerySet): + def force_evaluation(): + raise AssertionError( + 'Do not evaluate the `.queryset` attribute directly, ' + 'as the result will be cached and reused between requests. ' + 'Use `.all()` or call `.get_queryset()` instead.' + ) + + cls.queryset._fetch_all = force_evaluation + view = super(APIView, cls).as_view(**initkwargs) view.cls = cls + # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. return csrf_exempt(view) diff --git a/tests/test_generics.py b/tests/test_generics.py index 219a83a5d..5db0b6f71 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,12 +1,14 @@ from __future__ import unicode_literals import django +import pytest from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase from django.utils import six from rest_framework import generics, renderers, serializers, status +from rest_framework.response import Response from rest_framework.test import APIRequestFactory from tests.models import ( BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel @@ -527,3 +529,17 @@ class TestFilterBackendAppliedToViews(TestCase): response = view(request).render() self.assertContains(response, 'field_b') self.assertNotContains(response, 'field_a') + + +class TestGuardedQueryset(TestCase): + def test_guarded_queryset(self): + class QuerysetAccessError(generics.ListAPIView): + queryset = BasicModel.objects.all() + + def get(self, request): + return Response(list(self.queryset)) + + view = QuerysetAccessError.as_view() + request = factory.get('/') + with pytest.raises(AssertionError): + view(request).render()