diff --git a/lib/_range.py b/lib/_range.py index e7370a99..e0535506 100644 --- a/lib/_range.py +++ b/lib/_range.py @@ -55,7 +55,8 @@ class Range(object): """ __slots__ = ('_lower', '_upper', '_bounds') - def __init__(self, lower=None, upper=None, bounds='[)', empty=False): + def __new__(cls, lower=None, upper=None, bounds='[)', empty=False): + self = super(Range, cls).__new__(cls) if not empty: if bounds not in ('[)', '(]', '()', '[]'): raise ValueError("bound flags not valid: %r" % bounds) @@ -66,6 +67,7 @@ class Range(object): else: self._lower = self._upper = self._bounds = None + return self def __repr__(self): if self._bounds is None: @@ -131,6 +133,17 @@ class Range(object): return True + def __eq__(self, other): + return (self._lower == other._lower + and self._upper == other._upper + and self._bounds == other._bounds) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((self._lower, self._upper, self._bounds)) + def register_range(pgrange, pyrange, conn_or_curs, globally=False): """Register a typecaster and an adapter for range a range type. diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index d55acb91..611cca30 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -888,6 +888,27 @@ class RangeTestCase(unittest.TestCase): self.assert_(20 not in r) self.assert_(21 not in r) + def test_eq_hash(self): + from psycopg2.extras import Range + def assert_equal(r1, r2): + self.assert_(r1 == r2) + self.assert_(hash(r1) == hash(r2)) + + assert_equal(Range(empty=True), Range(empty=True)) + assert_equal(Range(), Range()) + assert_equal(Range(10, None), Range(10, None)) + assert_equal(Range(10, 20), Range(10, 20)) + assert_equal(Range(10, 20), Range(10, 20, '[)')) + assert_equal(Range(10, 20, '[]'), Range(10, 20, '[]')) + + def assert_not_equal(r1, r2): + self.assert_(r1 != r2) + self.assert_(hash(r1) != hash(r2)) + + assert_not_equal(Range(10, 20), Range(10, 21)) + assert_not_equal(Range(10, 20), Range(11, 20)) + assert_not_equal(Range(10, 20, '[)'), Range(10, 20, '[]')) + def skip_if_no_range(f): def skip_if_no_range_(self):