diff --git a/tests/test_types_basic.py b/tests/test_types_basic.py index d0c0c919..a0f2d7c0 100755 --- a/tests/test_types_basic.py +++ b/tests/test_types_basic.py @@ -30,7 +30,7 @@ import platform from . import testutils import unittest -from .testutils import PY2, long, text_type, ConnectingTestCase +from .testutils import PY2, long, text_type, ConnectingTestCase, restore_types import psycopg2 from psycopg2.extensions import AsIs, adapt, register_adapter @@ -430,6 +430,7 @@ class AdaptSubclassTest(unittest.TestCase): s2 = Sub(s1) self.assertEqual(adapt(s1).getquoted(), adapt(s2).getquoted()) + @restore_types def test_adapt_most_specific(self): class A(object): pass @@ -442,13 +443,10 @@ class AdaptSubclassTest(unittest.TestCase): register_adapter(A, lambda a: AsIs("a")) register_adapter(B, lambda b: AsIs("b")) - try: - self.assertEqual(b'b', adapt(C()).getquoted()) - finally: - del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] - del psycopg2.extensions.adapters[B, psycopg2.extensions.ISQLQuote] + self.assertEqual(b'b', adapt(C()).getquoted()) @testutils.skip_from_python(3) + @restore_types def test_no_mro_no_joy(self): class A: pass @@ -457,12 +455,10 @@ class AdaptSubclassTest(unittest.TestCase): pass register_adapter(A, lambda a: AsIs("a")) - try: - self.assertRaises(psycopg2.ProgrammingError, adapt, B()) - finally: - del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] + self.assertRaises(psycopg2.ProgrammingError, adapt, B()) @testutils.skip_before_python(3) + @restore_types def test_adapt_subtype_3(self): class A: pass @@ -471,10 +467,7 @@ class AdaptSubclassTest(unittest.TestCase): pass register_adapter(A, lambda a: AsIs("a")) - try: - self.assertEqual(b"a", adapt(B()).getquoted()) - finally: - del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] + self.assertEqual(b"a", adapt(B()).getquoted()) def test_conform_subclass_precedence(self): class foo(tuple): diff --git a/tests/test_types_extras.py b/tests/test_types_extras.py index 2b24ce25..8be629f2 100755 --- a/tests/test_types_extras.py +++ b/tests/test_types_extras.py @@ -25,7 +25,8 @@ from pickle import dumps, loads import unittest from .testutils import (PY2, text_type, skip_if_no_uuid, skip_before_postgres, - ConnectingTestCase, py3_raises_typeerror, slow, skip_from_python) + ConnectingTestCase, py3_raises_typeerror, slow, skip_from_python, + restore_types) import psycopg2 import psycopg2.extras @@ -75,6 +76,7 @@ class TypesExtrasTests(ConnectingTestCase): s = self.execute("SELECT '{}'::uuid[] AS foo") self.failUnless(type(s) == list and len(s) == 0) + @restore_types def testINET(self): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) @@ -87,6 +89,7 @@ class TypesExtrasTests(ConnectingTestCase): s = self.execute("SELECT NULL::inet AS foo") self.failUnless(s is None) + @restore_types def testINETARRAY(self): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) @@ -258,26 +261,18 @@ class HstoreTestCase(ConnectingTestCase): self.assert_(isinstance(t[2].values()[0], unicode)) @skip_if_no_hstore + @restore_types def test_register_globally(self): - oids = HstoreAdapter.get_oids(self.conn) + HstoreAdapter.get_oids(self.conn) + register_hstore(self.conn, globally=True) + conn2 = self.connect() try: - register_hstore(self.conn, globally=True) - conn2 = self.connect() - try: - cur2 = self.conn.cursor() - cur2.execute("select 'a => b'::hstore") - r = cur2.fetchone() - self.assert_(isinstance(r[0], dict)) - finally: - conn2.close() + cur2 = self.conn.cursor() + cur2.execute("select 'a => b'::hstore") + r = cur2.fetchone() + self.assert_(isinstance(r[0], dict)) finally: - psycopg2.extensions.string_types.pop(oids[0][0]) - - # verify the caster is not around anymore - cur = self.conn.cursor() - cur.execute("select 'a => b'::hstore") - r = cur.fetchone() - self.assert_(isinstance(r[0], str)) + conn2.close() @skip_if_no_hstore def test_roundtrip(self): @@ -332,6 +327,7 @@ class HstoreTestCase(ConnectingTestCase): ok(dict(zip(ab, ab))) @skip_if_no_hstore + @restore_types def test_oid(self): cur = self.conn.cursor() cur.execute("select 'hstore'::regtype::oid") @@ -340,15 +336,11 @@ class HstoreTestCase(ConnectingTestCase): # Note: None as conn_or_cursor is just for testing: not public # interface and it may break in future. register_hstore(None, globally=True, oid=oid) - try: - cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") - t = cur.fetchone() - self.assert_(t[0] is None) - self.assertEqual(t[1], {}) - self.assertEqual(t[2], {'a': 'b'}) - - finally: - psycopg2.extensions.string_types.pop(oid) + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() + self.assert_(t[0] is None) + self.assertEqual(t[1], {}) + self.assertEqual(t[2], {'a': 'b'}) @skip_if_no_hstore @skip_before_postgres(8, 3) @@ -385,25 +377,21 @@ class HstoreTestCase(ConnectingTestCase): self.assertEqual(a, [{'a': '1'}, {'b': '2'}]) @skip_if_no_hstore + @restore_types def test_array_cast_oid(self): cur = self.conn.cursor() cur.execute("select 'hstore'::regtype::oid, 'hstore[]'::regtype::oid") oid, aoid = cur.fetchone() register_hstore(None, globally=True, oid=oid, array_oid=aoid) - try: - cur.execute(""" - select null::hstore, ''::hstore, - 'a => b'::hstore, '{a=>b}'::hstore[]""") - t = cur.fetchone() - self.assert_(t[0] is None) - self.assertEqual(t[1], {}) - self.assertEqual(t[2], {'a': 'b'}) - self.assertEqual(t[3], [{'a': 'b'}]) - - finally: - psycopg2.extensions.string_types.pop(oid) - psycopg2.extensions.string_types.pop(aoid) + cur.execute(""" + select null::hstore, ''::hstore, + 'a => b'::hstore, '{a=>b}'::hstore[]""") + t = cur.fetchone() + self.assert_(t[0] is None) + self.assertEqual(t[1], {}) + self.assertEqual(t[2], {'a': 'b'}) + self.assertEqual(t[3], [{'a': 'b'}]) @skip_if_no_hstore def test_non_dbapi_connection(self): @@ -598,27 +586,20 @@ class AdaptTypeTestCase(ConnectingTestCase): conn2.close() @skip_if_no_composite + @restore_types def test_register_globally(self): self._create_type("type_ii", [("a", "integer"), ("b", "integer")]) conn1 = self.connect() conn2 = self.connect() try: - t = psycopg2.extras.register_composite("type_ii", conn1, globally=True) - try: - curs1 = conn1.cursor() - curs2 = conn2.cursor() - curs1.execute("select (1,2)::type_ii") - self.assertEqual(curs1.fetchone()[0], (1, 2)) - curs2.execute("select (1,2)::type_ii") - self.assertEqual(curs2.fetchone()[0], (1, 2)) - finally: - # drop the registered typecasters to help the refcounting - # script to return precise values. - del psycopg2.extensions.string_types[t.typecaster.values[0]] - if t.array_typecaster: - del psycopg2.extensions.string_types[ - t.array_typecaster.values[0]] + psycopg2.extras.register_composite("type_ii", conn1, globally=True) + curs1 = conn1.cursor() + curs2 = conn2.cursor() + curs1.execute("select (1,2)::type_ii") + self.assertEqual(curs1.fetchone()[0], (1, 2)) + curs2.execute("select (1,2)::type_ii") + self.assertEqual(curs2.fetchone()[0], (1, 2)) finally: conn1.close() @@ -844,16 +825,14 @@ class JsonTestCase(ConnectingTestCase): obj = Decimal('123.45') self.assertQuotedEqual(curs.mogrify("%s", (MyJson(obj),)), b"'123.45'") + @restore_types def test_register_on_dict(self): psycopg2.extensions.register_adapter(dict, Json) - try: - curs = self.conn.cursor() - obj = {'a': 123} - self.assertQuotedEqual( - curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""") - finally: - del psycopg2.extensions.adapters[dict, ext.ISQLQuote] + curs = self.conn.cursor() + obj = {'a': 123} + self.assertQuotedEqual( + curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""") def test_type_not_available(self): curs = self.conn.cursor() @@ -889,21 +868,12 @@ class JsonTestCase(ConnectingTestCase): self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) @skip_if_no_json_type + @restore_types def test_register_globally(self): - old = psycopg2.extensions.string_types.get(114) - olda = psycopg2.extensions.string_types.get(199) - try: - new, newa = psycopg2.extras.register_json(self.conn, globally=True) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) - finally: - psycopg2.extensions.string_types.pop(new.values[0]) - psycopg2.extensions.string_types.pop(newa.values[0]) - if old: - psycopg2.extensions.register_type(old) - if olda: - psycopg2.extensions.register_type(olda) + new, newa = psycopg2.extras.register_json(self.conn, globally=True) + curs = self.conn.cursor() + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None}) @skip_if_no_json_type def test_loads(self): @@ -919,29 +889,20 @@ class JsonTestCase(ConnectingTestCase): self.assertEqual(data['a'], Decimal('100.0')) @skip_if_no_json_type + @restore_types def test_no_conn_curs(self): oid, array_oid = _get_json_oids(self.conn) - old = psycopg2.extensions.string_types.get(114) - olda = psycopg2.extensions.string_types.get(199) - def loads(s): return psycopg2.extras.json.loads(s, parse_float=Decimal) - try: - new, newa = psycopg2.extras.register_json( - loads=loads, oid=oid, array_oid=array_oid) - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::json""") - data = curs.fetchone()[0] - self.assert_(isinstance(data['a'], Decimal)) - self.assertEqual(data['a'], Decimal('100.0')) - finally: - psycopg2.extensions.string_types.pop(new.values[0]) - psycopg2.extensions.string_types.pop(newa.values[0]) - if old: - psycopg2.extensions.register_type(old) - if olda: - psycopg2.extensions.register_type(olda) + + new, newa = psycopg2.extras.register_json( + loads=loads, oid=oid, array_oid=array_oid) + curs = self.conn.cursor() + curs.execute("""select '{"a": 100.0, "b": null}'::json""") + data = curs.fetchone()[0] + self.assert_(isinstance(data['a'], Decimal)) + self.assertEqual(data['a'], Decimal('100.0')) @skip_before_postgres(9, 2) def test_register_default(self): @@ -1043,22 +1004,13 @@ class JsonbTestCase(ConnectingTestCase): curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) + @restore_types def test_register_globally(self): - old = psycopg2.extensions.string_types.get(3802) - olda = psycopg2.extensions.string_types.get(3807) - try: - new, newa = psycopg2.extras.register_json(self.conn, - loads=self.myloads, globally=True, name='jsonb') - curs = self.conn.cursor() - curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") - self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) - finally: - psycopg2.extensions.string_types.pop(new.values[0]) - psycopg2.extensions.string_types.pop(newa.values[0]) - if old: - psycopg2.extensions.register_type(old) - if olda: - psycopg2.extensions.register_type(olda) + new, newa = psycopg2.extras.register_json(self.conn, + loads=self.myloads, globally=True, name='jsonb') + curs = self.conn.cursor() + curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") + self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1}) def test_loads(self): json = psycopg2.extras.json @@ -1558,6 +1510,7 @@ class RangeCasterTestCase(ConnectingTestCase): self.assert_(not r1.lower_inc) self.assert_(r1.upper_inc) + @restore_types def test_register_range_adapter(self): cur = self.conn.cursor() cur.execute("create type textrange as range (subtype=text)") @@ -1584,9 +1537,6 @@ class RangeCasterTestCase(ConnectingTestCase): self.assert_(not r1.lower_inc) self.assert_(r1.upper_inc) - # clear the adapters to allow precise count by scripts/refcounter.py - del ext.adapters[rc.range, ext.ISQLQuote] - def test_range_escaping(self): cur = self.conn.cursor() cur.execute("create type textrange as range (subtype=text)") @@ -1645,6 +1595,7 @@ class RangeCasterTestCase(ConnectingTestCase): self.assertRaises(psycopg2.ProgrammingError, register_range, 'nosuchrange', 'FailRange', cur) + @restore_types def test_schema_range(self): cur = self.conn.cursor() cur.execute("create schema rs") @@ -1654,10 +1605,10 @@ class RangeCasterTestCase(ConnectingTestCase): cur.execute("create type rs.r3 as range (subtype=text)") cur.execute("savepoint x") - ra1 = register_range('r1', 'r1', cur) + register_range('r1', 'r1', cur) ra2 = register_range('r2', 'r2', cur) rars2 = register_range('rs.r2', 'r2', cur) - rars3 = register_range('rs.r3', 'r3', cur) + register_range('rs.r3', 'r3', cur) self.assertNotEqual( ra2.typecaster.values[0], @@ -1671,10 +1622,6 @@ class RangeCasterTestCase(ConnectingTestCase): register_range, 'rs.r1', 'FailRange', cur) cur.execute("rollback to savepoint x;") - # clear the adapters to allow precise count by scripts/refcounter.py - for r in [ra1, ra2, rars2, rars3]: - del ext.adapters[r.range, ext.ISQLQuote] - def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) diff --git a/tests/testutils.py b/tests/testutils.py index 405da18a..996ceb7d 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -429,3 +429,20 @@ def slow(f): return self.skipTest("slow test") return f(self) return slow_ + + +def restore_types(f): + """Decorator to restore the adaptation system after running a test""" + @wraps(f) + def restore_types_(self): + types = psycopg2.extensions.string_types.copy() + adapters = psycopg2.extensions.adapters.copy() + try: + return f(self) + finally: + psycopg2.extensions.string_types.clear() + psycopg2.extensions.string_types.update(types) + psycopg2.extensions.adapters.clear() + psycopg2.extensions.adapters.update(adapters) + + return restore_types_