mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-24 10:34:03 +03:00
Add custom validators, permission classes, and PermissionTestModel for testing permissions
This commit is contained in:
parent
d3dd45b3f4
commit
f471f5871f
|
@ -882,6 +882,46 @@ class IPAddressField(CharField):
|
||||||
return super().to_internal_value(data)
|
return super().to_internal_value(data)
|
||||||
|
|
||||||
|
|
||||||
|
class AlphabeticFieldValidator:
|
||||||
|
"""
|
||||||
|
Custom validator to ensure that a field only contains alphabetic characters and spaces.
|
||||||
|
"""
|
||||||
|
def __call__(self, value):
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise ValueError("This field must be a string.")
|
||||||
|
if value == "":
|
||||||
|
raise ValueError("This field must contain only alphabetic characters and spaces.")
|
||||||
|
if not re.match(r'^[A-Za-z ]*$', value):
|
||||||
|
raise ValueError("This field must contain only alphabetic characters and spaces.")
|
||||||
|
|
||||||
|
class AlphanumericFieldValidator:
|
||||||
|
"""
|
||||||
|
Custom validator to ensure the field contains only alphanumeric characters (letters and numbers).
|
||||||
|
"""
|
||||||
|
def __call__(self, value):
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise ValueError("This field must be a string.")
|
||||||
|
if value == "":
|
||||||
|
raise ValueError("This field must contain only alphanumeric characters (letters and numbers).")
|
||||||
|
if not re.match(r'^[A-Za-z0-9]*$', value):
|
||||||
|
raise ValueError("This field must contain only alphanumeric characters (letters and numbers).")
|
||||||
|
|
||||||
|
class CustomLengthValidator:
|
||||||
|
"""
|
||||||
|
Custom validator to ensure the length of a string is within specified limits.
|
||||||
|
"""
|
||||||
|
def __init__(self, min_length=0, max_length=None):
|
||||||
|
self.min_length = min_length
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
if len(value) < self.min_length:
|
||||||
|
raise ValueError(f"This field must be at least {self.min_length} characters long.")
|
||||||
|
|
||||||
|
if self.max_length is not None and len(value) > self.max_length:
|
||||||
|
raise ValueError(f"This field must be no more than {self.max_length} characters long.")
|
||||||
|
|
||||||
|
|
||||||
# Number types...
|
# Number types...
|
||||||
|
|
||||||
class IntegerField(Field):
|
class IntegerField(Field):
|
||||||
|
|
|
@ -172,6 +172,31 @@ class IsAuthenticatedOrReadOnly(BasePermission):
|
||||||
request.user.is_authenticated
|
request.user.is_authenticated
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class IsAdminUserOrReadOnly(BasePermission):
|
||||||
|
"""
|
||||||
|
Custom permission to only allow admin users to edit an object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def has_permission(self, request, view):
|
||||||
|
# Allow any user to view the object
|
||||||
|
if request.method in ['GET', 'HEAD', 'OPTIONS']:
|
||||||
|
return True
|
||||||
|
# Only allow admin users to modify the object
|
||||||
|
return request.user and request.user.is_staff
|
||||||
|
|
||||||
|
|
||||||
|
class IsOwner(BasePermission):
|
||||||
|
"""
|
||||||
|
Custom permission to only allow owners of an object to edit it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def has_object_permission(self, request, view, obj):
|
||||||
|
# Allow read-only access to any request
|
||||||
|
if request.method in ['GET', 'HEAD', 'OPTIONS']:
|
||||||
|
return True
|
||||||
|
# Write permissions are only allowed to the owner of the object
|
||||||
|
return obj.owner == request.user
|
||||||
|
|
||||||
|
|
||||||
class DjangoModelPermissions(BasePermission):
|
class DjangoModelPermissions(BasePermission):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -150,3 +150,11 @@ class CustomManagerModel(RESTFrameworkModel):
|
||||||
help_text='OneToOneTarget',
|
help_text='OneToOneTarget',
|
||||||
verbose_name='OneToOneTarget',
|
verbose_name='OneToOneTarget',
|
||||||
on_delete=models.CASCADE)
|
on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class OwnershipTestModel(models.Model):
|
||||||
|
owner = models.ForeignKey(User, on_delete=models.CASCADE, related_name='ownership_test_models')
|
||||||
|
title = models.CharField(max_length=100)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.title
|
|
@ -25,7 +25,7 @@ from django.utils.timezone import activate, deactivate, override
|
||||||
import rest_framework
|
import rest_framework
|
||||||
from rest_framework import exceptions, serializers
|
from rest_framework import exceptions, serializers
|
||||||
from rest_framework.fields import (
|
from rest_framework.fields import (
|
||||||
BuiltinSignatureError, DjangoImageField, SkipField, empty,
|
AlphabeticFieldValidator, AlphanumericFieldValidator, BuiltinSignatureError, CustomLengthValidator, DjangoImageField, SkipField, empty,
|
||||||
is_simple_callable
|
is_simple_callable
|
||||||
)
|
)
|
||||||
from tests.models import UUIDForeignKeyTarget
|
from tests.models import UUIDForeignKeyTarget
|
||||||
|
@ -1061,6 +1061,112 @@ class TestFilePathField(FieldValues):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAlphabeticField:
|
||||||
|
valid_inputs = {
|
||||||
|
'John Doe': 'John Doe',
|
||||||
|
'Alice': 'Alice',
|
||||||
|
'Bob Marley': 'Bob Marley',
|
||||||
|
}
|
||||||
|
invalid_inputs = {
|
||||||
|
'John123': ['This field must contain only alphabetic characters and spaces.'],
|
||||||
|
'Alice!': ['This field must contain only alphabetic characters and spaces.'],
|
||||||
|
'': ['This field must contain only alphabetic characters and spaces.'],
|
||||||
|
}
|
||||||
|
non_string_inputs = [
|
||||||
|
123, # Integer
|
||||||
|
45.67, # Float
|
||||||
|
None, # NoneType
|
||||||
|
[], # Empty list
|
||||||
|
{}, # Empty dict
|
||||||
|
set() # Empty set
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_valid_inputs(self):
|
||||||
|
validator = AlphabeticFieldValidator()
|
||||||
|
for value in self.valid_inputs.keys():
|
||||||
|
validator(value)
|
||||||
|
|
||||||
|
def test_invalid_inputs(self):
|
||||||
|
validator = AlphabeticFieldValidator()
|
||||||
|
for value, expected_errors in self.invalid_inputs.items():
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
validator(value)
|
||||||
|
assert str(excinfo.value) == expected_errors[0]
|
||||||
|
|
||||||
|
def test_non_string_inputs(self):
|
||||||
|
validator = AlphabeticFieldValidator()
|
||||||
|
for value in self.non_string_inputs:
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
validator(value)
|
||||||
|
assert str(excinfo.value) == "This field must be a string."
|
||||||
|
|
||||||
|
|
||||||
|
class TestAlphanumericField:
|
||||||
|
valid_inputs = {
|
||||||
|
'John123': 'John123',
|
||||||
|
'Alice007': 'Alice007',
|
||||||
|
'Bob1990': 'Bob1990',
|
||||||
|
}
|
||||||
|
invalid_inputs = {
|
||||||
|
'John!': ['This field must contain only alphanumeric characters (letters and numbers).'],
|
||||||
|
'Alice 007': ['This field must contain only alphanumeric characters (letters and numbers).'],
|
||||||
|
'': ['This field must contain only alphanumeric characters (letters and numbers).'],
|
||||||
|
}
|
||||||
|
non_string_inputs = [
|
||||||
|
123, # Integer
|
||||||
|
45.67, # Float
|
||||||
|
None, # NoneType
|
||||||
|
[], # Empty list
|
||||||
|
{}, # Empty dict
|
||||||
|
set() # Empty set
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_valid_inputs(self):
|
||||||
|
validator = AlphanumericFieldValidator()
|
||||||
|
for value in self.valid_inputs.keys():
|
||||||
|
validator(value)
|
||||||
|
|
||||||
|
def test_invalid_inputs(self):
|
||||||
|
validator = AlphanumericFieldValidator()
|
||||||
|
for value, expected_errors in self.invalid_inputs.items():
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
validator(value)
|
||||||
|
assert str(excinfo.value) == expected_errors[0]
|
||||||
|
|
||||||
|
def test_non_string_inputs(self):
|
||||||
|
validator = AlphanumericFieldValidator()
|
||||||
|
for value in self.non_string_inputs:
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
validator(value)
|
||||||
|
assert str(excinfo.value) == "This field must be a string."
|
||||||
|
|
||||||
|
class TestCustomLengthField:
|
||||||
|
"""
|
||||||
|
Valid and invalid values for `CustomLengthValidator`.
|
||||||
|
"""
|
||||||
|
valid_inputs = {
|
||||||
|
'abc': 'abc', # 3 characters
|
||||||
|
'abcdefghij': 'abcdefghij', # 10 characters
|
||||||
|
}
|
||||||
|
invalid_inputs = {
|
||||||
|
'ab': ['This field must be at least 3 characters long.'], # Too short
|
||||||
|
'abcdefghijk': ['This field must be no more than 10 characters long.'], # Too long
|
||||||
|
}
|
||||||
|
field = str
|
||||||
|
|
||||||
|
def test_valid_inputs(self):
|
||||||
|
validator = CustomLengthValidator(min_length=3, max_length=10)
|
||||||
|
for value in self.valid_inputs.keys():
|
||||||
|
validator(value)
|
||||||
|
|
||||||
|
def test_invalid_inputs(self):
|
||||||
|
validator = CustomLengthValidator(min_length=3, max_length=10)
|
||||||
|
for value, expected_errors in self.invalid_inputs.items():
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
validator(value)
|
||||||
|
assert str(excinfo.value) == expected_errors[0]
|
||||||
|
|
||||||
|
|
||||||
# Number types...
|
# Number types...
|
||||||
|
|
||||||
class TestIntegerField(FieldValues):
|
class TestIntegerField(FieldValues):
|
||||||
|
|
|
@ -14,7 +14,7 @@ from rest_framework import (
|
||||||
)
|
)
|
||||||
from rest_framework.routers import DefaultRouter
|
from rest_framework.routers import DefaultRouter
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
from tests.models import BasicModel
|
from tests.models import BasicModel, OwnershipTestModel
|
||||||
|
|
||||||
factory = APIRequestFactory()
|
factory = APIRequestFactory()
|
||||||
|
|
||||||
|
@ -772,3 +772,51 @@ class PermissionsCompositionTests(TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
assert filtered_permissions == expected_permissions
|
assert filtered_permissions == expected_permissions
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.factory = APIRequestFactory()
|
||||||
|
self.admin_user = User.objects.create_user(username='admin', password='password', is_staff=True)
|
||||||
|
self.regular_user = User.objects.create_user(username='user', password='password')
|
||||||
|
self.anonymous_user = AnonymousUser()
|
||||||
|
|
||||||
|
def test_is_admin_user_or_read_only_allow_read(self):
|
||||||
|
request = self.factory.get('/1', format='json')
|
||||||
|
request.user = self.anonymous_user
|
||||||
|
permission = permissions.IsAdminUserOrReadOnly()
|
||||||
|
self.assertTrue(permission.has_permission(request, None))
|
||||||
|
|
||||||
|
request.user = self.admin_user
|
||||||
|
self.assertTrue(permission.has_permission(request, None))
|
||||||
|
|
||||||
|
def test_is_admin_user_or_read_only_allow_write(self):
|
||||||
|
request = self.factory.post('/1', format='json')
|
||||||
|
request.user = self.admin_user
|
||||||
|
permission = permissions.IsAdminUserOrReadOnly()
|
||||||
|
self.assertTrue(permission.has_permission(request, None))
|
||||||
|
|
||||||
|
request.user = self.regular_user
|
||||||
|
self.assertFalse(permission.has_permission(request, None))
|
||||||
|
|
||||||
|
def test_is_owner_permission(self):
|
||||||
|
obj = OwnershipTestModel.objects.create(owner=self.admin_user, title='Test Title')
|
||||||
|
|
||||||
|
request = self.factory.post('/1', format='json')
|
||||||
|
request.user = self.admin_user
|
||||||
|
permission = permissions.IsOwner()
|
||||||
|
self.assertTrue(permission.has_object_permission(request, None, obj))
|
||||||
|
|
||||||
|
request.user = self.regular_user
|
||||||
|
self.assertFalse(permission.has_object_permission(request, None, obj))
|
||||||
|
|
||||||
|
def test_is_owner_read_access(self):
|
||||||
|
obj = OwnershipTestModel.objects.create(owner=self.admin_user, title='Test Title')
|
||||||
|
|
||||||
|
request = self.factory.get('/1', format='json')
|
||||||
|
request.user = self.regular_user
|
||||||
|
permission = permissions.IsOwner()
|
||||||
|
self.assertTrue(permission.has_object_permission(request, None, obj))
|
||||||
|
|
||||||
|
request.user = self.admin_user
|
||||||
|
self.assertTrue(permission.has_object_permission(request, None, obj))
|
Loading…
Reference in New Issue
Block a user