From f471f5871fca9b09c302aabf83f0683d078e40bf Mon Sep 17 00:00:00 2001 From: Ghosts6 Date: Sat, 19 Oct 2024 19:41:07 +0100 Subject: [PATCH] Add custom validators, permission classes, and PermissionTestModel for testing permissions --- rest_framework/fields.py | 40 +++++++++++++ rest_framework/permissions.py | 25 ++++++++ tests/models.py | 8 +++ tests/test_fields.py | 108 +++++++++++++++++++++++++++++++++- tests/test_permissions.py | 50 +++++++++++++++- 5 files changed, 229 insertions(+), 2 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 6989edc0a..1b6945f9f 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -882,6 +882,46 @@ class IPAddressField(CharField): return super().to_internal_value(data) +class AlphabeticFieldValidator: + """ + Custom validator to ensure that a field only contains alphabetic characters and spaces. + """ + def __call__(self, value): + if not isinstance(value, str): + raise ValueError("This field must be a string.") + if value == "": + raise ValueError("This field must contain only alphabetic characters and spaces.") + if not re.match(r'^[A-Za-z ]*$', value): + raise ValueError("This field must contain only alphabetic characters and spaces.") + +class AlphanumericFieldValidator: + """ + Custom validator to ensure the field contains only alphanumeric characters (letters and numbers). + """ + def __call__(self, value): + if not isinstance(value, str): + raise ValueError("This field must be a string.") + if value == "": + raise ValueError("This field must contain only alphanumeric characters (letters and numbers).") + if not re.match(r'^[A-Za-z0-9]*$', value): + raise ValueError("This field must contain only alphanumeric characters (letters and numbers).") + +class CustomLengthValidator: + """ + Custom validator to ensure the length of a string is within specified limits. + """ + def __init__(self, min_length=0, max_length=None): + self.min_length = min_length + self.max_length = max_length + + def __call__(self, value): + if len(value) < self.min_length: + raise ValueError(f"This field must be at least {self.min_length} characters long.") + + if self.max_length is not None and len(value) > self.max_length: + raise ValueError(f"This field must be no more than {self.max_length} characters long.") + + # Number types... class IntegerField(Field): diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 7c15eca58..3c0de3d67 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -172,6 +172,31 @@ class IsAuthenticatedOrReadOnly(BasePermission): request.user.is_authenticated ) +class IsAdminUserOrReadOnly(BasePermission): + """ + Custom permission to only allow admin users to edit an object. + """ + + def has_permission(self, request, view): + # Allow any user to view the object + if request.method in ['GET', 'HEAD', 'OPTIONS']: + return True + # Only allow admin users to modify the object + return request.user and request.user.is_staff + + +class IsOwner(BasePermission): + """ + Custom permission to only allow owners of an object to edit it. + """ + + def has_object_permission(self, request, view, obj): + # Allow read-only access to any request + if request.method in ['GET', 'HEAD', 'OPTIONS']: + return True + # Write permissions are only allowed to the owner of the object + return obj.owner == request.user + class DjangoModelPermissions(BasePermission): """ diff --git a/tests/models.py b/tests/models.py index 88e3d8dca..1c5d58c1d 100644 --- a/tests/models.py +++ b/tests/models.py @@ -150,3 +150,11 @@ class CustomManagerModel(RESTFrameworkModel): help_text='OneToOneTarget', verbose_name='OneToOneTarget', on_delete=models.CASCADE) + + +class OwnershipTestModel(models.Model): + owner = models.ForeignKey(User, on_delete=models.CASCADE, related_name='ownership_test_models') + title = models.CharField(max_length=100) + + def __str__(self): + return self.title \ No newline at end of file diff --git a/tests/test_fields.py b/tests/test_fields.py index 1403a6a35..7ac4a0857 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -25,7 +25,7 @@ from django.utils.timezone import activate, deactivate, override import rest_framework from rest_framework import exceptions, serializers from rest_framework.fields import ( - BuiltinSignatureError, DjangoImageField, SkipField, empty, + AlphabeticFieldValidator, AlphanumericFieldValidator, BuiltinSignatureError, CustomLengthValidator, DjangoImageField, SkipField, empty, is_simple_callable ) from tests.models import UUIDForeignKeyTarget @@ -1061,6 +1061,112 @@ class TestFilePathField(FieldValues): ) +class TestAlphabeticField: + valid_inputs = { + 'John Doe': 'John Doe', + 'Alice': 'Alice', + 'Bob Marley': 'Bob Marley', + } + invalid_inputs = { + 'John123': ['This field must contain only alphabetic characters and spaces.'], + 'Alice!': ['This field must contain only alphabetic characters and spaces.'], + '': ['This field must contain only alphabetic characters and spaces.'], + } + non_string_inputs = [ + 123, # Integer + 45.67, # Float + None, # NoneType + [], # Empty list + {}, # Empty dict + set() # Empty set + ] + + def test_valid_inputs(self): + validator = AlphabeticFieldValidator() + for value in self.valid_inputs.keys(): + validator(value) + + def test_invalid_inputs(self): + validator = AlphabeticFieldValidator() + for value, expected_errors in self.invalid_inputs.items(): + with pytest.raises(ValueError) as excinfo: + validator(value) + assert str(excinfo.value) == expected_errors[0] + + def test_non_string_inputs(self): + validator = AlphabeticFieldValidator() + for value in self.non_string_inputs: + with pytest.raises(ValueError) as excinfo: + validator(value) + assert str(excinfo.value) == "This field must be a string." + + +class TestAlphanumericField: + valid_inputs = { + 'John123': 'John123', + 'Alice007': 'Alice007', + 'Bob1990': 'Bob1990', + } + invalid_inputs = { + 'John!': ['This field must contain only alphanumeric characters (letters and numbers).'], + 'Alice 007': ['This field must contain only alphanumeric characters (letters and numbers).'], + '': ['This field must contain only alphanumeric characters (letters and numbers).'], + } + non_string_inputs = [ + 123, # Integer + 45.67, # Float + None, # NoneType + [], # Empty list + {}, # Empty dict + set() # Empty set + ] + + def test_valid_inputs(self): + validator = AlphanumericFieldValidator() + for value in self.valid_inputs.keys(): + validator(value) + + def test_invalid_inputs(self): + validator = AlphanumericFieldValidator() + for value, expected_errors in self.invalid_inputs.items(): + with pytest.raises(ValueError) as excinfo: + validator(value) + assert str(excinfo.value) == expected_errors[0] + + def test_non_string_inputs(self): + validator = AlphanumericFieldValidator() + for value in self.non_string_inputs: + with pytest.raises(ValueError) as excinfo: + validator(value) + assert str(excinfo.value) == "This field must be a string." + +class TestCustomLengthField: + """ + Valid and invalid values for `CustomLengthValidator`. + """ + valid_inputs = { + 'abc': 'abc', # 3 characters + 'abcdefghij': 'abcdefghij', # 10 characters + } + invalid_inputs = { + 'ab': ['This field must be at least 3 characters long.'], # Too short + 'abcdefghijk': ['This field must be no more than 10 characters long.'], # Too long + } + field = str + + def test_valid_inputs(self): + validator = CustomLengthValidator(min_length=3, max_length=10) + for value in self.valid_inputs.keys(): + validator(value) + + def test_invalid_inputs(self): + validator = CustomLengthValidator(min_length=3, max_length=10) + for value, expected_errors in self.invalid_inputs.items(): + with pytest.raises(ValueError) as excinfo: + validator(value) + assert str(excinfo.value) == expected_errors[0] + + # Number types... class TestIntegerField(FieldValues): diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 39b7ed662..417dbd3c8 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -14,7 +14,7 @@ from rest_framework import ( ) from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory -from tests.models import BasicModel +from tests.models import BasicModel, OwnershipTestModel factory = APIRequestFactory() @@ -772,3 +772,51 @@ class PermissionsCompositionTests(TestCase): ] assert filtered_permissions == expected_permissions + + +class PermissionTests(TestCase): + def setUp(self): + self.factory = APIRequestFactory() + self.admin_user = User.objects.create_user(username='admin', password='password', is_staff=True) + self.regular_user = User.objects.create_user(username='user', password='password') + self.anonymous_user = AnonymousUser() + + def test_is_admin_user_or_read_only_allow_read(self): + request = self.factory.get('/1', format='json') + request.user = self.anonymous_user + permission = permissions.IsAdminUserOrReadOnly() + self.assertTrue(permission.has_permission(request, None)) + + request.user = self.admin_user + self.assertTrue(permission.has_permission(request, None)) + + def test_is_admin_user_or_read_only_allow_write(self): + request = self.factory.post('/1', format='json') + request.user = self.admin_user + permission = permissions.IsAdminUserOrReadOnly() + self.assertTrue(permission.has_permission(request, None)) + + request.user = self.regular_user + self.assertFalse(permission.has_permission(request, None)) + + def test_is_owner_permission(self): + obj = OwnershipTestModel.objects.create(owner=self.admin_user, title='Test Title') + + request = self.factory.post('/1', format='json') + request.user = self.admin_user + permission = permissions.IsOwner() + self.assertTrue(permission.has_object_permission(request, None, obj)) + + request.user = self.regular_user + self.assertFalse(permission.has_object_permission(request, None, obj)) + + def test_is_owner_read_access(self): + obj = OwnershipTestModel.objects.create(owner=self.admin_user, title='Test Title') + + request = self.factory.get('/1', format='json') + request.user = self.regular_user + permission = permissions.IsOwner() + self.assertTrue(permission.has_object_permission(request, None, obj)) + + request.user = self.admin_user + self.assertTrue(permission.has_object_permission(request, None, obj)) \ No newline at end of file