From 0d6ef034d2eed788f4fe6f9721148bf3874802ec Mon Sep 17 00:00:00 2001 From: Maxwell Muoto Date: Sun, 9 Apr 2023 03:53:47 -0500 Subject: [PATCH] Implement `__eq__` for validators (#8925) * Implement equality operator and add test coverage * Add documentation on implementation --- docs/api-guide/validators.md | 2 +- rest_framework/validators.py | 37 ++++++++++++++++++++++++++++++++++++ tests/test_validators.py | 11 +++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/docs/api-guide/validators.md b/docs/api-guide/validators.md index bb8466a2c..dac937d9b 100644 --- a/docs/api-guide/validators.md +++ b/docs/api-guide/validators.md @@ -53,7 +53,7 @@ If we open up the Django shell using `manage.py shell` we can now The interesting bit here is the `reference` field. We can see that the uniqueness constraint is being explicitly enforced by a validator on the serializer field. -Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. +Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. REST framework validators, like their Django counterparts, implement the `__eq__` method, allowing you to compare instances for equality. --- diff --git a/rest_framework/validators.py b/rest_framework/validators.py index a5cb75a84..07ad11b47 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -79,6 +79,15 @@ class UniqueValidator: smart_repr(self.queryset) ) + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return (self.message == other.message + and self.requires_context == other.requires_context + and self.queryset == other.queryset + and self.lookup == other.lookup + ) + class UniqueTogetherValidator: """ @@ -166,6 +175,16 @@ class UniqueTogetherValidator: smart_repr(self.fields) ) + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return (self.message == other.message + and self.requires_context == other.requires_context + and self.missing_message == other.missing_message + and self.queryset == other.queryset + and self.fields == other.fields + ) + class ProhibitSurrogateCharactersValidator: message = _('Surrogate characters are not allowed: U+{code_point:X}.') @@ -177,6 +196,13 @@ class ProhibitSurrogateCharactersValidator: message = self.message.format(code_point=ord(surrogate_character)) raise ValidationError(message, code=self.code) + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return (self.message == other.message + and self.code == other.code + ) + class BaseUniqueForValidator: message = None @@ -230,6 +256,17 @@ class BaseUniqueForValidator: self.field: message }, code='unique') + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return (self.message == other.message + and self.missing_message == other.missing_message + and self.requires_context == other.requires_context + and self.queryset == other.queryset + and self.field == other.field + and self.date_field == other.date_field + ) + def __repr__(self): return '<%s(queryset=%s, field=%s, date_field=%s)>' % ( self.__class__.__name__, diff --git a/tests/test_validators.py b/tests/test_validators.py index 35fef6f26..1cf42ed07 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,4 +1,5 @@ import datetime +from unittest.mock import MagicMock import pytest from django.db import DataError, models @@ -787,3 +788,13 @@ class ValidatorsTests(TestCase): validator.filter_queryset( attrs=None, queryset=None, field_name='', date_field_name='' ) + + def test_equality_operator(self): + mock_queryset = MagicMock() + validator = BaseUniqueForValidator(queryset=mock_queryset, field='foo', + date_field='bar') + validator2 = BaseUniqueForValidator(queryset=mock_queryset, field='foo', + date_field='bar') + assert validator == validator2 + validator2.date_field = "bar2" + assert validator != validator2