From d469c325037d09de0c478f362b35d6035e3ac43a Mon Sep 17 00:00:00 2001 From: Chris Withers Date: Wed, 12 Feb 2014 08:11:59 +0000 Subject: [PATCH] Provide a stable and consistent sort order for Range objects. This matches postgres server-side behaviour and helps client applications that need to sort based on the primary key of tables where the primary key is or contains a range. --- lib/_range.py | 19 +++++++++--- tests/test_types_extras.py | 63 +++++++++++++++++++++++++++++++++++--- 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/lib/_range.py b/lib/_range.py index 0f159908..5a794103 100644 --- a/lib/_range.py +++ b/lib/_range.py @@ -133,12 +133,21 @@ class Range(object): def __hash__(self): return hash((self._lower, self._upper, self._bounds)) - def __lt__(self, other): - raise TypeError( - 'Range objects cannot be ordered; please refer to the PostgreSQL' - ' documentation to perform this operation in the database') + # as the postgres docs describe for the server-side stuff, + # ordering is rather arbitrary, but will remain stable + # and consistent. - __le__ = __gt__ = __ge__ = __lt__ + def __lt__(self, other): + if not isinstance(other, Range): + return False + return ((self._lower, self._upper, self._bounds) < + (other._lower, other._upper, other._bounds)) + + def __le__(self, other): + if not isinstance(other, Range): + return False + return ((self._lower, self._upper, self._bounds) <= + (other._lower, other._upper, other._bounds)) def register_range(pgrange, pyrange, conn_or_curs, globally=False): diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index 96ffcd3c..30ff9c7a 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -1225,12 +1225,65 @@ class RangeTestCase(unittest.TestCase): self.assertEqual(Range(10, 20), IntRange(10, 20)) self.assertEqual(PositiveIntRange(10, 20), IntRange(10, 20)) - def test_not_ordered(self): + # as the postgres docs describe for the server-side stuff, + # ordering is rather arbitrary, but will remain stable + # and consistent. + + def test_ordering_lt(self): from psycopg2.extras import Range - self.assertRaises(TypeError, lambda: Range(empty=True) < Range(0,4)) - self.assertRaises(TypeError, lambda: Range(1,2) > Range(0,4)) - self.assertRaises(TypeError, lambda: Range(1,2) <= Range()) - self.assertRaises(TypeError, lambda: Range(1,2) >= Range()) + self.assertTrue(Range(empty=True) < Range(0, 4)) + self.assertFalse(Range(1, 2) < Range(0, 4)) + self.assertTrue(Range(0, 4) < Range(1, 2)) + self.assertFalse(Range(1, 2) < Range()) + self.assertTrue(Range() < Range(1, 2)) + self.assertFalse(Range(1) < Range(upper=1)) + self.assertFalse(Range() < Range()) + self.assertFalse(Range(empty=True) < Range(empty=True)) + self.assertFalse(Range(1, 2) < Range(1, 2)) + self.assertTrue(1 < Range(1, 2)) + self.assertFalse(Range(1, 2) < 1) + + def test_ordering_gt(self): + from psycopg2.extras import Range + self.assertFalse(Range(empty=True) > Range(0, 4)) + self.assertTrue(Range(1, 2) > Range(0, 4)) + self.assertFalse(Range(0, 4) > Range(1, 2)) + self.assertTrue(Range(1, 2) > Range()) + self.assertFalse(Range() > Range(1, 2)) + self.assertTrue(Range(1) > Range(upper=1)) + self.assertFalse(Range() > Range()) + self.assertFalse(Range(empty=True) > Range(empty=True)) + self.assertFalse(Range(1, 2) > Range(1, 2)) + self.assertFalse(1 > Range(1, 2)) + self.assertTrue(Range(1, 2) > 1) + + def test_ordering_le(self): + from psycopg2.extras import Range + self.assertTrue(Range(empty=True) <= Range(0, 4)) + self.assertFalse(Range(1, 2) <= Range(0, 4)) + self.assertTrue(Range(0, 4) <= Range(1, 2)) + self.assertFalse(Range(1, 2) <= Range()) + self.assertTrue(Range() <= Range(1, 2)) + self.assertFalse(Range(1) <= Range(upper=1)) + self.assertTrue(Range() <= Range()) + self.assertTrue(Range(empty=True) <= Range(empty=True)) + self.assertTrue(Range(1, 2) <= Range(1, 2)) + self.assertTrue(1 <= Range(1, 2)) + self.assertFalse(Range(1, 2) <= 1) + + def test_ordering_ge(self): + from psycopg2.extras import Range + self.assertFalse(Range(empty=True) >= Range(0, 4)) + self.assertTrue(Range(1, 2) >= Range(0, 4)) + self.assertFalse(Range(0, 4) >= Range(1, 2)) + self.assertTrue(Range(1, 2) >= Range()) + self.assertFalse(Range() >= Range(1, 2)) + self.assertTrue(Range(1) >= Range(upper=1)) + self.assertTrue(Range() >= Range()) + self.assertTrue(Range(empty=True) >= Range(empty=True)) + self.assertTrue(Range(1, 2) >= Range(1, 2)) + self.assertFalse(1 >= Range(1, 2)) + self.assertTrue(Range(1, 2) >= 1) def skip_if_no_range(f):