diff --git a/lib/tz.py b/lib/tz.py index e353f9b6..cc99025b 100644 --- a/lib/tz.py +++ b/lib/tz.py @@ -38,17 +38,36 @@ class FixedOffsetTimezone(datetime.tzinfo): with a small change to the `!__init__()` method to allow for pickling and a default name in the form ``sHH:MM`` (``s`` is the sign.). + The implementation also caches instances. During creation, if a + FixedOffsetTimezone instance has previously been created with the same + offset and name that instance will be returned. This saves memory and + improves comparability. + .. __: http://docs.python.org/library/datetime.html#datetime-tzinfo """ _name = None _offset = ZERO - + + _cache = {} + def __init__(self, offset=None, name=None): if offset is not None: self._offset = datetime.timedelta(minutes = offset) if name is not None: self._name = name + def __new__(cls, offset=None, name=None): + """Return a suitable instance created earlier if it exists + """ + key = (offset, name) + try: + return cls._cache[key] + except KeyError: + tz = datetime.tzinfo.__new__(cls, offset, name) + tz.__init__(offset, name) + cls._cache[key] = tz + return tz + def __repr__(self): offset_mins = self._offset.seconds // 60 + self._offset.days * 24 * 60 return "psycopg2.tz.FixedOffsetTimezone(offset=%r, name=%r)" \ diff --git a/tests/test_dates.py b/tests/test_dates.py index 026561a2..9be68a2c 100755 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -532,6 +532,13 @@ class FixedOffsetTimezoneTests(unittest.TestCase): tzinfo = FixedOffsetTimezone(name="FOO") self.assertEqual(repr(tzinfo), "psycopg2.tz.FixedOffsetTimezone(offset=0, name='FOO')") + def test_instance_caching(self): + self.assert_(FixedOffsetTimezone(name="FOO") is FixedOffsetTimezone(name="FOO")) + self.assert_(FixedOffsetTimezone(7 * 60) is FixedOffsetTimezone(7 * 60)) + self.assert_(FixedOffsetTimezone(-9 * 60, 'FOO') is FixedOffsetTimezone(-9 * 60, 'FOO')) + self.assert_(FixedOffsetTimezone(9 * 60) is not FixedOffsetTimezone(9 * 60, 'FOO')) + self.assert_(FixedOffsetTimezone(name='FOO') is not FixedOffsetTimezone(9 * 60, 'FOO')) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)