Implement __eq__ for validators (#8925)

* Implement equality operator and add test coverage

* Add documentation on implementation
This commit is contained in:
Maxwell Muoto 2023-04-09 03:53:47 -05:00 committed by GitHub
parent b1cec517ff
commit 0d6ef034d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 1 deletions

View File

@ -53,7 +53,7 @@ If we open up the Django shell using `manage.py shell` we can now
The interesting bit here is the `reference` field. We can see that the uniqueness constraint is being explicitly enforced by a validator on the serializer field. The interesting bit here is the `reference` field. We can see that the uniqueness constraint is being explicitly enforced by a validator on the serializer field.
Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. REST framework validators, like their Django counterparts, implement the `__eq__` method, allowing you to compare instances for equality.
--- ---

View File

@ -79,6 +79,15 @@ class UniqueValidator:
smart_repr(self.queryset) smart_repr(self.queryset)
) )
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.lookup == other.lookup
)
class UniqueTogetherValidator: class UniqueTogetherValidator:
""" """
@ -166,6 +175,16 @@ class UniqueTogetherValidator:
smart_repr(self.fields) smart_repr(self.fields)
) )
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.missing_message == other.missing_message
and self.queryset == other.queryset
and self.fields == other.fields
)
class ProhibitSurrogateCharactersValidator: class ProhibitSurrogateCharactersValidator:
message = _('Surrogate characters are not allowed: U+{code_point:X}.') message = _('Surrogate characters are not allowed: U+{code_point:X}.')
@ -177,6 +196,13 @@ class ProhibitSurrogateCharactersValidator:
message = self.message.format(code_point=ord(surrogate_character)) message = self.message.format(code_point=ord(surrogate_character))
raise ValidationError(message, code=self.code) raise ValidationError(message, code=self.code)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.code == other.code
)
class BaseUniqueForValidator: class BaseUniqueForValidator:
message = None message = None
@ -230,6 +256,17 @@ class BaseUniqueForValidator:
self.field: message self.field: message
}, code='unique') }, code='unique')
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.missing_message == other.missing_message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.field == other.field
and self.date_field == other.date_field
)
def __repr__(self): def __repr__(self):
return '<%s(queryset=%s, field=%s, date_field=%s)>' % ( return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
self.__class__.__name__, self.__class__.__name__,

View File

@ -1,4 +1,5 @@
import datetime import datetime
from unittest.mock import MagicMock
import pytest import pytest
from django.db import DataError, models from django.db import DataError, models
@ -787,3 +788,13 @@ class ValidatorsTests(TestCase):
validator.filter_queryset( validator.filter_queryset(
attrs=None, queryset=None, field_name='', date_field_name='' attrs=None, queryset=None, field_name='', date_field_name=''
) )
def test_equality_operator(self):
mock_queryset = MagicMock()
validator = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
date_field='bar')
validator2 = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
date_field='bar')
assert validator == validator2
validator2.date_field = "bar2"
assert validator != validator2