mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-13 10:00:53 +03:00
Merge pull request #3180 from tomchristie/guarded-queryset
Guard against erroneous direct .queryset evaluation in CBVs.
This commit is contained in:
commit
36d8d3681a
|
@ -7,6 +7,7 @@ import inspect
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from django.core.exceptions import PermissionDenied
|
from django.core.exceptions import PermissionDenied
|
||||||
|
from django.db import models
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
from django.utils.encoding import smart_text
|
from django.utils.encoding import smart_text
|
||||||
|
@ -118,8 +119,19 @@ class APIView(View):
|
||||||
This allows us to discover information about the view when we do URL
|
This allows us to discover information about the view when we do URL
|
||||||
reverse lookups. Used for breadcrumb generation.
|
reverse lookups. Used for breadcrumb generation.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
|
||||||
|
def force_evaluation():
|
||||||
|
raise AssertionError(
|
||||||
|
'Do not evaluate the `.queryset` attribute directly, '
|
||||||
|
'as the result will be cached and reused between requests. '
|
||||||
|
'Use `.all()` or call `.get_queryset()` instead.'
|
||||||
|
)
|
||||||
|
cls.queryset._fetch_all = force_evaluation
|
||||||
|
cls.queryset._result_iter = force_evaluation # Django <= 1.5
|
||||||
|
|
||||||
view = super(APIView, cls).as_view(**initkwargs)
|
view = super(APIView, cls).as_view(**initkwargs)
|
||||||
view.cls = cls
|
view.cls = cls
|
||||||
|
|
||||||
# Note: session based authentication is explicitly CSRF validated,
|
# Note: session based authentication is explicitly CSRF validated,
|
||||||
# all other authentication is CSRF exempt.
|
# all other authentication is CSRF exempt.
|
||||||
return csrf_exempt(view)
|
return csrf_exempt(view)
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import django
|
import django
|
||||||
|
import pytest
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
|
||||||
from rest_framework import generics, renderers, serializers, status
|
from rest_framework import generics, renderers, serializers, status
|
||||||
|
from rest_framework.response import Response
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
from tests.models import (
|
from tests.models import (
|
||||||
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel
|
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel
|
||||||
|
@ -527,3 +529,17 @@ class TestFilterBackendAppliedToViews(TestCase):
|
||||||
response = view(request).render()
|
response = view(request).render()
|
||||||
self.assertContains(response, 'field_b')
|
self.assertContains(response, 'field_b')
|
||||||
self.assertNotContains(response, 'field_a')
|
self.assertNotContains(response, 'field_a')
|
||||||
|
|
||||||
|
|
||||||
|
class TestGuardedQueryset(TestCase):
|
||||||
|
def test_guarded_queryset(self):
|
||||||
|
class QuerysetAccessError(generics.ListAPIView):
|
||||||
|
queryset = BasicModel.objects.all()
|
||||||
|
|
||||||
|
def get(self, request):
|
||||||
|
return Response(list(self.queryset))
|
||||||
|
|
||||||
|
view = QuerysetAccessError.as_view()
|
||||||
|
request = factory.get('/')
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
view(request).render()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user