diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 15c4b9105..66a335c1c 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -207,8 +207,10 @@ def get_field_kwargs(field_name, model_field): if isinstance(model_field, models.GenericIPAddressField): validator_kwarg = [ validator for validator in validator_kwarg - if validator is not validators.validate_ipv46_address + if validator not in [validators.validate_ipv46_address, validators.validate_ipv6_address, validators.validate_ipv4_address] ] + kwargs['protocol'] = getattr(model_field, 'protocol', 'both') + # Our decimal validation is handled in the field code, not validator code. if isinstance(model_field, models.DecimalField): validator_kwarg = [ diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index eac51ae70..2b56d1060 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -24,6 +24,7 @@ from django.test import TestCase from rest_framework import serializers from rest_framework.compat import postgres_fields +from rest_framework.exceptions import ValidationError from .models import NestedForeignKeySource @@ -404,21 +405,135 @@ class TestDurationFieldMapping(TestCase): class TestGenericIPAddressFieldValidation(TestCase): - def test_ip_address_validation(self): - class IPAddressFieldModel(models.Model): - address = models.GenericIPAddressField() - class TestSerializer(serializers.ModelSerializer): + def setUp(self): + class IPv4Model(models.Model): + address = models.GenericIPAddressField(protocol="IPv4") + + class IPv4TestSerializer(serializers.ModelSerializer): class Meta: - model = IPAddressFieldModel + model = IPv4Model fields = '__all__' - s = TestSerializer(data={'address': 'not an ip address'}) + class IPv6Model(models.Model): + address = models.GenericIPAddressField(protocol="IPv6") + + class IPv6TestSerializer(serializers.ModelSerializer): + class Meta: + model = IPv6Model + fields = '__all__' + + class BothProtocolsModel(models.Model): + address = models.GenericIPAddressField(protocol="both") + + class BothProtocolsTestSerializer(serializers.ModelSerializer): + class Meta: + model = BothProtocolsModel + fields = '__all__' + + self.ipv4_serializer = IPv4TestSerializer + self.ipv4_model = IPv4Model + self.ipv6_serializer = IPv6TestSerializer + self.ipv6_model = IPv6Model + self.both_protocols_serializer = BothProtocolsTestSerializer + self.both_protocols_model = BothProtocolsModel + + def test_ip_address_validation(self): + s = self.both_protocols_serializer(data={'address': 'not an ip address'}) self.assertFalse(s.is_valid()) self.assertEqual(1, len(s.errors['address']), 'Unexpected number of validation errors: ' '{}'.format(s.errors)) + def test_invalid_ipv4_for_ipv4_field(self): + """Test that an invalid IPv4 raises only an IPv4-related error.""" + invalid_data = {"address": "invalid-ip"} + serializer = self.ipv4_serializer(data=invalid_data) + + with self.assertRaises(ValidationError) as context: + serializer.is_valid(raise_exception=True) + + self.assertEqual( + str(context.exception.detail["address"][0]), + "Enter a valid IPv4 address." + ) + + def test_invalid_ipv6_for_ipv6_field(self): + """Test that an invalid IPv6 raises only an IPv6-related error.""" + invalid_data = {"address": "invalid-ip"} + serializer = self.ipv6_serializer(data=invalid_data) + + with self.assertRaises(ValidationError) as context: + serializer.is_valid(raise_exception=True) + + self.assertEqual( + str(context.exception.detail["address"][0]), + "Enter a valid IPv6 address." + ) + + def test_invalid_ipv6_message_v1(self): + """Test that an invalid IPv6 raises error message when data contains ':' in it.""" + invalid_data = {"address": "invalid : data"} + serializer = self.ipv6_serializer(data=invalid_data) + + with self.assertRaises(ValidationError) as context: + serializer.is_valid(raise_exception=True) + + self.assertEqual( + str(context.exception.detail["address"][0]), + "Enter a valid IPv4 or IPv6 address." + ) + + def test_invalid_ipv6_message_v2(self): + """Test that an invalid IPv6 raises error message when data doesn't contains ':' in it.""" + invalid_data = {"address": "invalid-ip"} + serializer = self.ipv6_serializer(data=invalid_data) + + with self.assertRaises(ValidationError) as context: + serializer.is_valid(raise_exception=True) + + self.assertEqual( + str(context.exception.detail["address"][0]), + "Enter a valid IPv6 address." + ) + + def test_invalid_both_protocol(self): + """Test that an invalid IP raises a combined error message when protocol is both.""" + invalid_data = {"address": "invalid-ip"} + serializer = self.both_protocols_serializer(data=invalid_data) + + with self.assertRaises(ValidationError) as context: + serializer.is_valid(raise_exception=True) + + self.assertEqual( + str(context.exception.detail["address"][0]), + "Enter a valid IPv4 or IPv6 address." + ) + + def test_valid_ipv4(self): + """Test that a valid IPv4 passes validation.""" + valid_data = {"address": "192.168.1.1"} + serializer = self.ipv4_serializer(data=valid_data) + self.assertTrue(serializer.is_valid()) + + def test_valid_ipv6(self): + """Test that a valid IPv6 passes validation.""" + valid_data = {"address": "2001:db8::ff00:42:8329"} + serializer = self.ipv6_serializer(data=valid_data) + self.assertTrue(serializer.is_valid()) + + def test_valid_ipv4_for_both_protocol(self): + """Test that a valid IPv4 is accepted when protocol is 'both'.""" + valid_data = {"address": "192.168.1.1"} + serializer = self.both_protocols_serializer(data=valid_data) + self.assertTrue(serializer.is_valid()) + + def test_valid_ipv6_for_both_protocol(self): + """Test that a valid IPv6 is accepted when protocol is 'both'.""" + valid_data = {"address": "2001:db8::ff00:42:8329"} + serializer = self.both_protocols_serializer(data=valid_data) + self.assertTrue(serializer.is_valid()) + @pytest.mark.skipif('not postgres_fields') class TestPosgresFieldsMapping(TestCase):