From e7af8d662bf837e4fee844b28606cda63c0d30a9 Mon Sep 17 00:00:00 2001 From: itsdkey Date: Wed, 8 Jun 2022 14:41:26 +0200 Subject: [PATCH] tests for #5127 (#7715) * tests for #5127 * Resolves #5127 --- tests/browsable_api/no_auth_urls.py | 3 ++- tests/browsable_api/serializers.py | 8 +++++++ tests/browsable_api/test_browsable_api.py | 27 +++++++++++++++++++++++ tests/browsable_api/views.py | 22 ++++++++++++++++++ tests/models.py | 5 +++++ 5 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 tests/browsable_api/serializers.py diff --git a/tests/browsable_api/no_auth_urls.py b/tests/browsable_api/no_auth_urls.py index 65701c065..33491ad92 100644 --- a/tests/browsable_api/no_auth_urls.py +++ b/tests/browsable_api/no_auth_urls.py @@ -1,7 +1,8 @@ from django.urls import path -from .views import MockView +from .views import BasicModelWithUsersViewSet, MockView urlpatterns = [ path('', MockView.as_view()), + path('basicviewset', BasicModelWithUsersViewSet.as_view({'get': 'list'})), ] diff --git a/tests/browsable_api/serializers.py b/tests/browsable_api/serializers.py new file mode 100644 index 000000000..e8a1cdef8 --- /dev/null +++ b/tests/browsable_api/serializers.py @@ -0,0 +1,8 @@ +from rest_framework.serializers import ModelSerializer +from tests.models import BasicModelWithUsers + + +class BasicSerializer(ModelSerializer): + class Meta: + model = BasicModelWithUsers + fields = '__all__' diff --git a/tests/browsable_api/test_browsable_api.py b/tests/browsable_api/test_browsable_api.py index 17644c2ac..a76d11fe3 100644 --- a/tests/browsable_api/test_browsable_api.py +++ b/tests/browsable_api/test_browsable_api.py @@ -1,8 +1,35 @@ from django.contrib.auth.models import User from django.test import TestCase, override_settings +from rest_framework.permissions import IsAuthenticated from rest_framework.test import APIClient +from .views import BasicModelWithUsersViewSet, OrganizationPermissions + + +@override_settings(ROOT_URLCONF='tests.browsable_api.no_auth_urls') +class AnonymousUserTests(TestCase): + """Tests correct handling of anonymous user request on endpoints with IsAuthenticated permission class.""" + + def setUp(self): + self.client = APIClient(enforce_csrf_checks=True) + + def tearDown(self): + self.client.logout() + + def test_get_raises_typeerror_when_anonymous_user_in_queryset_filter(self): + with self.assertRaises(TypeError): + self.client.get('/basicviewset') + + def test_get_returns_http_forbidden_when_anonymous_user(self): + old_permissions = BasicModelWithUsersViewSet.permission_classes + BasicModelWithUsersViewSet.permission_classes = [IsAuthenticated, OrganizationPermissions] + + response = self.client.get('/basicviewset') + + BasicModelWithUsersViewSet.permission_classes = old_permissions + self.assertEqual(response.status_code, 403) + @override_settings(ROOT_URLCONF='tests.browsable_api.auth_urls') class DropdownWithAuthTests(TestCase): diff --git a/tests/browsable_api/views.py b/tests/browsable_api/views.py index e1cf13a1e..e73967bf8 100644 --- a/tests/browsable_api/views.py +++ b/tests/browsable_api/views.py @@ -1,6 +1,16 @@ from rest_framework import authentication, renderers +from rest_framework.permissions import BasePermission from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework.viewsets import ModelViewSet + +from ..models import BasicModelWithUsers +from .serializers import BasicSerializer + + +class OrganizationPermissions(BasePermission): + def has_object_permission(self, request, view, obj): + return request.user.is_staff or (request.user == obj.owner.organization_user.user) class MockView(APIView): @@ -9,3 +19,15 @@ class MockView(APIView): def get(self, request): return Response({'a': 1, 'b': 2, 'c': 3}) + + +class BasicModelWithUsersViewSet(ModelViewSet): + queryset = BasicModelWithUsers.objects.all() + serializer_class = BasicSerializer + permission_classes = [OrganizationPermissions] + # permission_classes = [IsAuthenticated, OrganizationPermissions] + renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + + def get_queryset(self): + qs = super().get_queryset().filter(users=self.request.user) + return qs diff --git a/tests/models.py b/tests/models.py index afe649760..666e9f003 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,5 +1,6 @@ import uuid +from django.contrib.auth.models import User from django.db import models from django.utils.translation import gettext_lazy as _ @@ -33,6 +34,10 @@ class ManyToManySource(RESTFrameworkModel): targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') +class BasicModelWithUsers(RESTFrameworkModel): + users = models.ManyToManyField(User) + + # ForeignKey class ForeignKeyTarget(RESTFrameworkModel): name = models.CharField(max_length=100)