From 769bc1336fd5d6a7fcf10d8be3b374c3e7a21bb3 Mon Sep 17 00:00:00 2001 From: Daniel Hahler Date: Tue, 30 Jan 2018 08:45:09 +0100 Subject: [PATCH] ErrorDetail: add __eq__/__ne__ and __repr__ (#5787) This adds `__eq__` to handle `code` in comparisons. When comparing an ErrorDetail to a string (missing `code` there) the ErrorDetail's `code` is ignored, but otherwise it is taken into account. --- rest_framework/exceptions.py | 17 +++++++++++++++++ tests/test_exceptions.py | 27 +++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index d885ba643..492872ae5 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -14,6 +14,7 @@ from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ungettext from rest_framework import status +from rest_framework.compat import unicode_to_repr from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList @@ -73,6 +74,22 @@ class ErrorDetail(six.text_type): self.code = code return self + def __eq__(self, other): + r = super(ErrorDetail, self).__eq__(other) + try: + return r and self.code == other.code + except AttributeError: + return r + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return unicode_to_repr('ErrorDetail(string=%r, code=%r)' % ( + six.text_type(self), + self.code, + )) + class APIException(Exception): """ diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 176aeb174..006191a49 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -53,6 +53,33 @@ class ExceptionTestCase(TestCase): 'code': 'throttled'} +class ErrorDetailTests(TestCase): + + def test_eq(self): + assert ErrorDetail('msg') == ErrorDetail('msg') + assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code') + + assert ErrorDetail('msg') == 'msg' + assert ErrorDetail('msg', 'code') == 'msg' + + def test_ne(self): + assert ErrorDetail('msg1') != ErrorDetail('msg2') + assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid') + + assert ErrorDetail('msg1') != 'msg2' + assert ErrorDetail('msg1', 'code') != 'msg2' + + def test_repr(self): + assert repr(ErrorDetail('msg1')) == \ + 'ErrorDetail(string={!r}, code=None)'.format('msg1') + assert repr(ErrorDetail('msg1', 'code')) == \ + 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code') + + def test_str(self): + assert str(ErrorDetail('msg1')) == 'msg1' + assert str(ErrorDetail('msg1', 'code')) == 'msg1' + + class TranslationTests(TestCase): @translation.override('fr')